个人理解
我们知道rnn主要来维持一个cell state(上文信息),用来预测句子的下文,但如果句子很长,我们将上文中所有的信息都记住了,这会导致两个问题:
-
rnn反向传播时会导致梯度消失或爆炸
我们知道反向更新求导时,遵循BP的链式求导法则,可以简要认为是
其中连乘的部分可以看作是:
因为tanh的导数通常在(0,1),所以如果时刻太多,经过连乘之后,值会趋近于0,W如果也<1,则会出现梯度消失,反之W如果很大,则会出现梯度爆炸。
2.如果我们把句子中每个时刻的信息全部记住,需要训练的参数会呈现O(x^2)的增长,句子过长将非常难以训练
而LSTM的出现正好可以解决上面两个问题:
lstm的原理
如参考引用例子(小丽是一个女孩,她唱歌很好,小明是一个男孩,他篮球很好),当预测她时,我们希望记住上文中的小丽,当预测他时,我们希望忘记小丽,记住小明
遗忘门
ht-1为上时刻输出结果,xt为当前时刻输入,Wf为训练参数,目的是将[ht-1, xt]转为同一维度,sigmod为激活函数,目的将所有维度的值变到(0,1)之间,可以认为是可筛选器,这样得到的ft * Ct-1,将ft中接近于0的值 进行清除,接近于1的值保存下来
这里就可以解释rnn中的两个问题
- 梯度爆炸和消失
梯度爆炸和消失主要是存在tanh的导数和W进行连乘,而lstm增加遗忘门解决这点,可以认为将训练参数W也进行了sigmoid,所以tanh取值为(0,1),W取值也为(0,1),而lstm做的事情就是让他们两个的乘积要么为0,要么为1
2.0 解决训练参数过大
我们可以看出ft中为0的点将被清除掉,那么这个点的参数w在训练时也会进行清楚,这样会大大减少训练参数量
输入门
如图所示,it为输入做了一个筛选操作,而带波浪的Ct则做了个维度同一操作
输出门
同样ot做了一个输出筛选操作,tanh将Ct中每一位变到(-1,1)之间