这次展示下孪生网络的模型训练示例,先看下使用的数据:
数据使用的是att_faces人脸数据,共40人,每人10张照片,我们最终的预测目标是根据一张人脸,找到其最相似的多张图片。
上述代码是对样本的处理,init函数中我们从样本文件中提取所有样本信息,这里包括两张图片的路径,以及label(0=同一人,1=不同人),getitem则是对每个样本进行处理,主要是图像的各种变换处理。
这是模型的网络层,其中包含4个卷积层、2个全连接层,代码中有具体注释。
这是损失函数,由于是孪生网络,所以对比的是两张图片的距离。
上述就是主要的处理函数,下面看下训练代码:
这里大致说一下逻辑,每个epoch是一个完整样本集的训练,每次训练用一个batch_size的数据集进行模型参数更新,这里在每次epoch后会对测试集进行损失计算,并根据训练集和测试集损失函数变换确定最终的epoch。
这是训练50个epoch后每个epoch对应的训练集和测试集损失函数值,从图中9以后测试集不再变小,训练集减小幅度也很小了,这里就取epoch=8为当前的最优值。
本节内容就这些,想下载源码的可以微信公众号回复:孪生网络模型训练
公众号名称:桔子的算法之路