[AI] 8 LSTM

一、传统RNN的反向传播

structure of traditional RNN, from [1]

隐藏层:S_t=W_x·X_t+W_s·S_{t-1}+b_1
输出层:O_t=W_o·S_t+b_2

假设网络结构如上图所示,神经元没有激活函数,当t=3时,初始值S_0给定,则t=3时刻,损失函数L_3=\frac{1}{2}(Y_3-O_3)^2
S_1=W_x·X_1+W_s·S_{0}+b_1
O_1=W_o·S_1+b_2
S_2=W_x·X_2+W_s·S_{1}+b_1
O_2=W_o·S_2+b_2
S_3=W_x·X_2+W_s·S_{2}+b_1
O_3=W_o·S_2+b_2
训练的目标就是采用梯度下降法更新W_xW_sW_ob_1b_2使L达到最小,L=\sum_{t=0}^{T}L_t
这里需要说明一下,更新参数是以每个时刻t为单位的L_3,根据输出与label计算偏导更新参数;而L用来评判训练完一次后模型的效果,比如是否准确率增加、是否过拟合等等
接下来开始计算梯度:
\frac{\partial L_3}{\partial W_o}=\frac{\partial L_3}{\partial O_3}\frac{\partial L_3}{\partial W_o}
\frac{\partial L_3}{\partial W_x}=\frac{\partial L_3}{\partial O_3}\frac{\partial L_3}{\partial S_3}\frac{\partial S_3}{\partial W_x}+\frac{\partial L_3}{\partial O_3}\frac{\partial L_3}{\partial S_3}\frac{\partial S_3}{\partial S_2}\frac{\partial S_2}{\partial W_x}+\frac{\partial L_3}{\partial O_3}\frac{\partial L_3}{\partial S_3}\frac{\partial S_3}{\partial S_2}\frac{\partial S_2}{\partial S_1}\frac{\partial S_1}{\partial W_x}
\frac{\partial L_3}{\partial W_s}=\frac{\partial L_3}{\partial O_3}\frac{\partial L_3}{\partial S_3}\frac{\partial S_3}{\partial W_s}+\frac{\partial L_3}{\partial O_3}\frac{\partial L_3}{\partial S_3}\frac{\partial S_3}{\partial S_2}\frac{\partial S_2}{\partial W_s}+\frac{\partial L_3}{\partial O_3}\frac{\partial L_3}{\partial S_3}\frac{\partial S_3}{\partial S_2}\frac{\partial S_2}{\partial S_1}\frac{\partial S_1}{\partial W_s}
根据上式,可以求的任意时刻L_tW_sW_x的偏导
\frac{\partial L_t}{\partial W_x}=\sum_{k=0}^{t}\frac{\partial L_t}{\partial O_t}\frac{\partial O_t}{\partial S_t}(\prod_{j=k+1}^t \frac{\partial S_j}{\partial S_{j-1}})\frac{\partial S_k}{\partial W_x}
\frac{\partial L_t}{\partial W_s}=\sum_{k=0}^{t}\frac{\partial L_t}{\partial O_t}\frac{\partial O_t}{\partial S_t}(\prod_{j=k+1}^t \frac{\partial S_j}{\partial S_{j-1}})\frac{\partial S_k}{\partial W_s}
如果神经元经过激活函数,S_t=tanh(W_x·X_t+W_s·S_{t-1}+b_1),那么上式中\prod_{j=k+1}^t \frac{\partial S_j}{\partial S_{j-1}}=\prod_{j=k+1}^t tanh'(y)|_{y=W_x·X_t+W_x·S_{t-1}+b_1}·W_x其中,y=tanh'(x)的函数图像如下

the image of tanh', from [2]
接下来就是RNN的两个劣势,梯度消失梯度爆炸

二、RNN的梯度消失和梯度爆炸💥

通过y=tanh'(x)的函数图像可知,y=tanh'(x) \le 1,如果W_x \le 1,那么随着网络深度的不断加深,\prod_{j=k+1}^t \frac{\partial S_j}{\partial S_{j-1}}会趋于零,即前层的梯度趋于零,那么浅层网络对左右参数的更新能力较弱,这就是(浅层)梯度消失。
同样的道理,如果W_x >> 1,直观表现为tanh'(y)·W_x >>1,那么随着网络深度的不断加深,\prod_{j=k+1}^t \frac{\partial S_j}{\partial S_{j-1}}会趋于无穷大,即梯度趋于无穷大,即梯度爆炸💥。

常用解决方案:
梯度消失:尝试其他网络结构,e.g. LSTM
梯度爆炸:梯度剪裁

三、长期依赖的问题

长期依赖是指当前系统的状态,可能受很长时间之前系统状态的影响,是RNN中无法解决的一个问题。
如果从“这块冰糖味道真?”来预测下一个词,是很容易得出“甜”结果的。但是如果有这么一句话,“他吃了一口菜,被辣的流出了眼泪,满脸通红。旁边的人赶紧给他倒了一杯凉水,他咕咚咕咚喝了两口,才逐渐恢复正常。他气愤地说道:这个菜味道真?”,让你从这句话来预测下一个词,确实很难预测的。因为出现了长期依赖,预测结果要依赖于很长时间之前的信息。
理论上,通过调整参数,RNN是可以学习到时间久远的信息的。但是,实践中的结论是,RNN很难学习到这种信息的。RNN 会丧失学习时间价格较大的信息的能力,导致长期记忆失效。
长期记忆失效的原因
RNN中,S_t=W_xX_t+W_sS_{t-1}+b_1,如果用S_0来表示S_t,那么S_t前的系数为(W_s)^t,如果|W|<1,那么经过多次传递之后,(W_s)^t变得非常小,可以认为S_0对于S_t几乎不产生任何影响,也就是随着时间的推移,浅层信息机会被遗忘,这也就导致了长期记忆失效。
可以用LSTM来解决长期依赖问题。
循环神经网络(RNN)的长期依赖问题——曲曲菜

四、Long Short-Term Memory (LSTM)

传统RNN:

image from [4]
LSTM:
image from [4]
通过对比,LSTM与RNN的区别是细胞状态
C
——上图中为最上层的贯穿整个细胞的水平线,这也是LSTM的核心,它通过来控制信息的增加或者删除,LSTM中用遗忘门, 输入门, 输出门来控制数据流。

门:
门实际上就是一层全连接层,它的输入是一个向量或多个向量,输出是一个0到1之间的数字。
我们可以这样理解,如果门的输出是0, 就表示将门紧紧关闭,为1则表示将门完全打开,而位于0-1之间的实数表示将门半开,至于开的幅度跟这个数的大小有关。
理解 LSTM 网络——朱小虎XiaohuZhu

遗忘门:遗忘门决定了上一时刻的单元状态C_{t-1}有多少保留到当前时刻 C_{t}

image from [4]
f_t:当前遗忘门输出
\sigma: 激活函数 sigmoid
W_f : 遗忘门的权重矩阵
h_{t-1}: 上一时刻LSTM隐层输出值。
x_t: 当前时刻输入
b_f: 遗忘门偏置
[h_{t-1},x_t]:矩阵连接
还需要指明一下各个维度:
d_x:输入数据维度
d_h:隐藏层维度
d_c:细胞状态维度,通常d_c=d_h
W_{fh}:维度为d_c*d_h,对应输入为h_{t-1}
W_{fx}:维度为d_c*d_x,对应输入为x_t
W_f:由W_{fh}W_{fx}拼接而成,维度为d_c*(d_h+d_x),对应输入为[h_{t-1},x_t]

输入门:输入门决定了当前时刻网络的输入x_t有多少保存到单元状态c_t

image from [4]
i_t: 输入门的输出值,是一个0 - 1 之间的实数,决定了当前时刻网络的输入x_t有多少保存到单元状态c_t
W_i: 输入门的权重矩阵
\tilde{C_t}:此时的细胞状态,不包含前期细胞状态C_{t-1}
tanh :激活函数,它输出一个(-1, 1) 的实数值
W_c: 权重矩阵
b_c: 偏置
更新细胞状态
image from [4]
将当前的记忆\tilde{C_t}和长期的记忆C_{t-1}组合在一起,形成新的单元状态C_t。由于遗忘门的控制,它可以保存很久很久之前的信息,由于输入门的控制,它又可以避免当前无关紧要的内容进入记忆。

输出门:控制单元状态C_t有多少输入到LSTM的当前输出值h_t

image from [4]
o_t: 一个位于(0,1) 之间的实数值,用来控制单元状态c_t有多少独处到 LSTM 的当前输出值h_t

五、LSTM的应用

  • handwriting recognition
  • time series prediction, e.g. financial data
  • speech recognition
  • language modelling
  • language translation

参考资料:

  1. RNN梯度消失和爆炸的原因——沉默中的思索
  2. RNN 的梯度消失问题——老宋的茶书会
  3. 循环神经网络(RNN)的长期依赖问题——曲曲菜
  4. 理解 LSTM 网络——朱小虎XiaohuZhu
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。