CycleGAN

这篇文章理解自知乎上两篇文章:

GAN 补充

深度生成模型的分类树如下:

可以根据极大似然原理学习的深度生成模型,根据如何表示或预估概率,可以分为显式密度模型和隐式密度模型。显式密度模型可以构建一个明确的密度模型,p(x;θ),因此可以求得使可能性最大的 θ 值。显式密度模型又分为易解决的和不易解决的(需要使用近似法求最大化可能性的 θ)。对于隐式密度模型,则没有明确表示数据空间的概率分布,相反,该模型提供了一些与该概率分布间接相互作用的方式——生成样本,即定义一种在没有任何输入的情况下,通过随机转换现有样本,以便获取另一个服从同一分布的样本的方法。GAN 即属于隐式密度模型,它直接从模型表示的分布中采样,而非使用马尔可夫链。

GAN 核心原理的数学描述为:

简单分析一下这个公式:

  • 整个式子由两项构成。x 表示真实图片,z 表示输入 G 网络的噪声,而 G(z) 表示 G 网络生成的图片。
  • D(x) 表示 D 网络判断真实图片是否真实的概率(因为x就是真实的,所以对于 D 来说,这个值越接近1越好)。而 D(G(z)) 是 D 网络判断 G 生成的图片的是否真实的概率。
  • G 的目的:G 希望自己生成的图片“越接近真实越好”。也就是说,G 希望 D(G(z)) 尽可能得大,这时 V(D, G) 会变小。因此式子对于 G 来说是求最小(min_G)。
  • D的目的:D 的能力越强,D(x) 应该越大,D(G(x)) 应该越小,这时 V(D,G) 会变大。因此式子对于 D 来说是求最大(max_D)。

用随机梯度下降法训练 D 和 G 的算法为:

第一步训练 D,D 是希望 V(G, D) 越大越好,所以是加上梯度(ascending)。第二步训练 G 时,V(G, D) 越小越好,所以是减去梯度(descending)。整个训练过程交替进行。

CycleGAN 原理

CycleGAN的原理可以概述为:将一类图片转换成另一类图片。也就是说,现在有两个样本空间,X 和 Y,我们希望把 X 空间中的样本转换成 Y 空间中的样本。因此,实际的目标就是学习从 X 到 Y 的映射(设这个映射为 F),F 就对应着 GAN 中的生成器,F 可以将 X 中的图片 x 转换为 Y 中的图片 F(x)。对于生成的图片,我们还需要 GAN 中的判别器来判别它是否为真实图片,由此构成对抗生成网络。设这个判别器为 DY。这样的话,根据这里的生成器和判别器,我们就可以构造一个 GAN 损失,表达式为:

这个损失实际上和原始的 GAN 损失是一模一样的,但单纯的使用这一个损失是无法进行训练的。原因在于,映射 F 完全可以将所有 x 都映射为 Y 空间中的同一张图片,使损失无效化。对此,作者又提出了所谓的循环一致性损失(cycle consistency loss)。再假设一个映射 G,它可以将 Y 空间中的图片 y 转换为 X 中的图片 G(y)。CycleGAN 同时学习 F 和 G 两个映射,并要求 F(G(y)) ≈ y,以及 G(F(x)) ≈ x。也就是说,将 X 的图片转换到 Y 空间后,应该还可以转换回来。这样就杜绝模型把所有 X 的图片都转换为 Y 空间中的同一张图片了。根据 F(G(y)) ≈ y 和 G(F(x)) ≈ x,循环一致性损失就定义为:

同时,为 G 也引入一个判别器 DX,由此可以同样定义一个 GAN 的损失 LGAN(G,DX,X,Y)。最终的损失就由三部分组成:

CycleGAN 的结构示意图如下:

从上图可以了解 CycleGAN 的运作过程:两个输入被传递到对应的鉴别器(一个是对应于该域的原始图像,另一个是通过生成器产生的图像),并且鉴别器的任务是区分它们,识别出生成器输出的生成图像,并拒绝此生成图像。生成器想要确保这些图像被鉴别器接受,所以它将尝试生成与 DB 类中原始图像非常接近的新图像。事实上,在生成器分布与所需分布相同时,生成器和鉴别器之间实现了纳什均衡(Nash equilibrium)。

