1 卷积网络流程图
卷积网络网络图解.png
一个简单的卷积神经网络对于给定的输入图片,经过卷积池化过后得到feature map 然后再拉成一列,经过若干个隐含层得到FC,然后得到输出Z,Z的神经元个数取决于分类类别的个数,然后再经过softmax层得到每个类别的概率,最后经过交叉熵损失函数进行梯度回传。
关于卷积神经网络网上有很多优秀的文章,暂不细讲。图中只给出了网络中最后的一部分,即跟本文内容有关的部分,也就是softmax层和交叉熵。
2 softmax 和 交叉熵
关于softmax和交叉熵的历史,百度会有一堆文章,这里暂且不提,本文着重探讨公式的推导过程和代码验证。
- softmax公式如下:
式中,表示第
个输出,
表示
对应的softmax值,
为类别数。
- 交叉熵损失函数公式如下:
式中,表示第
个类的标签或者说真实值。
3 求导
在一个网络中,参数需要损失函数对
求负梯度来更新,也就是
,根据链式求导法则
,因此,需要先求
。同理,
.
在求导中,需要分成两步,下面式中的表示损失函数
对第
个输出求导。
-
,也就是对应
中的
:
-
,也就是对应
中下标不等于
的部分
综上:
由于在一般的分类任务中,标签一般使用one-hot编码,例如:[0,0,0,1,0,0]表示分类为第4类的编码,因此
(此处
除了是one-hot编码外也可以是一个概率分布,满足和为1即可),所以,
这个公式的意义就是,损失函数对输出
的导数就等于经过softmax后的输出
减去标签
。
4 在pytorch中的代码验证
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
x = torch.randn(1,5,requires_grad = True)#随机生成一个size为(1,5)的数据
y = torch.tensor([3]).long()#目标是属于第三类,若是one-hot编码应为[0,0,0,1,0,]
prob = F.softmax(x.float(),dim=1)
loss = nn.CrossEntropyLoss()(x,y)#注意,pytorch中的CrossEntropyLoss函数输入为类别下标和实际输出,该函数会自动进行one-hot编码和softmax计算
print("单样本输出为:"+ str(x))
print("#"*20+"pytorch中的结果"+"#"*20)
print("softmax:"+str(prob.data.numpy()))
print("loss:"+str(loss.data.numpy()))
loss.backward()#求导
print("grad:"+str(x.grad.data.numpy()))
print("\n")
softmax_=np.exp(x.data.numpy().squeeze())/np.sum(np.exp(x.data.numpy()))
loss_ = -np.log(softmax_[3])#因为one-hot编码中标签除了目标类为1,其他的标签都为0,因此loss_= -np.log(softmax_[3])
grad = [item-(i==3) for i,item in enumerate(softmax_)]#根据公式可知,梯度只需要对第三类的softmax输出减去标签1即可,因为其他类的标签为0,就只是softmax的输出值
print("#"*20+"自己的结果"+"#"*20)
print("softmax:"+ str(softmax_))
print("loss:"+str(loss_))
print("grad:"+str(grad))
'''
单样本输出为:tensor([[-2.0777, 0.3450, 0.7605, 0.2708, -0.3969]], requires_grad=True)
####################pytorch中的结果####################
softmax:[[0.02212428 0.24947256 0.3779738 0.2316277 0.11880173]]
loss:1.462624
grad:[[ 0.02212428 0.24947257 0.3779738 -0.7683723 0.11880173]]
####################自己的结果####################
softmax:[0.02212428 0.24947256 0.37797377 0.2316277 0.11880171]
loss:1.462624
grad:[0.022124277, 0.24947256, 0.37797377, -0.7683723, 0.11880171]
'''
最后提一句,在得到之后,根据
便可对参数进行求导,从而更新参数,后续有时间再写一篇文章讨论。
如有错误,欢迎指正!