在动手学习深度学习中学到了一个函数gather,原文是说可以通过gather得到标签的预测概率。
y_hat = torch.tensor([[0.1,0.3,0.6],[0.3,0.2,0.5]])
y = torch.LongTensor([0,2])
y_hat.gather(1,y.view(-1,1))
tensor([[0.1000],
[0.5000]])
开始我看到这个输出一头雾水 不知道怎么回事
查了查 gather的时候我才知道
torch.gather(input,dim,index,out=None)
example:
t = torch.Tensor([1,2],[3,4])
torch.gather(t,1,torchLongTensor([[0,0],[1,0]]))
1,1
4,3
可以看出gather的作用是根据索引返回该项元素,首先先输入一个Tensor 然后根据dim进行判断是是行的还是列的,当dim=0 时候竖行查找,当dim=1的时候是横向查找
上题中,dim=1,那么索引就是列号。index的大小就是输出的大小,比如index是[1,0;0,0]其实就是第一行的第二个元素和第一个元素,第二行的第一个元素也就是返回的是2,1 3,3
所以例子中是[0,0],[1,0]
返回的就是[1,1],[4,3]
在例题中的他是通过view
函数来返回index的,开始不知道view的意思,查过后知道了,他实际上和resize的意思差不多。
a = torch.Tensor([[1,2,3],[4,5,6]])
b = torch.Tensor([1,2,3,4,5,6])print(a.view(1,6))
print(b.view(1,6))
得到的都是
tensor([[1,2,3,4,5,6]])
再看一个例子
a = torch.Tensor([[1,2,3],[4,5,6]])
print(a.view(3,2))
将会得到
tensor([[1,2],
[3,4],
[5,6]
])
相当于就是从1,2,3,4,5,6 顺序的拿数组来填充需要的形状。
参数中的-1就代表这个位置由其他位置的数字来进行推断,只要不在歧义的情况下,view参数就可以推断出来,也就是人可以推断出形状的情况下,view也是可以推断出来的,比如a tensor
的数据个数是6个,如果view(1,-1)我们就可以推断出来-1代表6。而如果view(-1,-1,2)的话,人也不知道的话,机器也不会知道的,所以就会报错