本篇主要介绍pytorch中tensor的基本操作如:对tensor进行flatten操作, 对tensor进行拼接,tensor的broadcast(广播)机制.
1. 对tensor进行flatten操作
flatten顾名思义就是平展,将一个tensor由高维(rank)变为1维,元素的数量保持不变,这在深度学习中很常用,当输入是一个图像时候,图像的维度(rank)是3,当网络的输入层只能输入一维的数据时(如全连接层),flatten操作就显得非常有用了.
下面我们说两种flatten的实现方式:使用上篇的squeeze,reshape函数间接实现;使用tensor自带的flatten函数实现.
1. 使用squeeze,reshape函数间接实现
实现代码如下,我们可以写一个flatten函数,该函数的功能是输入一个tensor,将它维度变为1输出.
def flatten(t):
t = t.reshape(1, -1)
t = t.squeeze()
return t
函数的第一行先利用reshape函数将tensor变为2维度,具体可以参考上篇. 此时可以注意到,第一个axis的长度为1,因此第二行使用squeeze函数将长度为1的axis去掉,这时候tensor的维度自然而然就变成1啦!
举个栗子吧:
t = torch.tensor([
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]
], dtype=torch.float32)
print(t.shape)
print(flatten(t))
print(flatten(t).shape)
output:
torch.Size([3, 4])
tensor([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.])
torch.Size([12])
可以看到,使用我们自己写的flatten函数,将tensor变成了1维,但这很不方便欸 ,有自带的干嘛不用呢,用它!
2. 直接使用tensor的flatten函数
tensor是有flatten函数的,话不多说,直接举个栗子:
t1 = torch.ones(4, 4)
t2 = torch.ones(4, 4) * 2
t3 = torch.ones(4, 4) * 3
t = torch.stack((t1, t2, t3))
t = t.reshape(3, 1, 4, 4)
print(t.flatten(start_dim=1))
print(t.flatten(start_dim=1).shape)
print(t.reshape(t.shape[0], -1))
print(t.flatten().shape)
output:
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]])
torch.Size([3, 16])
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]])
torch.Size([48])
可以看到flatten有个参数start_dim,表示从start_dim到最后一个维度都做平展操作,因此t.flatten(start_dim=1)后,就只有第0维保持不变,其它维度做了flatten,变成了一维,从而shape变为torch.Size([3, 16]),这在深度学习中,向全连接层输入时经常会使用到.
当flatten没有参数时,默认将整个tensor进行flatten操作,即start_dim=0,因此经过t.flatten(),其shape变为torch.Size([48]).
2. 对tensor进行拼接
将2个甚至更多的tensor进行拼接是经常要使用到的功能,pytorch中对tensor拼接常用的函数为cat,举个栗子吧:
t1 = torch.tensor([
[1, 2],
[3, 4]
])
t2 = torch.tensor([
[5, 6],
[7, 8]
])
print(torch.cat((t1, t2), dim=0))
print(torch.cat((t1, t2), dim=0).shape)
print(torch.cat((t1, t2), dim=1))
print(torch.cat((t1, t2), dim=1).shape)
output:
torch.Size([12])
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
torch.Size([4, 2])
tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])
torch.Size([2, 4])
可以看到,cat中的dim参数决定了拼接的维度.
当dim=0时,在第0维(第1个axis)进行拼接 ,其余维度长度不变,因此shape变为torch.Size([4, 2]).
当dim=1时,在第1维(第2个axis)进行拼接 ,其余维度长度不变,因此shape变为torch.Size([2, 4]).
为了达成拼接的目的,很容易我们可以看出,对于要拼接的tensor(可以不止2个),除了需要拼接的维度,其余维度的长度必须保持相同,否则会引起错误.
3. tensor的broadcast(广播)机制
在数学运算中,两个形状不同的矩阵进行加减运算显然是不行的,但对于tensor,在某些形况下是完全可以的,这得益于pytorch中tensor的broadcast机制.
举个栗子:
t1 = torch.tensor([
[1, 2],
[3, 4]
])
t2 = torch.tensor([
[9, 8],
[7, 6]
])
print(t1 + t2)
print(t1 + 2)
print(t1 + torch.tensor(
np.broadcast_to(2, t1.shape),
dtype=torch.int32
)) # equal to last line
output:
tensor([[10, 10],
[10, 10]])
tensor([[3, 4],
[5, 6]])
tensor([[3, 4],
[5, 6]])
从倒数第二个print那里,我们惊讶的发现print(t1 + 2)竟然也可以运算!这得益于broadcast机制,等同于最后一个print中的语句,pytorch将2自动扩充成了与t1形状相同的tensor,这样当然就可以运算啦.
那么有个大胆的想法,除了常数,不同形状的tensor是否也可以这样操作,举个栗子试一试:
t1 = torch.tensor([
[1, 2],
[3, 4]
])
t3 = torch.tensor([2, 4])
print(t1 + t3)
print(t1 + torch.tensor(
np.broadcast_to(t3.numpy(), t1.shape),
dtype=torch.int32
)) # equal to last line
output:
tensor([[3, 6],
[5, 8]])
tensor([[3, 6],
[5, 8]])
哈哈,果然是可以的,pytorch将t3自动扩充成了与t1形状相同的tensor(将t3又复制了一行)再与之运算. 这个特性很有意思,很方便我们书写代码.