RuntimeError: Found dtype Long but expected Float

说明此时需要float型数据,但识别到long型数据,此时需要对入参和出参做一下类型转换

output=output.to(torch.float32)
target=target.to(torch.float32)

例证如下:

output =net(input)
target = variable(t.arange(0,10))

#the point
output=output.to(torch.float32)
target=target.to(torch.float32)

criterion = nn.MSELoss()
loss = criterion(output,target)

net.zero_grad()
print("反向传播之前conv1.bias的梯度")
print(net.conv1.bias.grad)
loss.backward()          #此处疑难杂症  先跳过
print("反向传播之后conv1.bias的梯度")
print(net.conv1.bias.grad)
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容