无中生有炼丹术,生成逼真卡通人脸——DCGAN(对抗生成网络)实战

前一阵子,偶然看到一个换脸的视频,觉得实在是很神奇,于是饶有兴致的去了解一下换脸算法。原来背后有一个极为有意思的算法思想——对抗生成。今天笔者斗胆来介绍一下在学术界大名鼎鼎的GAN(Generative Adversarial Networks ),此网络结构由Ian J. Goodfellow大神在2014年提出,一经推出,就引爆了学术界。

随后各种各样的GAN算法以指数级增长的方式涌现出来,比如WGAN(Wasserstein GAN),CGAN(condition gan),SRGAN(super resolution gan)等。据说后来提出的GAN在取名字简称的时候——XXGAN,其中GAN的前面的XX,26个英文字母两两排列组合都快不够用了,这足以见得这个算法最近几年的热度。而GAN也有很多应用场景:

  • 高清图片生成。
  • 消除马赛克。
  • 侧脸转正等等。

由于笔者只在稍微了解过图像领域的GAN算法,所以只能说出以上具体的应用场景。不过据了解在自然语言处理领域GAN也可用了训练聊天机器人(chatbot)。总之笔者感觉GAN这个算法如果用对了地方,还是能够发挥出它的潜力的。

GAN算法简介

GAN的结构

首先我们简单了解一下最原始的GAN网络结构,如下图,主要只看分为黄色长方形和粉红色长方形,这两部分为Network的部分:

  • 一个生成器(粉红的generator),
  • 一个判别器(黄色的discriminator),

接下来注意了,我们仔细研究下这两个网络的输入和输出,同时了解一下这两个网络的关系:

  • 生成器的输入随机生成的噪声向量输出一张图片(2维或者3维矩阵)
  • 判别器的输入是真实的图片(2维或者3维矩阵)生成器生成的图片(2维或者3维矩阵),输出是0或者1。
    GAN

    按照农场文的普遍的讲法,整个GAN做的事情就是类似于假画师鉴画师之间的博弈,是不是现在有点对抗(Adversarial )的意思了,其整个过程分为以下两部分:
  • 生成器的训练(假画师提高自己画假画的水平):生成的图片能够以假乱真欺骗判别器。
  • 判别器的训练(鉴画师提高自己鉴别假画的水平):能够鉴别出生成器生成的假图片。

最后结果可想而之,这两方在互相博弈之间,都得到了极大的提升。判别器鉴别能力越来越强,而生成器生成的图片越来越像真的。最终我们拿到训练好的生成器,随机输入一个噪声向量给它,它也能输出一张以假乱真的图片。

GAN的原理

笔者在这里不想讲太多的原理部分,大家感兴趣的可以去访问我的参考文献部分,其中台湾大学的李宏毅老师视频和苏剑林大神的博客中,将GAN讲得很通俗易懂。在数学理论方面笔者只强调一句话:GAN的训练目的是希望生成器生成的数据分布和真实数据里的分布越像越好,如下图所示。

GAN的数学解释

这里特别推荐一下李宏毅老师的课程,不需要你懂太高深的数学,也可以了解GAN和WGAN原理部分的精髓。

DCGAN的实战部分

DCGAN

实战部分笔者采用的是DCGAN,这个DCGAN架构规定一些搭建GAN网络时的规则:

  • 在生成网络和判别网络上必须使用批处理规范化。
  • 对于更深的架构移除全连接隐藏层。
  • 在生成网络的所有层上必须使用ReLU激活函数,除了输出层使用Tanh激活函数。
  • 在判别网络的所有层上必须使用LeakyReLU激活函数。

这些规定主要是为了使优化效果变得更好,没有特别好的数学解释。既然DCGAN做图像生成效果好,那我们用起来吧。

载入数据

数据载入部分其实很简单,就是一堆开通人物人脸的图片,没有label。从下方代码中可以看到,笔者本次实验图片是528张shape为(96, 96, 3)的图片,所以生成器需要生成的图片矩阵维度就必须是(96,96,3)。

import numpy as np
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import keras.backend as K
import matplotlib.pyplot as plt
import sys
from PIL import Image
import os
pic_list = os.listdir('./anime-faces/1boy')
pic_arr_list = []
###read anime-faces data from folder
for i in range(len(pic_list)):
    t = Image.open("./anime-faces/1boy/{}".format(pic_list[I]))
    t = np.array(t)
    t = t/127.5 - 1
    pic_arr_list.append(t)
### convert the picture data to array
train_data = np.array(pic_arr_list)
train_data.shape#(528, 96, 96, 3)
定义generator

注意DCGAN中生成网络的所有层使用ReLU的激活函数,除了输出层使用Tanh激活函数。而且每层不要忘了加BN层(作用是批处理规范化)。同时定义好输入随机噪声的维度,这里笔者定义的是100维。输出维度就是生成图片的维度(96,96,3)。

def build_generator():
        model = Sequential()
        model.add(Dense(128 * 24 * 24, activation="relu", input_dim=100))
        model.add(Reshape((24, 24, 128)))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(Conv2D(3, kernel_size=3, padding="same"))
        model.add(Activation("tanh"))
        model.summary()
        noise = Input(shape=(100,))
        img = model(noise)
        return Model(noise, img)

