12

tensorflow学习,循环神经网络(RNN)相关的函数简介(25)

 3 years ago
source link: https://blog.popkx.com/tensorflow-study-introduction-of-rnn-revelent-funcitons-in-tensorflow/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

tensorflow学习,循环神经网络(RNN)相关的函数简介(25)

发表于 2018-07-25 22:07:29   |   已被 访问: 638 次   |   分类于:   tensorflow   |   1 条评论

本节,将介绍 tensorflow 实现循环神经网络 RNN 的主要函数。

实现 RNN 的基本单元 RNNCell


RNNCell 是 tensorflow 中的循环神经网络的基本单元,它是一个抽象类,本身不能实例化。它的两个子类,一个 BasicRNNCell,另一个BasicLSTMCell,分别对应经典循环神经网络,和长短记忆循环神经网络。

学习 RNNCell 要重点关注三个地方:

  • 类方法 call
  • 类属性 state_size
  • 类属性 output_size

简单的说,call方法就是用来计算隐状态的。关于隐状态可以参考前面两节(RNNLSTM)。而state_sizeoutput_size则表示隐状态的大小和输出向量的大小。

output, next_state = call(input, state)
通常 input 的形状是 [batch_size, input_size],所以隐状态的形状 [batch_size, state_size],输出形状[batch_size, output_size]。

定义经典 RNN 单元的方法

rnnCell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
print rnnCell.state_size
# 应 state_size = 128

定义 LSTM 单元的方法

lstmCell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
print lstmCell.state_size
# 应 state_size = LSTMStateTuple(c=128, h=128)

多层循环神经网络:MultiRNNCell


很多时候,单层 RNN 的能力有限,需要多层 RNN,在 tensorflow 中,可以使用 tf.nn.rnn_cell.MultiRNNCell 函数建立多层的 RNN,下面是一个示例小 demo

import tensorflow as tf
import numpy as np

# 创建单个cell并堆叠多层
def get_a_cell(lstm_size, keep_prob):
    rnn = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
    return rnn
# 建立 3 层
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])

这里 cell 的 state_size 为 (128,128,128),表示有 3 个隐状态,每个隐状态大小为 128。

MultiRNNCell 也是 RNNCell 的子类,所以它也有 call 方法,和 state_size, output_size 属性。

使用 dynamic_rnn 展开时间维度


对于单个 RNNCell,使用它的 call 方法进行运算时,只在序列时间是前进了一步。如使用 x1,h0 计算得到 h1,根据 x2,h1 计算得到 h2等。如果序列长度为 n,就需要调用 n 次 call 函数。tensorflow 提供了 tf.nn.dynamic_rnn 函数,等价于调用 n 次 call 函数。即通过 {h0, x1, x2, x3, ...} 直接得到 {h1, h2, ...}

outputs, state = tf.nn.dynamic_rnn(cell, inputs)

至此,建立循环神经网络的几个比较重要的 tensorflow 函数就介绍完了,下一节将尝试建立 RNN 网络,训练其作诗。

本节主要参考《21个项目玩转深度学习》。

阅读更多:   tensorflow


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK