torch中LSTM层的理解与记录

import torch
import torch.nn as nn

lstm = nn.LSTM(
    input_size=10,
    hidden_size=20,
    num_layers=1,
    batch_first=True
)

input = torch.randn(3, 5, 10)   # batch_size=3, seq_len=5, num_features=10
h0 = torch.randn(1, 3, 20)
c0 = torch.randn(1, 3, 20)

output, (h, c) = lstm(input, (h0, c0))

"""
h和c都是三维张量,其中第一维度表示该LSTM层的层数num_layers,默认为1
output是三维张量
output[:, -1, :] 与 h[-1, :, :]是一样的

当多个LSTM层叠加时,它们之间的数据传递用每一层的output
最后一个LSTM层与全连接层相连时,采用最后一层的h[-1, :, :]作为全连接层的输入
"""
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容