第十三章 生成对抗网络

在生成对抗网络(Generative Adversarial Network)之前,VAE被认为是理论完美,实现简单,使用神经网络训练起来很稳定,生成的图片逼近度也较高,但是人类还是可以轻易区分。

但是 Ian Goodfellow 提出了生成对抗网络,最新的算法在图片生成上的效果甚至达到了肉眼难辨的程度。

13.1 博弈学习实例

GAN 网络借鉴了博弈学习的思想,分别设立了两个子网络:负责生成样本的生成器G,负责鉴别真伪的鉴别器D。

鉴别器D通过观察真实的样本和生成器G产生的样本之间的区别,学会如何鉴别真假(真实样本为真。生成器G产生的样本为假)

生成器G希望产生的样本能够骗过鉴别器D,因此生成器G通过优化自身的参数,尝试使得自己产生的样本在鉴别器D中判别为真

生成器G和鉴别器D相互博弈,直至达到平衡点(此时生成器G生成的样本非常逼真,使得鉴别器D真假难分)。

在原始的GAN论文中,Ian Goodfellow 对GAN有一个形象的比喻:
生成器网络G的功能是产生一系列非常逼真的假钞试图欺骗鉴别器D,而鉴别器D通过学习真钞和生成器G产生的假钞来掌握钞票的鉴别方法;这两个网络在相互博弈的过程中间同步提升,直到生成器G产生的假钞非常逼真,连鉴别器D都真假难辨。

13.2 GAN 原理

这部分介绍生成对抗网络的网络结构和训练方法

13.2.1 网络结构

生成对抗网络包含了两个子网络:

  • 生成网络(负责学习样本的真实分布)
  • 判别网络(负责将生成网络采样的样本与真实样本区分开来)

13.2.1.1 生成网络G(\mathbf{z})

从先验分布p_{z}()中采样隐藏变量\mathbf{z} \sim p_{z}(),通过生成网络G参数化的p_{g}(x|z)获得生成样本x\sim p_{g}(x|z)(其中隐藏变量\mathbf{z}的先验分布p_{z}()可以假设属于某中已知的分布)

p_{g}(x|z)可以用深度神经网络来参数话。

e.g. 从均匀分布p_{z}中采样出隐藏变量\mathbf{z},经过多层转置卷积层网络参数化的p_{g}(x|z)分布中采样出样本\mathbf{x}_{f}

13.2.1.2 判别网络D(\mathbf{x})

判别网络和普通的二分类网络功能类似,它接受输入样本x,包含了采样自真实数据分布p_{r}()的样本\mathbf{x}_{r}\sim p_{r}(),也同时包含了采样自生成网络的假样本x_{f}\sim p_{g}(x|z)

判别网络输出为x属于真实样本的概率P,把所有真实样本x_{r}的标签标注为真(1),所有生成网络产生的样本x_{f}标注为假(0),通过最小化判别网络预测值与标签之间的误差来优化判断网络参数。

13.2.2 网络训练

GAN 博弈学习的思想体现在它的训练方式上,由于生成器G和判别器D的优化目标不一样,不能和之前的网络训练一样,只采用一个损失函数。

对于判别网络D,它的目标是能够很好地分辨出真样本x_{r}与假样本x_{f}

以图片生成为例,它的目标是最小化图片的预测值和真实值之间的交叉熵损失函数:
\min_{\theta} \mathcal{L}=Crossentropy(D_{\theta}(x_{r}),y_{r},D_{theta}(x_{f}),y_{f})

其中,D_{\theta}(x_{r})代表真实样本x_{r}在判别网络D_{\theta}的输出;D_{\theta}(x_{f})代表生成样本x_{f}在判别网络D_{\theta}的输出,y_{r}x_{r}的标签,由于真实样本标注为真(y_{r}=1),y_{f}x_{f}的标签,由于生成样本标注为假(y_{f}=0)。

根据二分类问题的交叉熵损失函数定义:
\mathcal{L}=-\sum_{x_{r}\sim p_{r}}log\;D_{\theta}(x_{r}) - \sum_{x_{f}\sim p_{g}}log\;(1-D_{theta}(x_{f}))

判别网络的优化目标是:
\theta^{*} = \arg\min_{\theta} \mathcal{L}

\min_{\theta}\;L问题转换为\max_{\theta}-\mathcal{L},并写成期望形式:
\theta^{*}=\arg\max_{\theta}\mathbb{E}_{x_{r}\sim p_{r}}log\;D_{\theta}(x_{r})+\mathbb{E}_{x_{f}\sim p_{g}}log\;(1-D_{\theta}(x_{f}))

希望样本x_{f}在判别网络的输出越接近真实标签越好,意味着,在训练生成网络时,希望判别网络的输出D(G(z))越逼近1越好

交叉熵损失函数为:
\min_{\phi}\mathcal{L}=Crossentropy(D(G_{\phi}(z)),1) = -log\;D(G_{\phi}(z)))

\min_{\phi}\mathcal{L}问题转换成最大化问题,并写出期望形式:

\phi^{*} = \arg\max_{\phi} \mathbb{E}_{z\sim p_{z}}log\;D(G_{\phi}(z))

再次等价转化为:
\phi^{*}=\arg\min_{\phi}\mathcal{L}=\mathbb{E}_{z\sim p_{z}}log\;[1-D(G_{\phi}(z))]

13.2.3 统一目标函数

把判别网络的目标和生成网络的目标合并,写成min-max 博弈形式:
\begin{split} \min_{\phi}\max_{\theta}\mathcal{L}(D,G)=&\mathbb{E}_{x_{r}\sim p_{r}}log\;D_{\theta}(x_{r})+\mathbb{E}_{x_{f}\sim p_{g}}log\;(1-D_{\theta}(x_{f})) \\ =&\mathbb{E}_{x_{r}\sim p_{r}}log\;D_{\theta}(x_{r})+\mathbb{E}_{x_{f}\sim p_{g}}log\;(1-D_{\theta}(G_{\phi}(z))) \end{split}

13.3 DCGAN 实战

import os
import glob
import tensorflow as tf
import numpy as np
from PIL import Image
resize = 64
batch_size = 64
def preprocess(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img,[resize,resize])
    img = tf.clip_by_value(img,0,255)
    img = img / 127.5 - 1
    
    return img
img_paths = glob.glob('D:/faces/*.jpg')
dataset = tf.data.Dataset.from_tensor_slices(img_paths)
dataset = dataset.map(preprocess)
dataset = dataset.batch(batch_size)
class Generator(tf.keras.Model):
    def __init__(self):
        super(Generator, self).__init__()
        filter = 64
        
        self.conv1 = tf.keras.layers.Conv2DTranspose(filter * 8, 
                                    4, 1, padding = 'valid',
                                    use_bias = False)
        self.bn1 = tf.keras.layers.BatchNormalization()
        
        self.conv2 = tf.keras.layers.Conv2DTranspose(filter * 4,
                                    4, 2, padding = 'same',
                                    use_bias = False)
        self.bn2 = tf.keras.layers.BatchNormalization()
        
        self.conv3 = tf.keras.layers.Conv2DTranspose(filter * 2,
                                    4, 2, padding = 'same',
                                    use_bias = False)
        self.bn3 = tf.keras.layers.BatchNormalization()
        
        self.conv4 = tf.keras.layers.Conv2DTranspose(filter * 1,
                                    4, 2, padding = 'same',
                                    use_bias = False)
        self.bn4 = tf.keras.layers.BatchNormalization()
        
        self.conv5 = tf.keras.layers.Conv2DTranspose(3, 4, 2,
                                    padding = 'same',
                                    use_bias = False)
        
    def call(self, inputs, training = None):
        x = inputs
        x = tf.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
        x = tf.nn.relu(x)
        x = tf.nn.relu(self.bn1(self.conv1(x), training=training))
        x = tf.nn.relu(self.bn2(self.conv2(x), training=training))
        x = tf.nn.relu(self.bn3(self.conv3(x), training=training))
        x = tf.nn.relu(self.bn4(self.conv4(x), training=training))
        x = self.conv5(x)
        x = tf.tanh(x)
        
        return x
