import torch
import torch.nn as nn
lstm = nn.LSTM(
input_size=10,
hidden_size=20,
num_layers=1,
batch_first=True
)
input = torch.randn(3, 5, 10) # batch_size=3, seq_len=5, num_features=10
h0 = torch.randn(1, 3, 20)
c0 = torch.randn(1, 3, 20)
output, (h, c) = lstm(input, (h0, c0))
"""
h和c都是三维张量,其中第一维度表示该LSTM层的层数num_layers,默认为1
output是三维张量
output[:, -1, :] 与 h[-1, :, :]是一样的
当多个LSTM层叠加时,它们之间的数据传递用每一层的output
最后一个LSTM层与全连接层相连时,采用最后一层的h[-1, :, :]作为全连接层的输入
"""
torch中LSTM层的理解与记录
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
推荐阅读更多精彩内容
- 一次上《记梁任公先生的一次演讲》这课时,有同学问我,为什么梁启超讲到《桃花扇》会痛哭流涕而不能自已。那今天我就这个...
- 这8种学生永远拿不到高分!早看早受益! 下面是一位资深班主任总结了8种成绩提不上去的原因,分别对应8类孩子,如果你...