Tensor operation types:tensor四种操作
-
Reshaping operations
1.1 reshape
1.2 squeezing and unsqueezing
1.3 flatten a tensor
1.4 concatenating tensors: terch.cat/torch.stack - Element-wise operations
- Reduction operations
- Access operations
1. Stack vs Cat in PyTorch
torch.cat
和torch.stack
都是张量拼接相关的操作,二者有什么不同?
Concatenating joins a sequence of tensors along an existing axis.
Stacking joins a sequence of tensors along a new axis.
torch.cat
: 在已存在的轴上进行拼接
torch.cat
torch.stack
: 在新轴上进行拼接
等价于先添加新轴,再在新轴上使用cat进行拼接
dim=0
dim=1
在PyTorch中向张量添加轴,使用unsqueeze()
函数
2. 实际例子
2.1 将图像拼接为一个batch
假设有三张独立的图像张量。每个张量有三个维度,即颜色通道轴c,高度轴h,宽度轴w。现在,假设需要将这些张量拼接在一起以形成三张图像的a single batch tensor
单批张量。使用torch.cat
还是torch.stack
?
拼接前:
拼接后:
使用torch.stack
import torch
t1 = torch.zeros(3,28,28)
t2 = torch.zeros(3,28,28)
t3 = torch.zeros(3,28,28)
torch.stack(
(t1,t2,t3)
,dim=0
).shape
## output ##
torch.Size([3, 3, 28, 28])
2.1 将不同batches
拼接为单个batch
有三个batch_size
为1的batches
,如何合并为一个batch
呢?
拼接前:
拼接后:
使用torch.cat
import torch
t1 = torch.zeros(1,3,28,28)
t2 = torch.zeros(1,3,28,28)
t3 = torch.zeros(1,3,28,28)
torch.cat(
(t1,t2,t3)
,dim=0
).shape
## output ##
torch.Size([3, 3, 28, 28])
2.3 拼接图像和batch
假设有三个独立图像张量,并且已经有了一个batch
,如何将这三张图像与batch
拼接在一起?
拼接前:
拼接后:
先对三张图片使用torch.stack
,拼接成一个batch
再使用torch.cat
import torch
batch = torch.zeros(3,3,28,28)
t1 = torch.zeros(3,28,28)
t2 = torch.zeros(3,28,28)
t3 = torch.zeros(3,28,28)
torch.cat(
(
batch
,torch.stack(
(t1,t2,t3)
,dim=0
)
)
,dim=0
).shape
## output ##
torch.Size([6, 3, 28, 28])
等价于:
import torch
batch = torch.zeros(3,3,28,28)
t1 = torch.zeros(3,28,28)
t2 = torch.zeros(3,28,28)
t3 = torch.zeros(3,28,28)
torch.cat(
(
batch
,t1.unsqueeze(0)
,t2.unsqueeze(0)
,t3.unsqueeze(0)
)
,dim=0
).shape
## output ##
torch.Size([6, 3, 28, 28])