Torch 常见的坑

1、ValueError: Target size (torch.Size([64])) must be the same as input size (torch.Size([64, 1]))

计算loss时出现的错误,预测的batch维度与真实的batch维度不同,要压缩一个,具体压缩那个看输出信息,(一般压缩预测)     predict = predict.squeeze(-1)       

2、RuntimeError: Expected object of type torch.FloatTensor but found type torch.LongTensor for argument #2 'other'

计算loss时出现的错误,预测的type为FloatTensor,真实的type为LongTensor,在生成batch数据时用的是torch.LongTensor,例如

                                            Train_text = torch.LongTensor(train_text)

                                            Train_label = torch.LongTensor(train_label)

如果是二分类的话,因为用sigmoid函数将其压缩到0-1之间的小数,用LongTensor会报错应改为:

                                            Train_label = torch.FloatTensor(train_label)

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容