class Discriminator(tf.keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        filter = 64
        
        self.conv1 = tf.keras.layers.Conv2D(filter, 4, 2, 
                                    padding = 'valid',
                                    use_bias = False)
        self.bn1 = tf.keras.layers.BatchNormalization()
        
        self.conv2 = tf.keras.layers.Conv2D(filter * 2, 4, 2,
                                    padding = 'valid',
                                    use_bias = False)
        self.bn2 = tf.keras.layers.BatchNormalization()
        
        self.conv3 = tf.keras.layers.Conv2D(filter * 4, 4, 2,
                                    padding = 'valid',
                                    use_bias = False)
        self.bn3 = tf.keras.layers.BatchNormalization()
        
        self.conv4 = tf.keras.layers.Conv2D(filter * 8, 3, 1,
                                    padding = 'same',
                                    use_bias = False)
        self.bn4 = tf.keras.layers.BatchNormalization()
        
        self.conv5 = tf.keras.layers.Conv2D(filter * 16, 3, 1, 
                                    padding = 'same',
                                    use_bias = False)
        self.bn5 = tf.keras.layers.BatchNormalization()
        
        self.pool = tf.keras.layers.GlobalAveragePooling2D()
        
        self.flatten = tf.keras.layers.Flatten()
        
        self.fc = tf.keras.layers.Dense(1)
        
    def call(self, inputs, training = None):
        x = tf.nn.leaky_relu(self.bn1(self.conv1(inputs), training=training)) 
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))

        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))

        x = tf.nn.leaky_relu(self.bn4(self.conv4(x), training=training))

        x = tf.nn.leaky_relu(self.bn5(self.conv5(x), training=training))
           
        x = self.pool(x)
            
        x = self.flatten(x)
            
        logits = self.fc(x)
            
        return logits   
def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
    fake_image = generator(batch_z, is_training)
    d_fake_logits = discriminator(fake_image, is_training)
    d_real_logits = discriminator(batch_x, is_training)
    
    d_loss_real = celoss_ones(d_real_logits)
    
    d_loss_fake = celoss_zeros(d_fake_logits)
    
    loss = d_loss_fake + d_loss_real
    
    return loss

def celoss_ones(logits):
    y = tf.ones_like(logits)
    loss = tf.keras.losses.binary_crossentropy(y, logits,
                                              from_logits = True)
    
    return tf.reduce_mean(loss)

def celoss_zeros(logits):
    y = tf.zeros_like(logits)
    loss = tf.keras.losses.binary_crossentropy(y, logits,
                                              from_logits=True)
    
    return tf.reduce_mean(loss)

def g_loss_fn(generator, discriminator, batch_z, is_training):
    fake_image = generator(batch_z, is_training)
    d_fake_logits = discriminator(fake_image, is_training)
    loss = celoss_ones(d_fake_logits)
    
    return loss
generator = Generator()
generator.build(input_shape = (4, 100))
discriminator = Discriminator()
discriminator.build(input_shape=(4, 64, 64, 3))
g_optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0002,
                                      beta_1 = 0.5)
d_optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0002,
                                      beta_1 = 0.5)
for epoch in range(300):
    for _ in range(5):
        batch_z = tf.random.normal([batch_size, 100])
        batch_x = next(iter(dataset))
        with tf.GradientTape() as tape:
            d_loss = d_loss_fn(generator,discriminator, batch_z, batch_x, True)
        grads = tape.gradient(d_loss, discriminator.trainable_variables)
        d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
    batch_z = tf.random.normal([batch_size, 100])
    batch_x = next(iter(dataset))
    with tf.GradientTape() as tape:
        g_loss = g_loss_fn(generator,discriminator, batch_z, True)
        
    grads = tape.gradient(g_loss, generator.trainable_variables)
    g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
    
    if epoch % 100 ==0:
        print(epoch,'d_loss:',float(d_loss),'g_loss:',float(g_loss))

13.4 GAN 变种

原始GAN 模型在图片生成效果并不突出,和VAE差别不明显,并没有展现出它强大的分布逼近能力。但是由于GAN在理论方面教新颖,实现方面也有很多可以改进的地方,因此激发了学术界的研究兴趣。

13.4.1 DCGAN

DCGAN (2015) 提出了使用转置卷积层实现的生成网络,普通卷积层来实现的判别网络,来降低网络的参数量,同时图片的生成效果也大幅提升。

13.4.2 InfoGAN

InfoGAN (2016) 尝试使用无监督的方式去学习输入x的可接受隐向量z表示方法。

13.4.3 CycleGAN

CycleGAN (2017)是华人学者朱俊彦提出的无监督的方式进行图片转换的算法,并且其算法清晰简单,实验效果完成的非常好

CycleGAN 基本的假设是,如果由图片A转换到图片B,再从图片B转换到A^{'}那么A^{'}应该和A是同一张图片,因此出了设立标准的GAN 损失项外,CycleGAN 还增设了循环一致性损失

