Where的官方文档 可从此查到
具体用法:
torch.where(condition, x, y) → Tensor
具体的意思可以理解为:针对于x而言,如果其中的每个元素都满足condition,就返回x的值;如果不满足condition,就将y对应位置的元素或者y的值(如果y为氮元素tensor的话)替换x的值,最后返回结果。
可如下面的例子得到验证。
>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620, 0.3139],
[ 0.3898, -0.7197],
[ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000, 0.3139],
[ 0.3898, 1.0000],
[ 0.0478, 1.0000]])
x = torch.randn(3, 2)
y=torch.tensor(3.0)
print(x)
print(torch.where(x > 0, x, y))
tensor([[ 0.9346, -1.5232],
[-0.1766, 1.0083],
[ 1.1510, -0.2411]])
tensor([[0.9346, 3.0000],
[3.0000, 1.0083],
[1.1510, 3.0000]])
在阅读代码的过程中发现了一个高级的用法,在此详述一下:
import torch
rows_with_eos = torch.zeros(5)
rows_with_eos[2]=1
输出:tensor([0., 0., 1., 0., 0.])
p = torch.where((rows_with_eos==0).unsqueeze(1),torch.tensor([-1.0]),rows_with_eos)
输出:
tensor([[-1., -1., -1., -1., -1.],
[-1., -1., -1., -1., -1.],
[ 0., 0., 1., 0., 0.],
[-1., -1., -1., -1., -1.],
[-1., -1., -1., -1., -1.]])
(rows_with_eos==0).unsqueeze(1)
输出:
tensor([[1],
[1],
[0],
[1],
[1]], dtype=torch.uint8)
我对于这段代码的理解是先去判断rows_with_eos中的每个元素是否为0(condition部分)(即逐个去比较,无论是否为0都进行unsqueeze操作);
如果满足条件,则用torch.tensor([-1.0])去替换,但是为什么最后结果是个5x5的矩阵而不是5x1的呢?
我试过用torch.tensor([-1.0,-1.0,-1.0,-1.0,-1.0,-1.0])去替换torch.tensor([-1.0]),但是报错认为这个6维的向量和rows_with_eos这个5维向量不匹配
说明一个问题其实替换的是要求torch.where(condition, x, y)
中x和y的维度要一样,1维的默认扩张成为与另一个元素维度相同的情况,所以这个问题也就迎刃而解了。