import torch
x1 = torch.arange(6).view(2,3)
x2 = torch.ones(3,2)
x2[:1]+=1
torch.mm(x1, x2)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-55-25db65725f06> in <module>
----> 1 describe(torch.mm(x1, x2))
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #3 'mat2' in call to _th_addmm_out
问题的原因是:
x1的tensor类型是Long, x2的tensor类型是Float
把x1 = torch.arange(6).view(2,3)改成x1 = torch.arange(6.0).view(2,3)即可