4. torch.expand() 和 torch.repeat()

1.torch.expand()

torch.expand(), 只能把维度为1的拓展成指定维度。如果哪个维度为-1,就是该维度不变。

x = torch.rand((2, 1, 3, 1))
x_expand = x.expand(2, 3, 3, 2)
x_expand_1 = x.expand(-1, -1, -1, 4)

print(x.shape)
print(x_expand.shape)
print(x_expand_1.shape)


+++++++++++++++++++++++++++++++++++++++++++++++++++++++
torch.Size([2, 1, 3, 1])
torch.Size([2, 3, 3, 2])
torch.Size([2, 1, 3, 4])

2. torch.repeat()

torch.repeat()里面参数代表是重复多少次,就是复制多少次,比如下面2, 3, 1, 6代表复制2, 3, 1, 6次,原来为2, 1, 3, 1。相乘就是后面维度:4, 3, 3, 6. 它不允许使用参数 -1

print(x.shape)
x_rep = x.repeat(2, 3, 1, 6)
print(x_rep.shape)

++++++++++++++++++++++++++++++++++++++++++++++
torch.Size([2, 1, 3, 1])
torch.Size([4, 3, 3, 6])
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容