简介
Improved training of wasserstein gans.NIPS-2017,Cited-2438.放个pytorch的开源:wgan-gp,这大佬的仓库里一堆的gan及变种,Star-5.5k。
关键字
GAN,WGAN,WGAN-GP,EM距离,生成模型,深度学习,机器学习
正文
1. 存在问题和解决思路
WGAN使用可以在度量分布差异上处处连续的EM距离解决了GAN由于JS散度不连续造成的难以训练问题,解决问题的方式是对判别器的参数范围作clip,即限制参数的值只在
之内,然而这样的限制还存在梯度爆炸和梯度消失问题,这是由于clip方法会让这些梯度集中在clip的边界值
上的原因。
本文献提出的方法是不对参数作clip,而是采取更加温和的方法,对超过利普西茨常数的梯度作惩罚。利普西茨连续要求在所有可能的
上具有利普西茨连续,这显然做不到,那就取部分有代表性的样本
作惩罚作近似就好。
由于使用的是最优传输,相当于把样本属于的样本
搬到属于
的样本
位置上,所以只要求处于
之间的样本
满足利普西茨条件就好,所以每次在更新判别器
的参数时,从
样本对之间进行抽样
计算损失,加到判别器原有损失上就ok了。
2. 目标函数和算法
先看目标函数:
目标函数的第二项惩罚梯度,需要梯度越小越好,但也不能太小,所以就要求梯度都接近1或-1,最后看下算法(来自文献Alogorithm1):
备注:由于在WGAN-GP中是对每个样本独立地施加梯度惩罚,所以判别器的模型架构中不能使用Batch Normalization, 因为它会引入同个batch中不同样本的相互依赖关系。
3. 实验效果
关于效果,就贴个WGAN与WGAN-GP的效果(图来自文献Figure1),左边图是关于模式坍缩,上面一行是WGAN的,难以学到多峰的效果,陷入模式坍缩,下面一行是WGAN-GP的效果,在toy数据集上较好的解决了问题;右边图是关于梯度消失和爆炸的,就中间的GP比较稳定:
参考资料
[1] Gulrajani, Ishaan, et al. "Improved training of wasserstein gans." Advances in neural information processing systems. 2017.
[2] http://www.twistedwg.com/2018/02/02/WGAN-GP.html
[3] https://www.jianshu.com/p/7801f9f917d9
[4] https://www.jianshu.com/p/c000b27775cc