索引与切片
import torch
a = torch.rand(4, 3, 28, 28)
print(a.shape)
print(a[0].shape)
print(a[0][0].shape)
print(a[0, 0, 2, 4])
print(a[:2].shape)
print(a[:2, :1,:,:].shape)
print(a[:2, 1:,:,:].shape)
print(a[:2, -1:,:,:].shape)
print(a[:,:,0:28:2,0:28:2].shape)
print(a[:,:,::2,::2].shape)
输出结果:
torch.Size([4, 3, 28, 28])
torch.Size([3, 28, 28])
torch.Size([28, 28])
tensor(0.9962)
torch.Size([2, 3, 28, 28])
torch.Size([2, 1, 28, 28])
torch.Size([2, 2, 28, 28])
torch.Size([2, 1, 28, 28])
torch.Size([4, 3, 14, 14])
torch.Size([4, 3, 14, 14])
torch.index_select, ..., torch.masked_select, torch.take 用法
import torch
a = torch.rand(4, 3, 28, 28)
print(a.shape)
print(a.index_select(0,index=torch.tensor([0,2])).shape) #index参数接收数值必须是tensor
print(a.index_select(2,torch.arange(8)).shape)
print(a[...].shape) # ...代表任意多的‘:,’配对
print(a[0,...].shape)
print(a[:,1,...].shape)
print(a[...,:2].shape)
x = torch.randn(3,4)
print(x)
mask = x.ge(0.5) # 大于等于0.5的元素置为1
print(mask)
print(torch.masked_select(x, mask=mask)) # 取出大于等于0.5的元素
src = torch.tensor([[4,3,5],
[6,7,8]])
print(torch.take(src, torch.tensor([0,2]))) # 先把多维tensor打平为一维,然后按照索引取值
结果:
torch.Size([4, 3, 28, 28])
torch.Size([2, 3, 28, 28])
torch.Size([4, 3, 8, 28])
torch.Size([4, 3, 28, 28])
torch.Size([3, 28, 28])
torch.Size([4, 28, 28])
torch.Size([4, 3, 28, 2])
tensor([[ 0.0064, 0.2761, 0.9907, -0.7130],
[ 0.4022, -1.2620, 1.0982, -0.1781],
[ 0.1812, 0.1491, 0.5287, -1.0477]])
tensor([[0, 0, 1, 0],
[0, 0, 1, 0],
[0, 0, 1, 0]], dtype=torch.uint8)
tensor([0.9907, 1.0982, 0.5287])
tensor([4, 5])