使用Keras实现生成式对抗网络GAN

生成式对抗网络(GAN)自2014年提出以来已经成为最受欢迎的生成模型。本文借鉴机器之心对 2014 GAN 论文的解读,在本机运行该Keras项目。

传送门: 机器之心GitHub项目:GAN完整理论推导与实现,Perfect!

接下来主要讲一下如何实现的:

1. 定义一个生成模型:

def generator_model():
    #下面搭建生成器的架构,首先导入序贯模型(sequential),即多个网络层的线性堆叠
    model = Sequential()
    #添加一个全连接层,输入为100维向量,输出为1024维
    model.add(Dense(input_dim=100, output_dim=1024))
    #添加一个激活函数tanh
    model.add(Activation('tanh'))
    #添加一个全连接层,输出为128×7×7维度
    model.add(Dense(128*7*7))
    #添加一个批量归一化层,该层在每个batch上将前一层的激活值重新规范化,即使得其输出数据的均值接近0,其标准差接近1
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    #Reshape层用来将输入shape转换为特定的shape,将含有128*7*7个元素的向量转化为7×7×128张量
    model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
    #2维上采样层,即将数据的行和列分别重复2次
    model.add(UpSampling2D(size=(2, 2)))
    #添加一个2维卷积层,卷积核大小为5×5,激活函数为tanh,共64个卷积核,并采用padding以保持图像尺寸不变
    model.add(Conv2D(64, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    model.add(UpSampling2D(size=(2, 2)))
    #卷积核设为1即输出图像的维度
    model.add(Conv2D(1, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    return model

2. 定义一个判别模型:

def discriminator_model():
    #下面搭建判别器架构,同样采用序贯模型
    model = Sequential()
    
    #添加2维卷积层,卷积核大小为5×5,激活函数为tanh,输入shape在‘channels_first’模式下为(samples,channels,rows,cols)
    #在‘channels_last’模式下为(samples,rows,cols,channels),输出为64维
    model.add(
            Conv2D(64, (5, 5),
            padding='same',
            input_shape=(28, 28, 1))
            )
    model.add(Activation('tanh'))
    #为空域信号施加最大值池化,pool_size取(2,2)代表使图片在两个维度上均变为原长的一半
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (5, 5)))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    #Flatten层把多维输入一维化,常用在从卷积层到全连接层的过渡
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('tanh'))
    #一个结点进行二值分类,并采用sigmoid函数的输出作为概念
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

3. 拼接:

def generator_containing_discriminator(g, d):
    #将前面定义的生成器架构和判别器架构组拼接成一个大的神经网络,用于判别生成的图片
    model = Sequential()
    #先添加生成器架构,再令d不可训练,即固定d
    #因此在给定d的情况下训练生成器,即通过将生成的结果投入到判别器进行辨别而优化生成器
    model.add(g)
    d.trainable = False
    model.add(d)
    return model

4. 生成拼接的图片(即将一个batch所有生成图片放到一个图片中):

def combine_images(generated_images):
    #生成图片拼接
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height*shape[0], width*shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
            img[:, :, 0]
    return image

5. 训练:

def train(BATCH_SIZE):
    
    # 国内好像不能直接导入数据集,试了几次都不行,后来将数据集下载到本地'~/.keras/datasets/',也就是当前目录(我的是用户文件夹下)下的.keras文件夹中。
    #下载的地址为:https://s3.amazonaws.com/img-datasets/mnist.npz
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    #image_data_format选择"channels_last"或"channels_first",该选项指定了Keras将要使用的维度顺序。
    #"channels_first"假定2D数据的维度顺序为(channels, rows, cols),3D数据的维度顺序为(channels, conv_dim1, conv_dim2, conv_dim3)
    
    #转换字段类型,并将数据导入变量中
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train[:, :, :, None]   # None将3维的X_train扩展为4维
    X_test = X_test[:, :, :, None]
    # X_train = X_train.reshape((X_train.shape, 1) + X_train.shape[1:])
    
    #将定义好的模型架构赋值给特定的变量
    d = discriminator_model()
    g = generator_model()
    d_on_g = generator_containing_discriminator(g, d)
    
    #定义生成器模型判别器模型更新所使用的优化算法及超参数
    d_optim = SGD(lr=0.001, momentum=0.9, nesterov=True)
    g_optim = SGD(lr=0.001, momentum=0.9, nesterov=True)
    
    #编译三个神经网络并设置损失函数和优化算法,其中损失函数都是用的是二元分类交叉熵函数。编译是用来配置模型学习过程的
    g.compile(loss='binary_crossentropy', optimizer="SGD")
    d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim)
    
    #前一个架构在固定判别器的情况下训练了生成器,所以在训练判别器之前先要设定其为可训练。
    d.trainable = True
    d.compile(loss='binary_crossentropy', optimizer=d_optim)
    
    #下面在满足epoch条件下进行训练
    for epoch in range(30):
        print("Epoch is", epoch)
        
        #计算一个epoch所需要的迭代数量,即训练样本数除批量大小数的值取整;其中shape[0]就是读取矩阵第一维度的长度
        print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
        
        #在一个epoch内进行迭代训练
        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
            
            #随机生成的噪声服从均匀分布,且采样下界为-1、采样上界为1,输出BATCH_SIZE×100个样本;即抽取一个批量的随机样本
            noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
            
            #抽取一个批量的真实图片
            image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            
            #生成的图片使用生成器对随机噪声进行推断;verbose为日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录
            generated_images = g.predict(noise, verbose=0)
            #print(np.shape(generated_images)) # (BATCH_SIZE,28,28,1) # 表示用BATCH_SIZE个100维向量生成BATCH_SIZE个图像的过程
            
            #每经过100次迭代输出一张生成的图片
            if index % 100 == 0:
                image = combine_images(generated_images)
                image = image*127.5+127.5
                Image.fromarray(image.astype(np.uint8)).save(
                    "./GAN/"+str(epoch)+"_"+str(index)+".png")
            
            #将真实的图片和生成的图片以多维数组的形式拼接在一起,真实图片在上,生成图片在下
            X = np.concatenate((image_batch, generated_images))
            # print(np.shape(X)) # # (2*BATCH_SIZE,28,28,1)
            
            #生成图片真假标签,即一个包含两倍批量大小的列表;前一个批量大小都是1,代表真实图片,后一个批量大小都是0,代表伪造图片
            y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
            
            #判别器的损失;在一个batch的数据上进行一次参数更新
            d_loss = d.train_on_batch(X, y)  # (2*BATCH_SIZE,28,28,1) -> (2*BATCH_SIZE,1)
            print("batch %d d_loss : %f" % (index, d_loss))  # 理论上,d的loss越来越大,因为生成图片和真实图片越来越像
            
            #随机生成的噪声服从均匀分布
            noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
            
            #固定判别器
            d.trainable = False
            
            #计算生成器损失;在一个batch的数据上进行一次参数更新
            #生成器的目标是愚弄辨别器蒙混过关,需要达到的目标是对于生成的图片,输出为1(正好和鉴别器相反).
            g_loss = d_on_g.train_on_batch(noise, [1] * BATCH_SIZE) # (BATCH_SIZE,100) -> (BATCH_SIZE,28,28,1) -> (BATCH_SIZE,1)
            
            #令判别器可训练
            d.trainable = True
            print("batch %d g_loss : %f" % (index, g_loss)) # 理论上,g的loss越来越小,因为生成图像越接近真实,生成图像的label接近1
            
            #每100次迭代保存一次生成器和判别器的权重
            if index % 100 == 9:
                g.save_weights('generator', True)
                d.save_weights('discriminator', True)

