前一阵子,偶然看到一个换脸的视频,觉得实在是很神奇,于是饶有兴致的去了解一下换脸算法。原来背后有一个极为有意思的算法思想——对抗生成。今天笔者斗胆来介绍一下在学术界大名鼎鼎的GAN(Generative Adversarial Networks ),此网络结构由Ian J. Goodfellow大神在2014年提出,一经推出,就引爆了学术界。
随后各种各样的GAN算法以指数级增长的方式涌现出来,比如WGAN(Wasserstein GAN),CGAN(condition gan),SRGAN(super resolution gan)等。据说后来提出的GAN在取名字简称的时候——XXGAN,其中GAN的前面的XX,26个英文字母两两排列组合都快不够用了,这足以见得这个算法最近几年的热度。而GAN也有很多应用场景:
- 高清图片生成。
- 消除马赛克。
- 侧脸转正等等。
由于笔者只在稍微了解过图像领域的GAN算法,所以只能说出以上具体的应用场景。不过据了解在自然语言处理领域GAN也可用了训练聊天机器人(chatbot)。总之笔者感觉GAN这个算法如果用对了地方,还是能够发挥出它的潜力的。
GAN算法简介
GAN的结构
首先我们简单了解一下最原始的GAN网络结构,如下图,主要只看分为黄色长方形和粉红色长方形,这两部分为Network的部分:
- 一个生成器(粉红的generator),
- 一个判别器(黄色的discriminator),
接下来注意了,我们仔细研究下这两个网络的输入和输出,同时了解一下这两个网络的关系:
- 生成器的输入是随机生成的噪声向量,输出是一张图片(2维或者3维矩阵)
- 判别器的输入是真实的图片(2维或者3维矩阵)和生成器生成的图片(2维或者3维矩阵),输出是0或者1。
按照农场文的普遍的讲法,整个GAN做的事情就是类似于假画师和鉴画师之间的博弈,是不是现在有点对抗(Adversarial )的意思了,其整个过程分为以下两部分: - 生成器的训练(假画师提高自己画假画的水平):生成的图片能够以假乱真欺骗判别器。
- 判别器的训练(鉴画师提高自己鉴别假画的水平):能够鉴别出生成器生成的假图片。
最后结果可想而之,这两方在互相博弈之间,都得到了极大的提升。判别器鉴别能力越来越强,而生成器生成的图片越来越像真的。最终我们拿到训练好的生成器,随机输入一个噪声向量给它,它也能输出一张以假乱真的图片。
GAN的原理
笔者在这里不想讲太多的原理部分,大家感兴趣的可以去访问我的参考文献部分,其中台湾大学的李宏毅老师视频和苏剑林大神的博客中,将GAN讲得很通俗易懂。在数学理论方面笔者只强调一句话:GAN的训练目的是希望生成器生成的数据分布和真实数据里的分布越像越好,如下图所示。
这里特别推荐一下李宏毅老师的课程,不需要你懂太高深的数学,也可以了解GAN和WGAN原理部分的精髓。
DCGAN的实战部分
DCGAN
实战部分笔者采用的是DCGAN,这个DCGAN架构规定一些搭建GAN网络时的规则:
- 在生成网络和判别网络上必须使用批处理规范化。
- 对于更深的架构移除全连接隐藏层。
- 在生成网络的所有层上必须使用ReLU激活函数,除了输出层使用Tanh激活函数。
- 在判别网络的所有层上必须使用LeakyReLU激活函数。
这些规定主要是为了使优化效果变得更好,没有特别好的数学解释。既然DCGAN做图像生成效果好,那我们用起来吧。
载入数据
数据载入部分其实很简单,就是一堆开通人物人脸的图片,没有label。从下方代码中可以看到,笔者本次实验图片是528张shape为(96, 96, 3)的图片,所以生成器需要生成的图片矩阵维度就必须是(96,96,3)。
import numpy as np
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import keras.backend as K
import matplotlib.pyplot as plt
import sys
from PIL import Image
import os
pic_list = os.listdir('./anime-faces/1boy')
pic_arr_list = []
###read anime-faces data from folder
for i in range(len(pic_list)):
t = Image.open("./anime-faces/1boy/{}".format(pic_list[I]))
t = np.array(t)
t = t/127.5 - 1
pic_arr_list.append(t)
### convert the picture data to array
train_data = np.array(pic_arr_list)
train_data.shape#(528, 96, 96, 3)
定义generator
注意DCGAN中生成网络的所有层使用ReLU的激活函数,除了输出层使用Tanh激活函数。而且每层不要忘了加BN层(作用是批处理规范化)。同时定义好输入随机噪声的维度,这里笔者定义的是100维。输出维度就是生成图片的维度(96,96,3)。
def build_generator():
model = Sequential()
model.add(Dense(128 * 24 * 24, activation="relu", input_dim=100))
model.add(Reshape((24, 24, 128)))
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=3, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=3, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(Conv2D(3, kernel_size=3, padding="same"))
model.add(Activation("tanh"))
model.summary()
noise = Input(shape=(100,))
img = model(noise)
return Model(noise, img)
从上方的代码和下方的生成器的结构可视化中可以看出,整个网络的结构,以及输入维度和输出维度。
定义discriminator
在DCGAN中定义discriminator时,判别网络的所有层使用LeakyReLU的激活函数。而且每层也需要加BN层进行批处理规范化处理。而判别器的输入时一张(96,96,3)的矩阵,输出则是0或者1。
def build_discriminator():
model = Sequential()
model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=(96,96,3), padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
依然从上方代码和下方模型可视化输出中可以清晰看到,判别器的网络结构,以及输入,输出。总之判别器就是为了判断输入图片是真实图片还是生成图片。
联系生成器和判别器
这部分就是定义GAN的最关键部分,我们需要让生成器和判别器联系起来。下面部分代码有两点需注意:
- 将生成器生成的图片输入给判别器,
- 此时判别器不做训练,只训练生成器。
到这一步有的同学就会问了,为啥不训练判别器呢?别急,对抗的过程(判别器训练一步,生成器训练一步)在GAN的训练中才会体现出来。
optimizer = Adam(0.0002, 0.5)
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
generator = build_generator()
z = Input(shape=(100,))
#feed the random noise to the generator
img = generator(z)
# For the combined model we will only train the generator
discriminator.trainable = False
# The discriminator takes generated images as input and determines validity
valid_g = discriminator(img)
# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
combined = Model(z, valid_g)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)
combined.summary()
训练GAN
接下来接可以开始DCGAN的训练了,这里代码的含义是先训练一步判别器,在训练一步生成器,二者互相博弈,互相进步。
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random half of images
idx = np.random.randint(0, train_data.shape[0], batch_size)
imgs = train_data[idx]
# Sample noise and generate a batch of new images
noise = np.random.normal(0, 1, (batch_size, 100))
gen_imgs = generator.predict(noise)
# Train the discriminator (real classified as ones and generated as zeros)
d_loss_real = discriminator.train_on_batch(imgs, valid_d)
d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator
# ---------------------
# Train the generator (wants discriminator to mistake images as real)
g_loss = combined.train_on_batch(noise, valid_d)
# Plot the progress
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
# If at save interval => save generated image samples
if epoch % save_interval == 0:
r, c = 5, 5
noise_save = np.random.normal(0, 1, (r * c, 100))
gen_imgs = generator.predict(noise_save)
# gen_imgs = 0.5* gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], )
axs[i,j].axis('off')
cnt += 1
fig.savefig("/mnt/disk2/data/wp/test/images/boy_%d.png" % epoch)
plt.close()
从下图模型输出的损失函数中我们可以看到判别器和生成器博弈的过程,判别器的D loss先下降后上升(当然笔者的实验不是很明显),代表着判别器本先变强导致D loss下降,之后生成器开始发力,生成质量更好的图片,使得D loss上升。判别器的D loss应该是一个跌宕起伏的曲线。这里笔者的D loss 很小,说明判别器太强大了,其实在GAN训练的过程中,任意一方太强,都会导致模型训练效果不好,比较GAN是个相互进步,相互促进的过程,任何一方太强都会导致大家无法进步。
笔者在这里输出了模型跑了500个epoch和5000个epoch之后生成器生成的图像效果对比。
从两张图的生成效果上来说,5000个epoch时,生成器生成的图片质量更好一些,已经能够可看出卡通人物脸清晰的轮廓曲线了。
使用generator生成图片
pic = generator.predict(np.random.normal(0, 1, (1, 100)))
plt.imshow(np.squeeze(pic[0]))
最终笔者在模型训练好之后,运行上方代码,给生成器随机输入一个100维的向量,生成下方那个绿头发的卡通人脸。看起来效果还不错耶。是不是很神奇
结语
GAN确实是个很有趣的结构,对抗生成的思想很像我们人类社会中的棋逢对手的情况。在足球界梅西和C罗,正是因为对方的存在而促使对方努力,互相进步,形成绝代双骄的局面,在金庸的武侠世界老顽童周伯通发明来双手互搏来提升功力,而GAN正是使用这种对抗的方式学习进步。在人类世界独孤求败有时候也是很悲哀的一种情景,这也暗合了在train GAN时一定要保持判别器和生成器实力相当,不然你trian出来的GAN肯定很糟。
参考
https://spaces.ac.cn/archives/6240
http://speech.ee.ntu.edu.tw/~tlkagk/courses_MLDS18.html
https://arxiv.org/pdf/1511.06434.pdf
https://arxiv.org/abs/1406.2661
https://github.com/eriklindernoren/Keras-GAN