上回我们讲完了张量的维度,创建方法以及形状,这一章我们来讲讲张量的索引、切片以及关于切片的相关方法
1.3 索引&切片(indics & slices)
张量的索引和python原生数据结构的索引相同,都是从0开始;也可以通过传入负数从后往前取值
通过索引我们可以取出张量内对应的元素,由于上一章对向量,矩阵以及高维张量分别进行了讨论,这里也按照这个思路对不同维度的张量进行讨论,首先先说说向量的索引
1.3.1 向量的索引和切片
pytorch里的向量其实就是一维张量,对于一维张量,我们可以直接用[star:end:step]
取出对应的元素,下面是案例
# for vector
t1 = torch.arange(9)
print(t1[0]) # tensor(0)
print(t1[0].dtype) # torch.int64
print(t1[0:-1:2]) # tensor([0, 2, 4, 6])
这里要注两点:
- 我们可以取出某一个元素,这个元素被取出后仍然是张量类型【使用
dtype()
查看】 - 传入索引参数时可以传入负数,代表从后往前取值,但是
step
参数必须为正数
1.3.2 矩阵的索引和切片
pytorch里的矩阵就是二维张量,区别于向量,矩阵有行和列,我们在对矩阵切片时可以按照[row, column]
来切
对于row
或column
参数,我们都可以按照[star:end:step]
来传入,灵活的切取我们想要的元素
t1 = t1.reshape([3,3])
t1
'''
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
'''
# slices for matrix
# select the first row
print(t1[0]) # tensor([0, 1, 2])
# select the first column
print(t1[:,0]) # tensor([0, 3, 6])
# select the first item
print(t1[0,0]) # tensor(0)
# select the third items from first and third rows
print(t1[[0,2],2]) # tensor([2, 8])
print(t1[::2,::2])
'''
tensor([[0, 2],
[6, 8]])
'''
这里要注意以下两点:
- 传入参数时先传的是行参数,再传入列参数
- 行列参数间用","分开,而对于单个行或列参数要通过切片取值则用":",注意不能
:
和,
混着用
1.3.3 高维张量的索引和切片
对于高纬张量,因为不只有1个维度,我们就用1维,2维...n维来说
高维张量的取值也不难,其实就是按照维度从高到低一层一层传入参数取值
['dimension(n))','dimension(n-1)',...'dimension1']
下面以三维张量为例
# for three dimensions
t2 = torch.arange(27).reshape([3,3,3])
print(t2.ndim) # 3
print(t2)
'''
tensor([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]]])
'''
# 取出第一个矩阵中第1行第1列的元素
print(t2[0,0,0]) # tensor(0)
print(t2[::2,::2,::2]) # 取出第一第三个矩阵总的第1、3行、第1、3列的元素
'''
tensor([[[ 0, 2],
[ 6, 8]],
[[18, 20],
[24, 26]]])
'''
print(t2[1,::,::]) # 取出第2个矩阵
'''
tensor([[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]])
'''
总结一下
其实张量的切片就是按照各个维度从高到低来进行切分,选择各个维度下想要的值即可
这时候有小伙伴想问了,如果我想取出张量中的某一块,或者把一个张量分成两个或几个张量,那我需要一个一个传入索引取吗?那不就很麻烦???
这个方法可以实现,但其实torch里给我们提供了相关函数来实现
1.3.4 切片相关函数
在pytorch下,我们可以使用chunk()
和split()
函数来实现张量的分块和切分,下面就来逐个介绍下
1. split函数
split(tensor,section,dim)
对于split函数,第一个参数传入需要切分的张量,第二个参数传入需要怎么切,第三个参数传入切分的维度
第一和第三个参数都好理解,关键是第二个参数,第二个参数可以理解成切的份数,传入[1,2]就表示我要把传入的张量切成2部分,第一个张量有1份,第二个张量有2份
如果只传入1个数n,split()
就会帮我们把传入的张量进行均匀切分,切分后每个张量都包含n份
torch.split(t2,[2,1],dim=0) # 切成两个张量,第一个张量包含2份,第二个张量包含1份
'''
(tensor([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]]),
tensor([[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]]]))
'''
t1,t3,t4 = torch.split(t2,1,dim=0) # 均匀切分,每个张量包含1份
'''
(tensor([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]]),
tensor([[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]]),
tensor([[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]]]))
'''
【注意】
-
split()
函数返回的是对应张量切分后的视图,也就是说跟远张量共享一份内存空间【浅拷贝】,修改切分后的张量的值,则对应的原张量的值也会发生改变
2. chunk函数
chunk本身就是分块的意思,顾名思义,这个函数就是把传入的张量均匀分成若干块
chunk(tensor,chunks,dim)
chunk()
函数的第一、三个参数跟split()
一样,都是传入的张量对象和切分的维度,第二个参数略有不同,chunk()
方法不能自主选择切分的份额,都为均匀切分,传入的参数表示要均匀分成几份【分成3份则传入3】
t3
'''
tensor([[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]])
'''
torch.chunk(t3,3,dim=1) # dim=1表示列维度消失,按照行切分
'''
(tensor([[[ 9, 10, 11]]]), tensor([[[12, 13, 14]]]), tensor([[[15, 16, 17]]]))
'''
【注意】
- 和
split()
函数类似,chunk()
函数切分后的张量是原传入张量的视图(view) - 对于不能均匀切分的张量,就返回次一级切分方法的结果【比如,不能均匀切3份的时候就会切成2份】
那万一我不想切了,我就想把俩个或多个张量合成一个张量那咋办呢?
别急,cat()
和stack()
函数可以帮助我们
3. cat函数
cat([t1,t2...],dim)
cat()
方法需要传入两个参数,一个是由需要拼接的张量组成的列表,第二个是拼接的维度
torch.cat([t1,t3],dim=0) # dim=0表示行维度消失,按照列进行拼接
'''
tensor([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]])
'''
torch.cat([t1,t3],dim=1) # dim=1表示列维度消失,按照行拼接
'''
tensor([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]])
'''
【注意】
-
cat()
方法默认的维度是0dim=0
- 当拼接的两张量形状不匹配时会报错
4. stack函数
stack([t1,t2...])
torch.stack([t1,t3])
'''
tensor([[[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]]],
[[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]]])
'''
细心的小伙伴可能已经发现,虽然stack()
和cat()
函数都能对张量进行拼接,但返回的张量的维度是不一样的
cat()
函数返回的张量的维度和参与拼接的张量的维度是相同的;而stack()
则升了一维
这是因为stack()
方法是把参与拼接的张量拿出来,封装到一个更高维度的张量中
1.4 升维/降维
stack()
方法拼接后,对比原来个张量升了一个维度,那如果我们就单纯想对原张量进行升维或者降维需要怎么办呢?
pytorch中的squeeze()
和unsqueeze()
可以帮助我们
squeeze就是压缩;压榨的意思,很容易理解,当我们需要降维时,就使用squeeze()
方法;相反,需要升维时,就使用unsqueeze()
方法
squeeze(t)
方法会把所有为1的维度删除,不为1的维度没有影响,要查看张量中有哪些维度为1,我们可以调用t.shape
查看
torch.squeeze(t1)
print(t1.shape) # torch.Size([1, 3, 3])
print(torch.squeeze(t1).shape) # torch.Size([3, 3])
unsqueeze(t,dim)
需要传入维度参数,表示在哪个维度上升维度,同样,我们可以使用shape
查看
print(torch.unsqueeze(t1,dim=2).shape) # torch.Size([1, 3, 1, 3])
print(torch.unsqueeze(t1,dim=1).shape) # torch.Size([1, 1, 3, 3])
【注意】
-
shape
返回的元素有几个那就是几维度,也可以使用ndim
来查看 -
unsqueeze
指定的dim
参数不一样,则返回的高维的张量时不一样的
总结一下
- 在对张量进行切片时,我们可以通过再各个维度按照
(star:end:steps)
传入索引和步长来对切出我们想要的元素 -
split()
和chunk()
方法可以帮我们更高效的切分张量 -
cat()
和stack()
方法可以帮我们合并张量,注意两种方法返回张量维度的区别 -
squeeze()
和unsqueeze()
方法可以帮助我们手动对张量进行升维和降维,注意squeeze()
只会删去为1的维度
掌握以上思路和方法后,相信小伙伴们对张量的切分应该不存在太大的问题啦~
往期文章
算法小白的pytorch笔记--chapter1 Tensor张量
参考
[1] pytorch中torch.cat(),torch.chunk(),torch.split()函数的使用方法
[2] pytorch官方文档
[3] pytorch中的dim说明