3.Pytorch 中 torch.stack()/vstack()/hstack()和torch.cat()
1.torch.stack()
torch.stack(tensors, dim=0, *, out=None) → Tensor
作用:
Concatenates a sequence of tensors along a new dimension. All tensors need to be of the same size.
把一系列tensor沿着新的维度堆起来。注意要tensor都一样的size,并且会增加一个维度。默认,dim=0.
x = torch.arange(9).view(3,3)
print(x)
print("---")
new_x = torch.stack([x, x, x])
print(new_x.shape)
print(new_x)
================================================================
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
---
torch.Size([3, 3, 3])
tensor([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]],
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]],
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
如果指定维数的话,
- dim=0时, 这个维度是3, 将其看作3行,那么特征数是2。将2个特征列依次叠加。
- dim=1时, 这个维度是2, 将其看作2列,那么特征数是3。将3个特征行依次叠加。
a = torch.arange(0, 6).view((3, 2))
b = torch.arange(6, 12).view((3, 2))
print('a:', a)
print('b:', b)
ab0= torch.stack((a, b), dim=0)
ab1 = torch.stack((a, b), dim=1)
print(ab0, '\n', ab1)
+++++++++++++++++++++++++++++++++++++++++++
a: tensor([[0, 1],
[2, 3],
[4, 5]])
b: tensor([[ 6, 7],
[ 8, 9],
[10, 11]])
tensor([[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]]])
tensor([[[ 0, 1],
[ 6, 7]],
[[ 2, 3],
[ 8, 9]],
[[ 4, 5],
[10, 11]]])
2. torch.vstack()和 torch.hstack()
torch.vstack(tensors, *, out=None) → Tensor
和hstack()
作用:这两个方法在1.8.0之后才支持,没有就用torch.cat()
在竖直、水平方向上堆tensor
ab_vstack_0 = torch.vstack((a, b))
ab_vstack_1 = torch.vstack((a, b))
print('ab_vstack_0 :', ab_vstack_0 )
print('ab_vstack_1 :', ab_vstack_1 )
print(torch.__version__)
++++++++++++++++++++++++++++++++++++++++++++
ab_vstack_0 : tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11]])
ab_vstack_1 : tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11]])
1.8.0+cu101
3. torch.cat()
torch.cat(tensors, dim=0, *, out=None) → Tensor
与 torch.stack()
区别是:不增加维度
作用:
Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.
按照指定维度连接tensor,所有tensor必须有同样的shape, 除了指定合并的维度或者是空tensor。
ab_cat_0 = torch.cat((a, b), dim=0)
ab_cat_1 = torch.cat((a, b), dim=1)
print('ab_cat_0 :', ab_cat_0 )
print('ab_cat_1 :', ab_cat_1 )
++++++++++++++++++++++++++++++++++++
ab_cat_0 : tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11]])
ab_cat_1 : tensor([[ 0, 1, 6, 7],
[ 2, 3, 8, 9],
[ 4, 5, 10, 11]])