https://pytorch.org/docs/stable/generated/torch.gather.html
一个简单的例子:
t = torch.rand(2,3)
"""
tensor([[0.8133, 0.5586, 0.7917],
[0.0551, 0.2322, 0.9087]])
"""
t.gather(dim=0,index=torch.tensor([[0,1,0],[1,0,1]]))
"""
tensor([[0.8133, 0.2322, 0.7917],
[0.0551, 0.5586, 0.9087]])
"""
- dim = 0,说明index中所有索引均是索引行。
- 关于index的shape:dim=1时,input.shape[0]=index.shape[0], 同理, 可推dim=0时,input.shape[1]=index.shape[1]
# 常用于以下需求:
# celoss = torch.tensor([i_s[i_t] for i_s,i_t in zip(softmax,target)])
input = torch.randn(3, 5, requires_grad=True) # (3,5)
n_samples = input.shape[0] # 注意dim=1时,input.shape[0]=index.shape[0], 同理, 可推dim=0时,input.shape[1]=index.shape[1]
channel = 6
idx = torch.randint(low=0,high=5,size=(n_samples*channel,)).reshape(n_samples,channel)
"""
tensor([[0, 0, 4, 2, 3, 1], 第一行取第0个,第0个,第4个...
[3, 3, 1, 0, 2, 2], 第二行取第3个,第3个,第1个...
[4, 4, 4, 2, 1, 3]]) ...
"""
input.gather(dim=1,index=idx) # torch.Size([3, 6])