pytorch实战示例-模型训练

这次展示下孪生网络的模型训练示例,先看下使用的数据:

数据使用的是att_faces人脸数据,共40人,每人10张照片,我们最终的预测目标是根据一张人脸,找到其最相似的多张图片。

上述代码是对样本的处理,init函数中我们从样本文件中提取所有样本信息,这里包括两张图片的路径,以及label(0=同一人,1=不同人),getitem则是对每个样本进行处理,主要是图像的各种变换处理。

这是模型的网络层,其中包含4个卷积层、2个全连接层,代码中有具体注释。

这是损失函数,由于是孪生网络,所以对比的是两张图片的距离。

上述就是主要的处理函数,下面看下训练代码:

这里大致说一下逻辑,每个epoch是一个完整样本集的训练,每次训练用一个batch_size的数据集进行模型参数更新,这里在每次epoch后会对测试集进行损失计算,并根据训练集和测试集损失函数变换确定最终的epoch。

这是训练50个epoch后每个epoch对应的训练集和测试集损失函数值,从图中9以后测试集不再变小,训练集减小幅度也很小了,这里就取epoch=8为当前的最优值。

本节内容就这些,想下载源码的可以微信公众号回复:孪生网络模型训练

公众号名称:桔子的算法之路

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。