CycleGAN 的灵活性在于不需要提供从源域到目标域的配对转换例子就可以训练。比如,我们希望训练一个将白天的照片转换为夜晚的模型。如果使用pix2pix模型,那么我们必须在搜集大量地点在白天和夜晚的两张对应图片,而使用CycleGAN只需同时搜集白天的图片和夜晚的图片,不必满足对应关系。因此CycleGAN的用途要比pix2pix更广泛,利用CycleGAN就可以做出更多有趣的应用。

CycleGAN 实现

一、构建生成器

生成器的结构如下:

生成器由三部分组成:编码器、转换器、解码器。

编码
第一步是利用卷积网络从输入图像中提取特征。整个编码过程,将 DA 域中一个尺寸为 [256,256,3] 的图像,输入到设计的编码器中,获得了尺寸为 [64,64,256] 的输出 OAenc。

转换
这些网络层的作用是组合图像的不同相近特征,然后基于这些特征,确定如何将图像的特征向量 OAenc 从 DA 域转换为 DB 域的特征向量。因此,作者使用了 6 层 Resnet 模块。OBenc 表示该层的最终输出,尺寸为 [64,64,256],这可以看作是 DB 域中图像的特征向量。

一个 Resnet 模块是一个由两个卷积层组成的神经网络层,其中部分输入数据直接添加到输出。这样做是为了确保先前网络层的输入数据信息直接作用于后面的网络层,使得相应输出与原始输入的偏差缩小,否则原始图像的特征将不会保留在输出中且输出结果会偏离目标轮廓。这个任务的一个主要目标是保留原始图像的特征,如目标的大小和形状,因此残差网络非常适合完成这些转换。Resnet 模块的结构如下所示:

解码
解码过程与编码方式完全相反,从特征向量中还原出低级特征,这是利用了反卷积层(deconvolution)来完成的。最后,我们将这些低级特征转换得到一张在DB域中的图像,得到一个大小为 [256,256,3] 的生成图像 genB。

二、构建鉴别器

鉴别器将一张图像作为输入,并尝试预测其为原始图像或是生成器的输出图像。鉴别器的结构如下所示:

鉴别器本身就属于卷积网络,需要从图像中提取特征;然后是确定这些特征是否属于该特定类别,使用一个产生一维输出的卷积层来完成这个任务。

至此,已经完成该模型的两个主要组成部分,即生成器和鉴别器。由于要使这个模型可以从 A→B 和 B→A 两个方向工作,所以设置了两个生成器,即生成器 A→B 和生成器 B→A,以及两个鉴别器,即鉴别器 A 和鉴别器 B。

三、建立模型

在定义损失函数前,先定义基础输入变量,来构建模型:

input_A = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name="input_A")
input_B = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name="input_B")

同时定义模型如下:

gen_B = build_generator(input_A, name="generator_AtoB")
gen_A = build_generator(input_B, name="generator_BtoA")
dec_A = build_discriminator(input_A, name="discriminator_A")
dec_B = build_discriminator(input_B, name="discriminator_B")

dec_gen_A = build_discriminator(gen_A, "discriminator_A")
dec_gen_B = build_discriminator(gen_B, "discriminator_B")
cyc_A = build_generator(gen_B, "generator_BtoA")
cyc_B = build_generator(gen_A, "generator_AtoB")

gen 表示使用相应的生成器后生成的图像,dec 表示在将相应输入传递到鉴别器后做出的判断。因此:

  • gen_A 是生成器 B2A 根据真 B 生成的假 A,
    gen_B 是生成器 A2B 根据真 A 生成的假 B;
  • dec_A 是鉴别器 A 对真 A 的鉴别结果,
    dec_B 是鉴别器 B 对真 B 的鉴别结果;
  • dec_gen_A 是鉴别器 A 对 gen_A 的鉴别结果,
    dec_gen_B 是鉴别器 B 对 gen_B 的鉴别结果;
  • cyc_A 是生成器 B2A 根据 gen_B 生成的假 A,
    cyc_B 是生成器 A2B 根据 gen_A 生成的假 B.
四、损失函数