注意:运行加载MNIST数据集,调用mnist.load_data()函数需要翻墙。如果不翻墙,可在其他地方找到要加载的mnist.npz文件,把它放到Keras安装目录下的~/.keras/datasets/,也可以。不要试图用Tensorflow加载MNIST数据集的那个模块,因为那个模块对MNIST采取了one-hot的编码格式,得到的值都是归一化的数值。而Keras的函数mnist.load_data()加载的MNIST数据集是原始的像素值。

6. 生成:

def generate(BATCH_SIZE, nice= False ):
    #训练完模型后,可以运行该函数生成图片
    g = generator_model()
    g.compile(loss='binary_crossentropy', optimizer="SGD")
    g.load_weights('generator')
    if nice:
        d = discriminator_model()
        d.compile(loss='binary_crossentropy', optimizer="SGD")
        d.load_weights('discriminator')
        noise = np.random.uniform(-1, 1, (BATCH_SIZE*20, 100))
        generated_images = g.predict(noise, verbose=1)
        d_pret = d.predict(generated_images, verbose=1)
        index = np.arange(0, BATCH_SIZE*20)
        index.resize((BATCH_SIZE*20, 1))
        pre_with_index = list(np.append(d_pret, index, axis=1))
        pre_with_index.sort(key=lambda x: x[0], reverse=True)
        nice_images = np.zeros((BATCH_SIZE,) + generated_images.shape[1:3], dtype=np.float32)
        nice_images = nice_images[:, :, :, None]
        for i in range(BATCH_SIZE):
            idx = int(pre_with_index[i][1])
            nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]
        image = combine_images(nice_images)
    else:
        noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
        generated_images = g.predict(noise, verbose=0)
        image = combine_images(generated_images)
    image = image*127.5+127.5
    Image.fromarray(image.astype(np.uint8)).save(
        "./GAN/generated_image.png")

