Tensorflow RNN中使用dropout的一些坑

使用 tensorflow 中的 DropoutWrapper 引发的问题

最近炼(tiao)丹(can)的时候遇到了 RNN 模型过拟合比较严重的问题,当时只是在 RNN 的输入特征加了 dropout。于是尝试在 RNN 的状态向量中也引入 dropout,具体方法可以查看参考文献1,tensorflow 也根据此文献实现了针对 RNN 的 dropout,函数如下:

class DropoutWrapper(RNNCell):
  """Operator adding dropout to inputs and outputs of the given cell."""

  def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0,
               state_keep_prob=1.0, variational_recurrent=False,
               input_size=None, dtype=None, seed=None,
               dropout_state_filter_visitor=None):

其中三个 keep_prob 分别决定 RNN 的输入、输出、状态向量的 dropout 概率。我最初没有用文献1提出的在 RNN 内部使用 dropout,所以只是把 input_keep_prob 设小于1的概率。然后在引入 RNN dropout 时,将 input_keep_prob 和 state_keep_prob 都设置为 小于1的值,并且把 variational_recurrent 设置为 True,此参数决定了是否使用文献1的方式。

注意文献1提出的 RNN dropout 方式是对于一个序列的所有时间步,输入向量、输出向量、状态向量分别采用相同的 dropout mask,如图所示,相同颜色表示相同的 dropout mask:
variational RNN dropout.png

注意如果将 state_keep_prob 设置为小于1的值,一般都会把 variational_recurrent 设置为 True,使用上述方式对状态进行 dropout。否则,每个时间步的状态进行随机的 dropout 会导致 RNN 几乎无法记录长期依赖特征,这样的 dropout 反而会使 RNN 性能变差。

当时使用的是 GRUCell,是普通 RNN 的改进版。兴高采烈地对 GRUCell 引入了 DropoutWrapper,state_keep_prob 设为 0.8,variational_recurrent 设为 True,然后开始训练。结果一脸懵逼,loss 直接变成了 nan,根据以往的经验猜测是发生了梯度爆炸。接着,我把 state_keep_prob 设为 1.0 即不引入状态的 dropout,看是否是这里引起的,果然在这样修改之后 loss 又正常下降了。最后猜测是不是 GRU 的特殊结构导致进行这样的 dropout 时会有问题,于是将 GRUCell 替换成了 LSTMCell,state_keep_prob 改回0.8,没想到这回 loss 正常下降了!总结下来是:tensorflow 的 variational_recurrent dropout 与 LSTM 结合能正常使用,但是与 GRU 结合却有问题。

发现问题根结

由于不确定是代码的 bug 还是算法结构的问题,所以 Google 一下,发现果然有大牛已经发现了问题。大牛在参考文献3 tensorflow 库的 issues 里提出代码和算法结构存在的问题。
下图是 GRU 的计算结构:

GRU.png
在对 GRU 的状态向量引入 dropout 后结构如下图,图中 h* 表示 dropout 后的状态向量。我们知道,dropout 根据 1-keep_rate 的概率生成一个 mask 向量,根据 mask 对状态向量中某些值置0,然后为了保证向量在后续与权值矩阵 W 相乘后得到的结果向量的 scale 与不 dropout 时一致,会对向量 h 中的对应 mask 非 0 的值除以 keep_rate。在我的实验中 dropout 对h中每个对应 mask 非0值除以0.8即乘以1.25。如前所述,由于变分 RNN dropout 中所有时间步的 dropout mask 都是相同的,所以对于一个长 n 的序列,状态向量 h* 中有些值在这个序列中永远是0,而另外一些值每经过一个时间步就要乘以 1/keep_rate(我的实验中是1.25),一个序列计算完后状态向量的值要乘以 (1/keep_rate)n,这些值在长序列情况下会变得非常大甚至溢出。这样就解释了为什么将 GRU 和 variational RNN dropout 结合使用的时候 loss 会变成 nan。
GRU with dropout.png

然后看看 LSTM,LSTM 的结构如下图所示。由图中结构可以看到,LSTM 的状态向量 h 每次都与矩阵相乘后再使用,这样可以保证即使每个时间步 h 的某些值会乘以 1/keep_rate,在与矩阵相乘后不会造成像 GRU 那样 h 的值呈指数上升的情况。
LSTM.png

但是在这个 issue 里面大牛还是指出了 LSTM 在当时用这样的 dropout 有问题,于是看了这个 issue 关闭时关联的修改 commit(参考文献4)。发现当时tensorflow 在实现 variational rnn dropout 时没有严格遵守文献1的方法,它对 LSTM 的记忆状态 c 也进行了类似的 dropout,这样就导致了 c 的值会想 GRU 的 h 那样指数爆炸。

最后,tensorflow 在参考文献4的 commit 里面修改了这个 bug,去掉了对 c 的 dropout,但是仍然没有解决 GRU 的问题。所以目前可以将 LSTM 和 variational rnn dropout 结合使用,但不能将 GRU 与 variational rnn dropout 结合使用。

总结与解决方案

  • 由于 GRU 与 LSTM 结构差异,variantional RNN dropout 可以很好地在 LSTM 中使用,但是不能在 GRU 中使用,variantional RNN dropout 有这样的局限性。
  • 由于 tensorflow 1.4.0 之前的版本在 DropoutWrapper 中错误的实现了 variantional RNN dropout,所以当时在 LSTM 中使用也会有问题,但是在 1.4.0 版本中已经解决了,参考文献4。
  • 对于 GRU,可以尝试参考文献2中提出的 dropout 方式,该方式不会对状态向量进行 dropout,只对一些门控进行 dropout,所以这样的方式适用于 GRU 和 LSTM。如图所示,左右两边分别是在 LSTM 中进行 variantional RNN dropout 和门控dropout的结构,虚线表示 dropout 连接。只是 tensorflow 中没有实现右边这种 dropout,可以参考文献5自己实现。
    两种不同的 RNN dropout 方式.png

参考文献

1、《A Theoretically Grounded Application of Dropout in Recurrent Neural Networks》
2、《Recurrent Dropout without Memory Loss》
3、https://github.com/tensorflow/tensorflow/issues/11650
4、https://github.com/tensorflow/tensorflow/commit/cb3314159fe102419289d394246d7ac9c2a422c1
5、https://github.com/stas-semeniuta/drop-rnn

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,588评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,456评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,146评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,387评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,481评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,510评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,522评论 3 414
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,296评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,745评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,039评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,202评论 1 343
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,901评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,538评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,165评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,415评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,081评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,085评论 2 352

推荐阅读更多精彩内容

  • 1 2年前,刚刚大学毕业的小敏被一家大型文化公司录取,同时被录取的共有十个人。 上班第一个月的某天,快下班时部门主...
    大俗小雅阅读 936评论 0 4
  • DISC性格分析中D型人是“领导型”。这类人最突出的特点是思维敏捷,做事效率高,不害怕别人的攻击,喜欢创新,行动力...
    W南茜阅读 10,801评论 0 0
  • 来北京十多天了,大大小小的面试也有七八个了,喜欢我的和我喜欢的总是差一点点,俗话说差之毫厘,失之千里,感觉自己快魔...
    初时不语阅读 310评论 7 3
  • 穷则思变。 今天卞工对我发了脾气。 也怪我自己不好,没有把系统当回事儿。一个劲儿的不知道在干什么。 可能是在消极抵...
    阿卡墙阅读 309评论 0 0