简介
根据CoGAN作者的官方开源实现了CoGAN的数字集域适应任务。
开源地址:https://github.com/mingyuliutw/CoGAN,。
开源代码的分析见机器学习·CoGAN开源解析。
关键字
CoGAN、域适应、深度学习、神经网络、机器学习、复现
正文
1. 概述
前天解析了大神开源,原来打算把作者大神的官方代码拿来跑跑看看结果就行,结果由于本地环境与开源代码的环境差异有点大,各种尝试未果,决定参考官方代码自行实现。按照昨天总结深度学习算法实现的一般套路(pytorch)的结构,按照数据处理,模型定义,参数更新,模型选择4部分来复现。本地环境:python3.7.5,pytorch1.3.1,torchvision0.4.2;1080Ti,cuda10。
2. 数据
数字集两个mnist,usps,第一个mnist使用torchvision.datasets自带的MNIST是28*28的,直接使用就好,其中训练集样本数量60000,测试集样本数量10000;第二个usps集在torchvision.datasets里也有,但是早期torchvision.datasets版本是没有的,在torchvision.datasets内的usps是16*16的。作者使用的是自己处理好的usps数据集,以及处理好28*28,可以从作者官方开源的网址获取,其中训练集样本数量7438,测试集样本数量1860。
3. 实现概要
一共9个文件,分述如下:
config.py:配置,主要包括迭代次数,快照间隔,批数量,损失平衡系数,任务选择,随机种子等。
get_dataloader.py:读取数据,输出3个dataloader对象,分别是A域训练集(含标签),B域训练集(不含标签),B域测试集(不含标签)。
init.py:对模型参数的初始化,定义了两种,高斯和Xavier,供Trainer内对模型进行初始化,这边只用了Xavier的初始化方式。
main.py:主函数,读取config.py,然后执行slover.py就完了。
model.py:基本照搬官方开源代码,提供了模型的网络结构。
solver.py:参考开源自行实现。作用是声明Trainer对象,使用Trainer的更新参数行为训练判别器、生成器、分类器。按照快照间隔使用B域测试集测试分类准确率。直到最大迭代次数。
trainer.py:包装了模型及其模型的行为,也就是一次参数更新。
usps.py:基本照搬官方开源代码,对usps数据源进行处理,形成dataset类型的数据对象。
utils.py:按照随机种子,固定程序的一些随机行为。
4. 结果
任务:在A域训练集(含标签),B域训练集(不含标签)上训练,在B域(不含标签)的测试集上测试,指标为分类准确率。做了2个任务在mnist上训练,在usps上测试,最高准确率为95.97%;在usps上训练,在mnist上测试,最高准确率为94.34%。最大迭代次数10000,每100次进行一次测试,如图,上图是m2u,下图是u2m,纵轴是准确率,单位是1%;横轴是次数,单位是10^2。
参考资料
[1] Liu M Y, Tuzel O. Coupled generative adversarial networks[C]//Advances in neural information processing systems. 2016: 469-477.
[2] https://github.com/mingyuliutw/CoGAN
[3] https://www.jianshu.com/p/5b31cc80e3a2
[4] https://www.jianshu.com/p/c02b5567655e
[5] https://www.jianshu.com/p/a534e0e50d10