简介
文献Coupled generative adversarial networks的官方开源,源码阅读解析。开源网址:https://github.com/mingyuliutw/CoGAN/tree/master/cogan_pytorch
这里解析的是pytorch版本,只对域转换相关的部分代码进行解析。CoGAN的简介见https://www.jianshu.com/p/5b31cc80e3a2。
关键字
CoGAN,开源代码
文件及程序结构
域转换相关的代码涉及9个文件,2个是数据集,2个配置文件,1个配置解析文件,1个网络结构定义,1个网络参数初始化,1个训练器,1个训练过程文件。大体分为4部分,具体如下。
1. 数据部分
两个文件,一个是dataset_mnist.py,dataset_usps.py,这两个文件的功能是把mnist,usps两个数据集包装成dataset数据类的格式,在使用时需要在把dataset封装成dataloader。
2. 配置部分
相关的3个文件,其中有两个文件是训练时的配置文件,分别是mnist2usps_full_cogan.yaml、usps2mnist_full_cogan.yaml,这两个分别对应mnist与usps相互转换过程中的配置,主要参数有max_iter(最大迭代次数)、batch_size(每批样本数量)、latent_dim(特征维度)、cls_weight(类损失平衡系数)、mse_weight(高层特征的欧式距离损失)等。还有一个net_config.py用来解析上面的两个配置文件。
3. 网络结构
net_cogan_mnist2usps.py定义了网络结构,包括生成器、判别器、分类器等网络的结构。
判别器共4层,分别是:卷积+池化(不共享参数)、卷积+池化(共享)、卷积+激活(共享)、卷积(共享)。输入图像样本,输出2维的向量,对应真假样本。最后2维向量与标签计算交叉熵。
分类器也是4层,前3层与判别器相同,最后一层卷积输出向量是10维的,对应数字集10分类。最后10维向量与标签计算交叉熵。
生成器5层,前4层每层都是由反卷积、批标准化、激活函数构成,都是共享参数的,最后1层是反卷积+Sigmoid,用来生成归一化的图像,不共享参数。
4. 网络结构
剩下3个对应训练过程,init.py用来初始化网络参数,有高斯分布、Xavier两种方式;trainer_cogan_mnist2usps.py打包了判别器、生成器各自更新一次参数的过程;train_cogan_mnist2usps.py定义了完整的训练过程。
参考资料
[1] Liu M Y, Tuzel O. Coupled generative adversarial networks[C]//Advances in neural information processing systems. 2016: 469-477.
[2] https://github.com/mingyuliutw/CoGAN
[3] https://www.jianshu.com/p/5b31cc80e3a2