从上方的代码和下方的生成器的结构可视化中可以看出,整个网络的结构,以及输入维度和输出维度。


gen
定义discriminator

在DCGAN中定义discriminator时,判别网络的所有层使用LeakyReLU的激活函数。而且每层也需要加BN层进行批处理规范化处理。而判别器的输入时一张(96,96,3)的矩阵,输出则是0或者1。

def build_discriminator():  
    model = Sequential()
    model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=(96,96,3), padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
    model.add(ZeroPadding2D(padding=((0,1),(0,1))))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))

依然从上方代码和下方模型可视化输出中可以清晰看到,判别器的网络结构,以及输入,输出。总之判别器就是为了判断输入图片是真实图片还是生成图片

dis

联系生成器和判别器

这部分就是定义GAN的最关键部分,我们需要让生成器和判别器联系起来。下面部分代码有两点需注意:

  • 将生成器生成的图片输入给判别器,
  • 此时判别器不做训练,只训练生成器。

到这一步有的同学就会问了,为啥不训练判别器呢?别急,对抗的过程(判别器训练一步,生成器训练一步)在GAN的训练中才会体现出来。

optimizer = Adam(0.0002, 0.5)
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])
generator = build_generator()
z = Input(shape=(100,))
#feed the random noise to the generator
img = generator(z)
# For the combined model we will only train the generator
discriminator.trainable = False
# The discriminator takes generated images as input and determines validity
valid_g = discriminator(img)
# The combined model  (stacked generator and discriminator)
# Trains the generator to fool the discriminator
combined = Model(z, valid_g)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)
combined.summary()
GAN
训练GAN

接下来接可以开始DCGAN的训练了,这里代码的含义是先训练一步判别器,在训练一步生成器,二者互相博弈,互相进步。

for epoch in range(epochs):
    # ---------------------
    #  Train Discriminator
    # ---------------------

    # Select a random half of images
    idx = np.random.randint(0, train_data.shape[0], batch_size)
    imgs = train_data[idx]

    # Sample noise and generate a batch of new images
    noise = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(noise)

    # Train the discriminator (real classified as ones and generated as zeros)
    d_loss_real = discriminator.train_on_batch(imgs, valid_d)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # ---------------------
    #  Train Generator
    # ---------------------
    # Train the generator (wants discriminator to mistake images as real)
    g_loss = combined.train_on_batch(noise, valid_d)

    # Plot the progress
    print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

    # If at save interval => save generated image samples
    if epoch % save_interval == 0:
        r, c = 5, 5
        noise_save = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = generator.predict(noise_save)
#         gen_imgs = 0.5* gen_imgs + 0.5
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], )
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("/mnt/disk2/data/wp/test/images/boy_%d.png" % epoch)
        plt.close()

从下图模型输出的损失函数中我们可以看到判别器和生成器博弈的过程,判别器的D loss先下降后上升(当然笔者的实验不是很明显),代表着判别器本先变强导致D loss下降,之后生成器开始发力,生成质量更好的图片,使得D loss上升。判别器的D loss应该是一个跌宕起伏的曲线。这里笔者的D loss 很小,说明判别器太强大了,其实在GAN训练的过程中,任意一方太强,都会导致模型训练效果不好,比较GAN是个相互进步,相互促进的过程,任何一方太强都会导致大家无法进步。


train_loss

笔者在这里输出了模型跑了500个epoch和5000个epoch之后生成器生成的图像效果对比。


500 epoch

从两张图的生成效果上来说,5000个epoch时,生成器生成的图片质量更好一些,已经能够可看出卡通人物脸清晰的轮廓曲线了。
5000 epoch
使用generator生成图片
pic = generator.predict(np.random.normal(0, 1, (1, 100)))
plt.imshow(np.squeeze(pic[0]))

最终笔者在模型训练好之后,运行上方代码,给生成器随机输入一个100维的向量,生成下方那个绿头发的卡通人脸。看起来效果还不错耶。是不是很神奇


效果图

结语

GAN确实是个很有趣的结构,对抗生成的思想很像我们人类社会中的棋逢对手的情况。在足球界梅西和C罗,正是因为对方的存在而促使对方努力,互相进步,形成绝代双骄的局面,在金庸的武侠世界老顽童周伯通发明来双手互搏来提升功力,而GAN正是使用这种对抗的方式学习进步。在人类世界独孤求败有时候也是很悲哀的一种情景,这也暗合了在train GAN时一定要保持判别器和生成器实力相当,不然你trian出来的GAN肯定很糟。

参考

https://spaces.ac.cn/archives/6240
http://speech.ee.ntu.edu.tw/~tlkagk/courses_MLDS18.html
https://arxiv.org/pdf/1511.06434.pdf
https://arxiv.org/abs/1406.2661
https://github.com/eriklindernoren/Keras-GAN

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,172评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,346评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 159,788评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,299评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,409评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,467评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,476评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,262评论 0 269
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,699评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,994评论 2 328
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,167评论 1 343
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,827评论 4 337
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,499评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,149评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,387评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,028评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,055评论 2 352

推荐阅读更多精彩内容