9

循环神经网络:RNN简介、BPTT算法、梯度消失、梯度爆炸

 2 years ago
source link: https://ylhao.github.io/2018/05/24/216/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

符号定义和解释

首先说明这里的推导采用的符号如下图所示:

A recurrent neural network and the unfolding in time of the computation involved in its forward computation. Source: Nature

  1. xt 是第 t 个时刻的输入
  2. st 是第 t 个时刻隐藏层的状态
  3. ot 是第 t 个时刻的输出,比如如果我们想要预测下一个词是什么,那么可以认为 ot=softmax(Vst)
  4. st 计算方式为 st=f(Uxt+Wst−1),其中的函数 f 代表一个非线性函数,比如 tanh 或者 ReLU
  5. 第 1 个时刻对应的输入 s0 通常初始化为零向量
  6. U,W,V 是循环神经网络的参数,所有时刻共享,这在很大程度上减少了参数数量

形象化理解 st

st 是隐藏层的状态,可以把 st 看成循环神经网络的记忆,通过 st 可以知道在之前所有时刻都发生了什么,但是实际情况通常是通过 st 并不能知道“很久以前”到底发生了什么。循环神经网络的最重要的一个特性就是隐藏层状态了,我们可以通过隐藏层状态捕获到一个序列的相关信息。

对 ot 的理解

ot 是神经网络的输出,如上图所示,每个时刻都有输出,但是实际上每个时刻是否需要有输出是视情况而定的。比如我们做文本情感分析,我们只关心整个句子最后一个时刻的输出,那之前所有的时刻都是不需要输出的。

Backpropagation Through Time(BPTT)

首先我们有:
st=tanh(Uxt+Wst−1) ˆyt=softmax(Vst)

接着我们可以用交叉熵来计算每个时刻的 loss:
Et(yt,ˆyt)=−ytlogˆyt 

通常我们把每次输入的一个序列看成一个训练样本,所以就把每个时刻的交叉熵的和看成这个样本的总误差:
E(y,ˆy)=∑tE(yt,ˆyt) =−∑tytlogˆyt

首先我们根据以上的定义,给出以下简化了的图片:

接下来我们通过总误差计算 U,V,W 这些参数的梯度,然后用随机梯度下降法更新参数。接下来以 E3 为例详细分析。

首先计算参数 V 的梯度:

在以上推导过程中 z3=Vs3,⊗ 代表的是计算两个向量的外积(叉乘)。我们发现参数 V 的梯度仅仅依赖当前时刻的隐层状态和当前时刻的输出。

接下来计算参数 W 和参数 U 的梯度,首先我们给出一张计算图,图中的矩形框代表的输入,图中的椭圆形框代表的计算函数。通过计算图我们可以得到参数 W 和参数 U 的梯度:

∂E3∂W=∂E3∂ˆy3∂ˆy3∂s3∂s3∂W+∂E3∂ˆy3∂ˆy3∂s3∂s3∂s2∂s2∂W+∂E3∂ˆy3∂ˆy3∂s3∂s3∂s2∂s2∂s1∂s1∂W

∂E3∂U=∂E3∂ˆy3∂ˆy3∂s3∂s3∂U+∂E3∂ˆy3∂ˆy3∂s3∂s3∂s2∂s2∂U+∂E3∂ˆy3∂ˆy3∂s3∂s3∂s2∂s2∂s1∂s1∂U

形象化的图像如下图所示:

我们可以将参数 W 的表达式简化成以下形式:

RNN 中的梯度消失和梯度爆炸

首先我们从上式中可以看出,参数 W 和参数 U 的导数存在连乘的情况。通过以下公式

st=tanh(Uxt+Wst−1)

我们可以发现每次向前传递误差(每次求导)都要经过(处理)一个 tanh 函数。tanh 函数的图像和 tanh 函数的导函数图像如下图所示:

然后我们知道 tanh 函数的基本形式和其导数形式为:

tanh(x)=ex−e−xex+e−x tanh′(x)=(1−tanh2(x))

下面展开 ∂s3∂s2 来更加具体的说明梯度消失和梯度爆炸的问题:
s3=tanh(Uxt+Ws2) ∂s3∂s2=tanh′(Uxt+Ws2)W

通过上图我们知道 tanh 函数的导数小于等于 1,同时如果参数 W 初始化的很小的话,那么 ∂s3∂s2 将会是一个小于 0 的数,我们可以类推出 ∂st∂st−1 将会是一个小于 0 的数。假设有 n 项连乘,则可形式化表示为:
(小于0的数)n

随着时刻向前推移(误差沿着时间向前传递),梯度是呈指数级下降的。这也就是梯度消失问题。

如果 W 初始的很大的话,那么我们可以类推出 ∂st∂st−1 将会是一个大于 0 的数。如果有 n 项连乘,则可形式化表示为:
(大于0的数)n

随着时刻向前推移(误差沿着时间向前传递),梯度是呈指数级上升的。这也就是梯度爆炸的问题。

梯度爆炸的问题可以用梯度裁剪的方式来缓解。
梯度消失的问题则有以下缓解方式:

  1. 更换激活函数,比如可以选择 ReLU 函数。
  2. 更改 RNN 隐藏层的结构,比如采用 GRU 或者 LSTM 的隐藏层结构。

转载请注明来源,欢迎对文章中的引用来源进行考证,欢迎指出任何有错误或不够清晰的表达,可以在文章下方的评论区进行评论,也可以邮件至 [email protected]

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK