本文公式显示效果不太好,可移步至LSTM学习笔记
Long Short-Term Memory(LSTM) 是一种循环神经网络(Recurrent Neural Network, RNN)。跟所有RNN一样,在网络单元足够多的条件下,LSTM可以计算传统计算机所能计算的任何东西。
Like most RNNs, an LSTM network is universal in the sense that given enough network units it can compute anything a conventional computer can compute. 维基百科
RNN
传统前馈神经网络(feedforward neural networks)如下图所示:
前馈神经网络只从输入节点接受信息,它只能对输入空间进行操作,对不同时间序列的输入没有“记忆”。在前馈神经网络中,信息只能从输入层流向隐藏层,再流向输出层。这种网络无法解决带有时序性的问题,比如预测句子中的下一个单词,这种情况下,往往需要使用到前面已知的单词。假设要预测这样一句话:百度是一家____公司。显然,受到前面的词语“百度”的影响,横线中填入“互联网”的概率远大于“金融”,即使单从语法上考虑,填入“金融”也是正确的。
RNN与前馈神经网络最大的不同是,它不仅能对输入空间进行操作,还能对内部状态空间进行操作,它的结构如下:
可以看到,RNN的隐藏层多了一条连向自己的边。因此,它的输入不仅包括输入层的数据,还包括了来自上一时刻的隐藏层的输出。
RNN可以采用BPTT(Back-Propagation Through Time)算法进行学习。BPTT和BP算法类似,都是基于梯度的训练方法。首先,将RNN按时间序列展开,如下图:
图中表示的是时间步长数为3的RNN的展开。一般的,时间步长为T的RNN展开后将含有T个隐藏层,T可以是任意的。当T为1时,RNN退化为一个普通的前馈神经网络。将RNN展开之后,就可以使用与训练带BP的前馈神经网络类似的方法进行训练。有一点需要注意的是,在展开的RNN中,每一层隐藏层实际上是相同的(只是在不同时刻的副本),也就是说,它们最后得到的参数必须是一致的。在训练过程中,不同时刻的隐藏层的参数可能会不一致,最后可以将它们的平均数作为模型的参数。
关于RNN的内容,具体的可以参考“A guide to recurrent neural networks and backpropagation”。
LSTM
梯度消失与避免
在实际应用中,上述RNN模型存在着梯度消失和梯度爆炸的问题。根据链式法则,输出误差对于输入层的偏导等于各层偏导的乘积(关于前馈神经网络的误差传播参考“一文弄懂神经网络中的反向传播法——BackPropagation”)。
假设使用的是平均平方误差,则在时刻t,输出层k的误差信号表示为(以下推导来自“Long Short-Term Memory”,本文采用与论文一致的表示方法):
其中,
是表示非输入单元,fi是可微函数,
表示当前网络单元的输入,wij是单元j和i的之间权重。非输出单元j的反向误差信号为:
对于时间步长为q的RNN网络,在t-q时刻的的误差可以通过以下递归函数来求解:
令lq = v,l0 = u,上式可以进一步写成:
由于梯度最终以乘积的形式得出,若乘式中的每一项(或大部分)都大于1,
则将导致梯度爆炸;若每一项都小于1,
则随着乘法次数的增加,梯度会消失。
梯度消失和梯度爆炸都会严重影响学习的过程。
为了避免梯度消失和梯度爆炸,一个简单的做法是强制让流过每个神经元的误差都为1,即
简单推导可以知道,f是一个线性函数。这样就保证了误差将以参数的形式在网络中流动,不会出现梯度爆炸或者梯度消失的问题,把这样的结构称为CEC(constant error carousel)。但是这种做法存在着权重冲突的问题。
权重冲突问题与解决
前面说到,RNN的隐藏层同时接受外界信息和上一时刻隐藏层的输出作为输入,回到前面将RNN展开成深层网络的情况。隐藏层在每一个时刻t的输出都通过权重向量U影响着下一个时刻t+1的隐藏层。而实际上,对于不同的时刻t, t+1, ..., t+k,图中隐藏层的权重向量V和U是相同的。虽然在展开图中它们看起来像是不同层次的节点,但实际上,它们都是同一个实体。
这就产生一个问题,在某一个时刻t,隐藏层可能需要使得权重向量U整体有一个较大的值,即t-1时刻隐藏层的输出对t时刻很重要;而在时刻t+k,隐藏层可能需要使权重U有一个较小的值,也就是说此时隐藏层不想受到t+k-1时刻的计算结果的影响。考虑现实的例子,假如要预测以下句子中动词的形式,括号中为需要预测的词:I may (go) home and (get) my book, but that (depends)。句子中go和get的形式受到may的影响,而depends的形态由that决定,而与may无关。因此预测过程中,当读入and时,may的影响还在,因此应该增大权重U的值,而在读入that的时候,应该同时减小来自上一时刻的隐藏层输出的影响,即减小权重U的,这就与前面产生了冲突。
为了解决这种冲突,Sepp Hochreiter和Jürgen Schmidhuber早在1997年就在论文“Long Short-Term Memory”提出了LSTM,下面的部分图以及公式来自该论文。为了使隐藏层的输出对下一个时刻的影响变得可控,LSTM引入了输入门,输出门的概念。LSTM的基本单元称为记忆元件(memory cell),它是在CEC的基础上扩展而成的,如下图:
图中,3表示输入门(input gate);6表示输出门(output gate);4是CEC单元,可以看到它有一条到自己的权值为1的边。之后,又有人将LSTM进一步扩展,引入了遗忘门,如下图(图片内容出自“Understanding LSTM Networks”):
为了方便表示,把上一个图称为lstm-1,这个图称为lstm-2。lstm-2中与lstm-1编号相同的单元分别一一对应,可以看到,lstm-2比lstm-1多了单元8,它就是所谓的遗忘门。lstm-2中元素的含义如下:
表示一个门,它由一个网络层和一个乘法单元构成。
在上面的图例中,每一条黑线表示信息(向量)的传递,从一个节点的输出到其他节点的输入。粉色的圈代表 pointwise 的操作,诸如向量的和,而黄色的矩形表示学习到的神经网络层。合在一起的线表示向量的连接,分开的线表示内容被复制,然后分发到不同的位置。
LSTM的信息传播
输入信息前向传播
(以下图片内容出自“Understanding LSTM Networks”)
首先,记忆元件(memory cell)接受上一个时刻的输出(ht-1)以及这个时刻的外界信息(xt)作为输入,将它们合并成一个长向量,经过𝞼变换成为ft。
接着,前面说到的两个输入合并成的向量又分别进行了𝞼变换和tanh变换,成为it和Ĉt进入输入门。其中it的值用于决定是否要接受输入信息(即Ĉt)。
之后,根据ft的值决定是否保留上一时刻隐藏层CEC的输出(Ct-1),再将经过it缩放之后的输入Ĉt累加到CEC(Ct-1)中,成为Ct。
最后在输出门中,由Ot决定是否将经过tanh变换的Ct输出,作为下一时刻输入的一部分。
Sepp Hochreiter和Jürgen Schmidhuber的论文中还提到一种称为记忆元件块(memory cell block)的部件。由S个共享相同的输入门、输出门(论文中的LSTM并没有记忆门)的记忆元件所构成的结构称为“大小为S的记忆元件块”。当S = 1时,该结构就是普通的记忆元件。
误差的反向传播
首先,规定下标的表示如下:
- k:输出单元
- i:隐藏单元
- Cj:第j个记忆元件块
- Cjv:第j个记忆元件块Cj中的第v个单元
- l, m, u:任意的网络单元
- t:给定输入序列所有的时间步长
在t时刻,LSTM的平方误差计算如下:
其中,tk(t)是输出单元k在t时刻的输出目标。
在学习率为𝞪的条件下,wlm基于梯度的更新如下:
将单元l在t时刻的误差(error)定义为:
使用标准的方向误差传播算法(backdrop)就可以计算出输出单元(output unit,l = k)、隐藏单元(hidden unit,l = i)、输出门单元(output gate unit,l = outj)的权重更新:
对于所有可能的单元l,时刻t对权重wlm所贡献的更新为:
eoutj(t)的计算式子中,括号内的h函数一开始把我卡住了,但是看回记忆元件的结构图:
可以发现输出门的计算公式为:gateout = h*youtj,其中youtj是netoutj的函数,h与netoutj无关,根据求导法则,h被保留了下来,gateout' = h*(youtj)',于是得到上式。
剩下的对输入门单元(l = inj)和记忆元件单元<l = Cjv的更新与常规的单元会有些差别。定义内部状态SCjv的误差为:
虽然上式乍看之下形式有些复杂,但仔细分析可以发现,这与求输出门的梯度的情况是类似的,由于youtj是与SCj无关的项,所以在求导过程中被保留了下来。
由以上推导可以得到,当l = inj或者l = Cjv,v = 1, ..., Sj时的误差:
中间状态单元SCjv对于输入netinj单元的权重winj的偏导可以计算如下:
因此,时刻t对winjm更新的贡献为:
类似地,可以得到SCjv对于netCjv的权重wCjv的偏导为:
因此,时刻t对wCjv更新的贡献为:
以上就是反向传播算法所需要使用到的等式。在更新权重的过程中,每个权重总的更新值是所有时刻t对w权重更新的贡献之和。
LSTM参数更新的计算复杂度
LSTM每次更新的计算复杂度为:
其中,K表示输出单元的数量,C表示记忆元件块的数量,S>0表示记忆元件块的大小,H是隐藏单元的数量,I是与记忆元件、门单元和隐藏单元直接相连(forward-connected)的单元的数量,而
是权重的数量。