Pytorch-Where

Where的官方文档 可从此查到

具体用法:

torch.where(condition, x, y) → Tensor

具体的意思可以理解为:针对于x而言,如果其中的每个元素都满足condition,就返回x的值;如果不满足condition,就将y对应位置的元素或者y的值(如果y为氮元素tensor的话)替换x的值,最后返回结果。
可如下面的例子得到验证。

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],
        [ 0.3898, -0.7197],
        [ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000,  0.3139],
        [ 0.3898,  1.0000],
        [ 0.0478,  1.0000]])
x = torch.randn(3, 2)
y=torch.tensor(3.0)
print(x)
print(torch.where(x > 0, x, y))

tensor([[ 0.9346, -1.5232],
        [-0.1766,  1.0083],
        [ 1.1510, -0.2411]])
tensor([[0.9346, 3.0000],
        [3.0000, 1.0083],
        [1.1510, 3.0000]])

在阅读代码的过程中发现了一个高级的用法,在此详述一下:

import torch
rows_with_eos = torch.zeros(5)
rows_with_eos[2]=1
输出:tensor([0., 0., 1., 0., 0.])
p = torch.where((rows_with_eos==0).unsqueeze(1),torch.tensor([-1.0]),rows_with_eos)
输出:
tensor([[-1., -1., -1., -1., -1.],
        [-1., -1., -1., -1., -1.],
        [ 0.,  0.,  1.,  0.,  0.],
        [-1., -1., -1., -1., -1.],
        [-1., -1., -1., -1., -1.]])
(rows_with_eos==0).unsqueeze(1)
输出:
tensor([[1],
        [1],
        [0],
        [1],
        [1]], dtype=torch.uint8)

我对于这段代码的理解是先去判断rows_with_eos中的每个元素是否为0(condition部分)(即逐个去比较,无论是否为0都进行unsqueeze操作);
如果满足条件,则用torch.tensor([-1.0])去替换,但是为什么最后结果是个5x5的矩阵而不是5x1的呢?
我试过用torch.tensor([-1.0,-1.0,-1.0,-1.0,-1.0,-1.0])去替换torch.tensor([-1.0]),但是报错认为这个6维的向量和rows_with_eos这个5维向量不匹配
说明一个问题其实替换的是要求torch.where(condition, x, y)中x和y的维度要一样,1维的默认扩张成为与另一个元素维度相同的情况,所以这个问题也就迎刃而解了。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容