pyTorch上的TimeDistributed

Keras有个TimeDistributed包装器,pytorch上用nn.Linear就能实现。老是忘在这里记录下:

给定输入in[batch, steps, in_dims],希望在每个step内Dense,然后输出out[batch, steps, out_dims],

只需要直接指定nn.Linear(in_dims, out_dims)就好了,例如:

batchs=2
steps=3
in_dims=4
out_dims=2

m = nn.Linear(in_dims, out_dims, False)
print(m.weight)
Out[16]:
Parameter containing:
tensor([[ 0.4397,  0.0982, -0.0458, -0.0480],
        [-0.1751, -0.2792, -0.2744,  0.1664]], requires_grad=True)

input = torch.ones(batchs, steps, in_dims)
m(input)
Out[17]:
tensor([[[ 0.4440, -0.5623],
         [ 0.4440, -0.5623],
         [ 0.4440, -0.5623]],
        [[ 0.4440, -0.5623],
         [ 0.4440, -0.5623],
         [ 0.4440, -0.5623]]], grad_fn=<UnsafeViewBackward>)

输出维度是[batch,step,out_dims]每个step内作dense

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 官方所有教程的地址:pytorch.org/tutorials 以下是基于实例来入门pytorch Learnin...
    MiracleJQ阅读 5,782评论 0 4
  • 幸福是什么?托尔斯泰说,幸福具备三个要素:有事做,有人爱,有希望。每个人对幸福都有自己的定义和感受。 上世纪80年...
    何处遇见阅读 3,493评论 0 5
  • 如果我可以更努力 如果我可以更痴迷 如果我不曾软弱 如果我不曾犹疑 如果我可以提早预见 如果我可以洒脱放弃 如果这...
    淡墨染重峦阅读 1,287评论 0 1
  • NoSQL NoSQL,指非关系型的数据库。NoSQL有时也称作Not Only SQL的缩写,是对不同于传统的关...
    honehou阅读 9,137评论 0 5