13.4.4 WGAN/WGAN-GP

GAN 的训练很容易出现训练不收敛和模式崩塌的现象。

WGAN (2017)从理论层面分析了原始的 GAN 使用 JS 散度存在的缺陷,并提出了可以用 Wasserstein 距离来解决这个问题。

WGAN-GP(2017),作者提出了通过添加梯度惩罚项目,从工程层面很好的实现了 WGAN 算法,并且实验性证实了 WGAN 训练稳定的优点。

13.4.5 Equal GAN

Google Brain 的几位研究员在2018年提出了另一个观点:没有证据表明我们测试的GAN变种算法一直持续地比最初始的GAN要好。论文中对这些GAN变种进行了相对公平,全面的比较,在有足够计算资源的情况下,几乎所有的GAN变种都能达到相似的性能(FID分数)。

13.4.6 Self-Attention GAN

Self-Attention GAN (SAGAN)借鉴了 Attention 机制,提出了基于自我注意力机制的 GAN 变种。SAGAN(2019)把图片的逼真图指标,Inception score(从最好的36.8 提升到52.52),Frechet Inception distance(从27.62降到18.65)。

13.4.7 BigGAN

在 SAGAN 的基础上,BigGAN(2019)尝试将GAN的训练扩展到大规模上,利用正交正则化等技巧保证训练过程的稳定性。

BigGAN 的意义在于:GAN 网络的训练同样可以从大数据,大算力中间受益。

其把图片的逼真图指标,Inception score(提升到166.5),Frechet Inception distance(从27.62降到18.65)。

13.5 纳什均衡

从理论层面进行分析,通过博弈学习的训练方式,生成器G和判别器D分别会达到什么状态。

探索以下两个问题:

  • 固定GD会收敛到什么最优状态D^{*}
  • D达到最优状态D^{*}后,G会收敛到什么状态?

13.5.1 判别器状态

回归GAN的损失函数:


GAN损失函数

对于判别器D,优化的目标是最大化\mathcal{L}(G,D)函数,需要找出:

公式

公式

不是很能理解这里。

13.5.2 生成器状态

JS 散度,它定义为KL散度的组合:


推导
推导
推导
推导

13.6 GAN 训练难题

GAN 网络训练困难的问题,主要体现在:

13.6.1 超参数敏感

超参数敏感是指网络的结构,学习率,初始化状态等超参数对网络的训练过程影响较大,微量的超参数调整可能导致网络的训练结果截然不同。

为了能较好地训练GAN网络,DCGAN 论文作者提出了不使用 Pooling 层,多使用 Batch Normalization层,不使用全连接层,生成网络中激活函数使用ReLU,最后一层使用tanh,判别网络激活函数室友LeakyLeLU等一系列经验性技巧。

13.6.2 模式崩塌

模式崩塌是指模型生成的样本单一,多样性很差。

由于判别器只能鉴别单个样本是否采样直真实样本,并没有对样本多样性进行显示约束,导致生成模型可能倾向于生成真实分布的部分区间中的少量高质量样本,以此来在判别器的输出中获得较高的概率值,

但是,我们希望生成网络能够逼近真实的分布,而不是真实分布中的某部分。

13.7 WGAN 原理

WGAN 作者提出是因为 JS 散度在不重叠的分布p,g的梯度曲面是恒定为0的,当分布p,g不重叠时,JS散度的梯度始终为0,从而导致此时GAN的训练出现梯度弥散现象,参数长时间得不到更新,网络无法收敛。

要解决此问题,需要使用一种更好的分布距离衡量标准,使得它即使在分布不重叠时,也能平滑反映分布之间的距离变换。

13.7.2 EM 距离

WGAN 论文中发现了JS散度导致GAN训练不稳定的问题,并引入了一种新的分布距离度量方法:Wasserstein Distance,也叫做推土机距离(EM 距离),它表示了从一个分布变换到另一个分布的最小代价,定义为:


公式

image.png

13.7.3 WGAN-GP

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,723评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,003评论 3 391
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,512评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,825评论 1 290
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,874评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,841评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,812评论 3 416
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,582评论 0 271
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,033评论 1 308
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,309评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,450评论 1 345
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,158评论 5 341
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,789评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,409评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,609评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,440评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,357评论 2 352