最近开始学习pytorch,在训练时出现标题所示问题。浏览了很多方法后,总结出出现这个问题的主要原因是输入的数据类型与网络参数的类型不符。
Input type为torch.cuda.FloatTensor(GPU数据类型), weight type(即net.parameters)为torch.FloatTensor(CPU数据类型)
网上资料大多数的解决方法是 将网络放到GPU上。有以下两种方法
方法一
device = torch.device('cuda:0')
net.to(device)
方法二:
net = net.cuda()
但是上述方法并不适合我碰到的问题,经过试验,最后发现自己搭建的net有一部分没有写入GPU。出现问题的原因是,在搭建net时,在forward函数中两种定义网络的方式不能混用,如下图。
错误写法
解决办法:现将几个网络层在init中定义好
更改后的__init__
然后forward改为:
更改后的forward
问题解决!!!!