Pytorch 中 Bi-GRU / Bi-LSTM 的输出问题

  在 PyTorch 中,GRU / LSTM 模块的调用十分方便,以 GRU 为例,如下:

import torch
from torch.nn import LSTM, GRU
from torch.autograd import Variable
import numpy as np

# [batch_size, seq_len, input_feature_size]
random_input = Variable(torch.FloatTensor(1, 5, 1).normal_(), requires_grad=False)
gru = GRU( 
  input_size=1, hidden_size=1, num_layers=1, 
  batch_first=True, bidirectional=False
)
# output: [batch_size, seq_len, num_direction * hidden_size]
# hidden: [num_layers * num_directions, batch, hidden_size]
output, hidden = gru(random_input)

  其中,output[:, -1, :] 即为 hidden。LSTM 只是比 GRU 多了一个返回值 cell_state,其余不变。
  当我们将 bidirectional 参数设置为 True 的时候,GRU/LSTM 会自动地将两个方向的状态拼接起来。遇到一些序列分类问题,我们常常会将 Bi-GRU/LSTM 的最后一个隐状态输出到分类层中,也即使用 output[:, -1, :],那么这样做是否正确呢?
  考虑这样一个问题:当模型正向遍历序列1, 2, 3, 4, 5 的时候,output[:, -1, :] 是依次计算节点 1~5 之后的隐状态;当模型反向遍历序列1, 2, 3, 4, 5 的时候,t = 5 位置对应的隐状态仅仅是计算了节点 5 之后的隐状态。output[:, -1, :] 就是拼接了上述两个向量的特征,但我们想要放入分类层的逆序特征应该是 t=1 位置对应的隐状态,也即依次遍历 5~1 节点、编码整个序列信息的特征。
  下面通过具体的代码佐证上述结论,样例主要参考 Understanding Bidirectional RNN in PyTorch

1) 数据 & 模型准备

# import 如上
random_input = Variable(torch.FloatTensor(1, 5, 1).normal_(), requires_grad=False)
# random_input[0, :, 0]
# tensor([ 0.0929,  0.6335,  0.6090, -0.0992,  0.7811])

# 分别建立一个 双向 和 单向 GRU
bi_gru = GRU(input_size=1, hidden_size=1, num_layers=1, batch_first=True, bidirectional=True)
reverse_gru = GRU(input_size=1, hidden_size=1, num_layers=1, batch_first=True, bidirectional=False)

# 使 reverse_gru 的参数与 bi_gru 中逆序计算的部分保持一致
# 这样 reverse_gru 就可以等价于 bi_gru 的逆序部分
reverse_gru.weight_ih_l0 = bi_gru.weight_ih_l0_reverse
reverse_gru.weight_hh_l0 = bi_gru.weight_hh_l0_reverse
reverse_gru.bias_ih_l0 = bi_gru.bias_ih_l0_reverse
reverse_gru.bias_hh_l0 = bi_gru.bias_hh_l0_reverse

# random_input 正序输入 bi_gru,逆序输入 reverse_gru
bi_output, bi_hidden = bi_gru(random_input)
reverse_output, reverse_hidden = reverse_gru(random_input[:, np.arange(4, -1, -1), :])

2)结果对比

bi_output
'''
# shape = [1, 5, 2]
tensor([[[0.0867, 0.7053],
         [0.2305, 0.6983],
         [0.3245, 0.5996],
         [0.2290, 0.4437],
         [0.3471, 0.3395]]], grad_fn=<TransposeBackward1>)
'''

reverse_output
# shape = [1, 5, 1]
'''
tensor([[[0.3395],
         [0.4437],
         [0.5996],
         [0.6983],
         [0.7053]]], grad_fn=<TransposeBackward1>)
'''

  捋一捋,先只看 reverse_gru,这是个单向gru,我们输入了一个序列,那么编码了真格序列信息的隐状态自然是最后一个隐状态,也即 0.7053 是序列 [0.7811, -0.0992, 0.609, 0.6335, 0.0929] 的最后一个隐状态(序列向量);bi_output 的第二列代表着逆向编码的结果,刚好是 reverse_output 的倒序,如果我们直接把 bi_output[:, -1, :] 作为序列向量,显然是不符合期望的。正确的做法是:

Method 1:
seq_vec = torch.cat(bi_output[:, -1, 0], bi_output[:, 0, 1])
'''
tensor([0.3471, 0.7053], grad_fn=<CatBackward>)
'''

Method 2:
seq_vec = bi_hidden.reshape([bi_hidden.shape[0], -1])
'''
tensor([[0.3471],
        [0.7053]], grad_fn=<ViewBackward>)
'''

  也即 hidden 这个变量是返回了 序列编码 的信息,满足了我们的要求,可以放心用,也推荐使用第二种方法,少做不必要折腾。

bi_hidden
'''
tensor([[[0.3471]],
        [[0.7053]]], grad_fn=<StackBackward>)
'''

reverse_hidden
'''
tensor([[[0.7053]]], grad_fn=<StackBackward>)
'''
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,245评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,749评论 3 391
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,960评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,575评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,668评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,670评论 1 294
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,664评论 3 415
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,422评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,864评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,178评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,340评论 1 344
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,015评论 5 340
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,646评论 3 323
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,265评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,494评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,261评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,206评论 2 352

推荐阅读更多精彩内容