Generate To Adapt: Aligning Domains using Generative Adversarial Networks
作者所属机构 :UMIACS, University of Maryland, College Park(马里兰大学帕克分校)
(1) Digit classification (MNIST, SVHN and USPS datasets
(2) Object recognition using OFFICE dataset
(3) Domain adaptation from synthetic to real dataset
论文实验代码:pytorch实现 GitHub - yogeshbalaji/Generate_To_Adapt: Implementation of "Generate To Adapt: Aligning Domains using Generative Adversarial Networks"
文中用到了条件GAN的变种AC-GAN。
直观理解:stream2的作用是让F生成的原域和目标域的特征尽可能的接近,这样用分类器C去分类目标域数据时,被分对的可能性就很大了。
1.训练阶段
训练阶段分为两个网络: F-C网络以及F-G-D网络
我比较在意的是,用F-G-D网络有什么用?
我自己的理解是,这个网络更新F的参数,用更新后的F参数,可能会使特征提取器提取出的特征更好,但是好在哪儿呢?
F-C网络就是传统的特征提取器,分类器网络,对带标签的原域数据它进行训练,训练后更新F和C的参数。
损失函数如下:
F-G-D网络就比较复杂:
首先,G的输入由三部分组成,其中F(x)有两个来源,一个是原域样本经过F后提取出的特征,一个是目标域样本经过F后提取出的特征。G用来生成source-like的图像,生成的图像输入D进行判断,同时D的输入还有原域图像和目标域图像。
关于这个D,输出两个分布,一个是Ddata(x),用于判别样本是生成器G生成的还是真实的 ,可以看成是个二分类器,另一个是Dcls(x),用来判断输入图片是属于哪一个类的,可以看成是个多分类器。
关于G和D的训练,若输入目标域数据,D同样输出两个分布,但是由于不知道类标签,因此只有Ddata用于反向传播梯度,来更新D。若输入原域数据,由于知道类标签,因而Ddata和Dcls都用于反向传播,它们俩都去更新D。需要注意的是,由于Dcls是和类标签有关的,此时还可以用Dcls反向传播来更新F。
此时,更新D的公式如下:
判别器D判别目标域生成的图片为假的,故用Ladv,tgt来更新判别器。
第一项是用原域数据来更新D.
第二项是用原域数据来更新D.
第三项是用目标域数据来更新D,希望更新后的D能把目标域数据生成的source-like图像判定为假的。
更新G:
原域数据作为G的输入,用来自D的对抗损失和分类损失来更新:
用原域数据来更新G。
关于更新F,一个是用F-C网络来更新,另一个是用F-G-D网络来自D的对抗梯度来更新。
F-C网络用对带标签的原域数据它进行训练,训练后更新F的参数,即用Lc来更新。
第一项是在F-C网络下,用原域数据来更新F。
第二项是在F-G-D网络下,用原域数据来更新F。
第三项是在F-G-D网络下,用目标域数据来更新F,希望更新后的D能把目标域数据生成的source-like图像判定为真的。
为啥用第三项LFadv来更新F?
这一项是D认为目标域生成的source-like图像是真实的。
这里就构成了F与D的对抗,因为F想让判别器判断目标域生成的source-like图像是真实的,而D则根据自身的使命,会判断目标域生成的source-like图像是假的。
所以,文中有两个对抗,一个是F与D的,另一个则是G和D的对抗。
F想让判别器判断目标域生成的source-like图像是真实的,则F必须改变自身,使得提取出的目标域的embedding和提取出的原域的embedding尽可能的像,只有二者像,才有可能让生成器生成出一个能欺骗判别器的source-like图像。
关于更新Ddata,用最小化损失Ldata,sc来实现。关于更新Dcls,用最小化损失Lcls,src来实现。