现在我们有两个生成器和两个鉴别器。我们要按照实际目的来设计损失函数。损失函数应该包括如下四个部分:

  1. 鉴别器必须允许所有相应类别的原始图像,即对应输出置 1;
  2. 鉴别器必须拒绝所有想要愚弄过关的生成图像,即对应输出置 0;
  3. 生成器必须使鉴别器允许通过所有的生成图像,来实现愚弄操作;
  4. 所生成的图像必须保留有原始图像的特性,所以如果我们使用生成器 GeneratorA→B 生成一张假图像,那么要能够使用另一个生成器 GeneratorB→A 来努力恢复成原始图像。此过程必须满足循环一致性。

鉴别器损失
通过训练鉴别器 A,使其对真 A 的鉴别输出接近于1,鉴别器 B 也是如此。因此,鉴别器 A 的训练目标为最小化 (DiscriminatorA(a)−1)2 的值,鉴别器 B 也是如此。

另外,由于鉴别器应该能够区分生成图像和原始图像,所以在处理生成图像时期望输出为 0,即鉴别器 A 要最小化 (DiscriminatorA(GeneratorB→A(b)))2 的值。

d_loss_A_1 = tf.reduce_mean(tf.squared_difference(dec_A,1))
d_loss_B_1 = tf.reduce_mean(tf.squared_difference(dec_B,1))

d_loss_A_2 = tf.reduce_mean(tf.square(dec_gen_A))
d_loss_B_2 = tf.reduce_mean(tf.square(dec_gen_B))

d_loss_A = (d_loss_A_1 + d_loss_A_2) / 2
d_loss_B = (d_loss_B_1 + d_loss_B_2) / 2

生成器损失
最终生成器应该使得鉴别器对生成图像的输出值尽可能接近 1。故生成器想要最小化 (DiscriminatorB(GeneratorA→B(a))−1)2。对应代码为:

g_loss_A_1 = tf.reduce_mean(tf.squared_difference(dec_gen_B,1))
g_loss_B_1 = tf.reduce_mean(tf.squared_difference(dec_gen_A,1))

循环损失
最后一个重要参数为循环丢失(cyclic loss),能判断用另一个生成器得到的生成图像与原始图像的差别。因此原始图像和循环图像之间的差异应该尽可能小:

cyc_loss = tf.reduce_mean(tf.abs(input_A - cyc_A)) + tf.reduce_mean(tf.abs(input_B - cyc_B))

所以完整的生成器损失为:

g_loss_A = g_loss_A_1 + 10 * cyc_loss
g_loss_B = g_loss_B_1 + 10 * cyc_loss

cyc_loss 的乘法因子设置为 10,说明循环损失比鉴别损失更重要。

五、训练模型

定义好损失函数,接下来只需要训练模型来最小化损失函数。

d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)
d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)
g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)
g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars)

训练过程如下:

for epoch in range(0,100):
    # Define the learning rate schedule. The learning rate is kept
    # constant upto 100 epochs and then slowly decayed
    if(epoch < 100) :
        curr_lr = 0.0002
    else:
        curr_lr = 0.0002 - 0.0002*(epoch-100)/100

    # Running the training loop for all batches
    for ptr in range(0,num_images):

        # Train generator G_A->B
        _, gen_B_temp = sess.run([g_A_trainer, gen_B],
                                 feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})

        # We need gen_B_temp because to calculate the error in training D_B
        _ = sess.run([d_B_trainer],
                     feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})

        # Same for G_B->A and D_A as follow
        _, gen_A_temp = sess.run([g_B_trainer, gen_A],
                                 feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
        _ = sess.run([d_A_trainer],
                     feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})

在训练函数中可以看到,在训练时需要不断调用不同鉴别器和生成器。为了训练模型,需要输入训练图像和选择优化器的学习率。由于 batch_size 设置为1,所以 num_batches 等于 num_images。

我们已经完成了模型构建,下面是模型中一些默认超参数。

生成图像库
计算每个生成图像的鉴别器损失是不可能的,因为会耗费大量的计算资源。为了加快训练,我们存储了之前每个域的所有生成图像,并且每次仅使用一张图像来计算误差。首先,逐个填充图像库使其完整,然后随机将某个库中的图像替换为最新的生成图像,并使用这个替换图像来作为该步的训练。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,271评论 5 476
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,275评论 2 380
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,151评论 0 336
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,550评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,553评论 5 365
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,559评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,924评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,580评论 0 257
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,826评论 1 297
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,578评论 2 320
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,661评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,363评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,940评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,926评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,156评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,872评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,391评论 2 342

推荐阅读更多精彩内容