[pytorch]如何将label转化成onehot编码

之前用octave学习神经网络的时候,用逻辑回归,激活函数是sigmoid,损失函数是交叉熵损失函数,那个时候不用任何框架,需要把label转化成onehot编码:

c =[1:10]
y =(y==c)

只需要两行代码,很简单。
现在使用pytorch框架,刚开始学,情况比较复杂,废了半天时间才能把自己的数据正确导入程序(需要用固定的torch容器来装),之后训练神经网路的时候开始使用交叉熵损失函数(CrossEntropyLoss),没有发现错误,改用MSE损失函数后反而会报错。后来知道,使用交叉熵损失函数的时候会自动把label转化成onehot,所以不用手动转化,而使用MSE需要手动转化成onehot编码,转化方法如下(https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/3):

Paste_Image.png
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容