1.torch.expand()
torch.expand()
, 只能把维度为1的拓展成指定维度。如果哪个维度为-1,就是该维度不变。
x = torch.rand((2, 1, 3, 1))
x_expand = x.expand(2, 3, 3, 2)
x_expand_1 = x.expand(-1, -1, -1, 4)
print(x.shape)
print(x_expand.shape)
print(x_expand_1.shape)
+++++++++++++++++++++++++++++++++++++++++++++++++++++++
torch.Size([2, 1, 3, 1])
torch.Size([2, 3, 3, 2])
torch.Size([2, 1, 3, 4])
2. torch.repeat()
torch.repeat()
里面参数代表是重复多少次,就是复制多少次,比如下面2, 3, 1, 6代表复制2, 3, 1, 6次,原来为2, 1, 3, 1。相乘就是后面维度:4, 3, 3, 6. 它不允许使用参数 -1
。
print(x.shape)
x_rep = x.repeat(2, 3, 1, 6)
print(x_rep.shape)
++++++++++++++++++++++++++++++++++++++++++++++
torch.Size([2, 1, 3, 1])
torch.Size([4, 3, 3, 6])