以上代码在支持 Tensorflow、Kerasipython notebook中运行。

先训练模型(迭代30次):

train(100)  # 100为batch大小,可以随意指定。

迭代的效果如下:


9_500.png

19_500.png

29_500.png

再生成模型:

generate(132)  # 132为batch大小,可以随意指定。该值大小也决定了生成的图片中含有多少个数字。
generate(32)  # 32为batch大小,可以随意指定。该值大小也决定了生成的图片中含有多少个数字。

生成的效果如下:


generated_image132.png

generated_image32.png

训练过程分析:
将 MNIST 数据集(60000)分块训练(如 BITCH_SIZE = 200),则一个 Epoch 就会循环 300 次。每一次循环,生成器 g 先根据 BITCH_SIZE 个 100 维的随机噪声生成 BITCH_SIZE 张和真实图像同样大小的图像,然后将这 BITCH_SIZE 张图像和真实的 BITCH_SIZE 张图像拼接起来,给它们打上标签,计算判别器 d 的损失值。接下来,同样生成 BITCH_SIZE 个 100 维的随机噪声,打上和真实图像一样的标签(因为这是 g 所期望的),放入 d_on_g 中。在 d_on_g 中,先由 g 生成 BITCH_SIZE 张图像图像,再固定 d,再由 d 计算这 BITCH_SIZE 张图像的损失值,即生成器 g 的损失值。循环 300 次后,这才是 1 个 Epoch。如果想要生成更真实的图像, 要有多个 Epoch(代码中是30次)。

损失函数分析:
正常情况下,生成模型的损失和判别模型的损失会在一定范围内交替上升与下降。因为判别模型损失小,意味着更容易区分真实图像和假图像;相反,此时的生成模型与它的目标(真实图像)相差很大,损失函数会增大。反之亦然。最终理想的生成器的损失函数应该和判别器的损失函数一样大,这时判别器无法区分真实图像和假图像,生成器也达到了它所能到达的最小损失值。

更多实现细节可以参考我的github

https://github.com/xyxxmb/DeepLearning

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,254评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,875评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,682评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,896评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,015评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,152评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,208评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,962评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,388评论 1 304
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,700评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,867评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,551评论 4 335
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,186评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,901评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,142评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,689评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,757评论 2 351