[GAN笔记] first paper

Generative Adversarial Nets,生成对抗网络。是Ian Goodfellow其受博弈论思想的启发提出的一种通过对抗方式来评估生成模型的框架。主要分为两个模块:一个是生成模型 G,一个是判别模型 D。用原文的话来描述其思想就是:G是一个造假币的人,D则是警察。G将随机拿到的数据不断的仿造真币,并且将真币和假币糅合在一起企图蒙骗过关,而D则不断的提高自己辨别假币的能力,不断的尝试去将真币和假币分辨出来。在经过多次迭代后,G生成的数据的质量越来越逼近真实的数据,而D也就慢慢对输入的数据失去了辨别能力。

基本概念

假设真实数据的分布是 p_{data}(x),我们想要让生成器生成的数据的分布 p_{g} 尽可能的接近 p_{data}(x)。生成器 G(z;\theta_{g})将预定义的先验的输入噪声变量 p_{z}(z) 映射到数据空间中。这里的生成器G是一个受参数 \theta_{g} 控制的复杂函数(这里是一个多层感知机)。另一方面,GAN也定义了另一个多层感知机 D(x;\theta_{d}) ,D的输出是一个标量,可以看做一个二分类器,判断输入的数据是真实的数据还是生成器伪造的数据,若输入的数据是真实数据,D的输出会接近1,输入的是伪造数据,则其输出会接近0。训练时,最大化D的辨别能力,同时训练G去最小化 log(1-D(G(z)))。所以D和G就是对目标函数 V(G,D) 进行最大最小化博弈,目的就是为了不断的缩小生成分布 p_{g} 跟数据分布 p_{data} 之间的散度,让两个分布越接近越好。

GAN目标函数

在G和D的模型足够复杂时(无参数限制),对抗网络的可以逼近真实数据的分布。进行理论分析前,先看一下原文对于GAN的大概描述:
基本概念

如图,黑色的虚线表示真实数据的分布,绿色的实线表示生成数据的分布。蓝色的线表示判别器D的输出,当判别为真实数据时,D的输出很高,反之,D的输出则较小。x 旁边的水平线表示真实数据 X 的一部分,而 z 旁边的实线则表示采样域 Z 。箭头表示 x=G(z)对噪声样本 z 施加一个非均匀分布
p_{g}
 在训练开始时,由于生成器G性能比较差,判别器D可以很轻易的将数据的真实性辨别出来。多次迭代后,生成分布
p_{g}
会越来越接近
p_{data}
,当
p_{g}=p_{data}
时,G和D会达到一个均衡状态,双方都无法继续提升,D(x)会处处为 1/2 。这是因为当D收敛到最大值时,
D^{*}(x)=\frac{p_{data}(x)}{p_{data}(x) + p_{g}(x)}

理论分析

下面,对目标函数进行理论分析。即为什么要找一个D使V(D,G)最大,同时找一个G去使得V(D,G)最小。这样做为什么就可以缩小 p_{g}p_{data}的距离呢?
首先看\underset{D}{max}V(D,G),我们知道:

推导

而在原文中,上式右边的第二项,直接做了一个积分换元,得到下面的式子:


原文推导

这里的积分换元需要计算G^(-1),而G的逆并没有假设一定存在,对于此处的换元尚存疑问。
如果忽略上面的疑问,直接得到上面的等式的话,我们可以看到,\underset{D}{max}V(D,G)就是要最大化下面这个式子


这可以看做一个二分类器,当x来自于真实数据时,要使D(x)尽可能大,从而使得V较大,当x是生成的数据时,要使(1-D(x))尽可能大,也就是要使D(x)的值尽可能小。
给定G时,
p_{data}(x)
p_{g}(x)
就固定了,这时要求D(x)的最大值,可以将
p_{data}(x)
p_{g}(x)
看做常数。看做
f(D)= alog(D) + blog(1-D)

令df(D)/dD 为0时,求得D的值为:
D^{*}=\frac{a}{a+b}

也就是说,当
D^{*}(x)=\frac{p_{data}(x)}{p_{data}(x)+p_{g}(x)}
时,求得一个最大的V(D,G)
D^{*}(x)
代入V中,有:

h

最后化简出来一个JS散度,G的任务就是最小化p_{data}(x)p_{g}(x)的JS散度,使得\underset{D}{max} V(D,G)最小。也即G^{*}=arg\underset{G}{min} \underset{D}{max}V(D,G)
p_{data}(x)=p_{g}(x)p_{data}(x)p_{g}(x)的JS散度为0,此时价值函数最小,我们也就获得最优生成器。

算法流程

算法的流程可以看下图:


给定一个
G_{1}
,找到一个使V最大的D,然后固定D,找一个使得V最小的G。
实际中,因为无法算积分,所以从真实数据中采样。算法步骤如下:

(1)从P_{data}(x)中采样m个点\left \{ {x^{1},x^{2},...,x^{m}} \right \},从P_{z}(z)中采样m个点\left \{ {z^{1},z^{2},...,z^{m}} \right \}
(2)将\left \{ {z^{1},z^{2},...,z^{m}} \right \}通过生成器x=G(z),得到\left \{ {\widetilde{x}^{1},\widetilde{x}^{2},...,\widetilde{x}^{m}} \right \}
(3)更新D的参数\theta _{d}去最大化\widetilde{V} = \frac{1}{m} \sum_{1}^{m}log(D(x^{i})) + \frac{1}{m} \sum_{1}^{m}log(1-D(\widetilde{x}^{i}))
\theta_{d}=\theta_{d}+\eta\bigtriangledown\widetilde{V}(\theta_{d})
(4)再从P_{z}(z)中采样m个点\left \{ {z^{1},z^{2},...,z^{m}} \right \}
(5)更新G的参数\theta _{g}=\theta _{g} - \eta\bigtriangledown\widetilde{V}(\theta_{g})
先重复(1)到(3)步k次更新D,然后通过(4)和(5)更新一次G

接下来是GAN的tensorflow实现,代码来源于:
https://blog.csdn.net/jiongnima/article/details/80033169
数据集:
https://pan.baidu.com/s/1HFLbIzlW-CkZ5ysWco7yPA

import tensorflow as tf #导入tensorflow
from tensorflow.examples.tutorials.mnist import input_data #导入手写数字数据集
import numpy as np #导入numpy
import matplotlib.pyplot as plt #plt是绘图工具,在训练过程中用于输出可视化结果
import matplotlib.gridspec as gridspec #gridspec是图片排列工具,在训练过程中用于输出可视化结果
import os #导入os

def save(saver, sess, logdir, step): #保存模型的save函数
   model_name = 'model' #模型名前缀
   checkpoint_path = os.path.join(logdir, model_name) #保存路径
   saver.save(sess, checkpoint_path, global_step=step) #保存模型
   print('The checkpoint has been created.')
def xavier_init(size): #初始化参数时使用的xavier_init函数
    in_dim = size[0] 
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.) #初始化标准差
    return tf.random_normal(shape=size, stddev=xavier_stddev) #返回初始化的结果

X = tf.placeholder(tf.float32, shape=[None, 784]) #X表示真的样本(即真实的手写数字)

D_W1 = tf.Variable(xavier_init([784, 128])) #表示使用xavier方式初始化的判别器的D_W1参数,是一个784行128列的矩阵
D_b1 = tf.Variable(tf.zeros(shape=[128])) #表示全零方式初始化的判别器的D_1参数,是一个长度为128的向量

D_W2 = tf.Variable(xavier_init([128, 1])) #表示使用xavier方式初始化的判别器的D_W2参数,是一个128行1列的矩阵
D_b2 = tf.Variable(tf.zeros(shape=[1])) ##表示全零方式初始化的判别器的D_1参数,是一个长度为1的向量

theta_D = [D_W1, D_W2, D_b1, D_b2] #theta_D表示判别器的可训练参数集合


Z = tf.placeholder(tf.float32, shape=[None, 100]) #Z表示生成器的输入(在这里是噪声),是一个N列100行的矩阵

G_W1 = tf.Variable(xavier_init([100, 128])) #表示使用xavier方式初始化的生成器的G_W1参数,是一个100行128列的矩阵
G_b1 = tf.Variable(tf.zeros(shape=[128])) #表示全零方式初始化的生成器的G_b1参数,是一个长度为128的向量

G_W2 = tf.Variable(xavier_init([128, 784])) #表示使用xavier方式初始化的生成器的G_W2参数,是一个128行784列的矩阵
G_b2 = tf.Variable(tf.zeros(shape=[784])) #表示全零方式初始化的生成器的G_b2参数,是一个长度为784的向量

theta_G = [G_W1, G_W2, G_b1, G_b2] #theta_G表示生成器的可训练参数集合


def sample_Z(m, n): #生成维度为[m, n]的随机噪声作为生成器G的输入
    return np.random.uniform(-1., 1., size=[m, n])


def generator(z): #生成器,z的维度为[N, 100]
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1) #输入的随机噪声乘以G_W1矩阵加上偏置G_b1,G_h1维度为[N, 128]
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2 #G_h1乘以G_W2矩阵加上偏置G_b2,G_log_prob维度为[N, 784]
    G_prob = tf.nn.sigmoid(G_log_prob) #G_log_prob经过一个sigmoid函数,G_prob维度为[N, 784]

    return G_prob #返回G_prob


def discriminator(x): #判别器,x的维度为[N, 784]
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1) #输入乘以D_W1矩阵加上偏置D_b1,D_h1维度为[N, 128]
    D_logit = tf.matmul(D_h1, D_W2) + D_b2 #D_h1乘以D_W2矩阵加上偏置D_b2,D_logit维度为[N, 1]
    D_prob = tf.nn.sigmoid(D_logit) #D_logit经过一个sigmoid函数,D_prob维度为[N, 1]

    return D_prob, D_logit #返回D_prob, D_logit


def plot(samples): #保存图片时使用的plot函数
    fig = plt.figure(figsize=(4, 4)) #初始化一个4行4列包含16张子图像的图片
    gs = gridspec.GridSpec(4, 4) #调整子图的位置
    gs.update(wspace=0.05, hspace=0.05) #置子图间的间距

    for i, sample in enumerate(samples): #依次将16张子图填充进需要保存的图像
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig


G_sample = generator(Z) #取得生成器的生成结果
D_real, D_logit_real = discriminator(X) #取得判别器判别的真实手写数字的结果
D_fake, D_logit_fake = discriminator(G_sample) #取得判别器判别的生成的手写数字的结果

D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real))) #对判别器对真实样本的判别结果计算误差(将结果与1比较)
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake))) #对判别器对虚假样本(即生成器生成的手写数字)的判别结果计算误差(将结果与0比较)
D_loss = D_loss_real + D_loss_fake #判别器的误差
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake))) #生成器的误差(将判别器返回的对虚假样本的判别结果与1比较)

dreal_loss_sum = tf.summary.scalar("dreal_loss", D_loss_real) #记录判别器判别真实样本的误差
dfake_loss_sum = tf.summary.scalar("dfake_loss", D_loss_fake) #记录判别器判别虚假样本的误差
d_loss_sum = tf.summary.scalar("d_loss", D_loss) #记录判别器的误差
g_loss_sum = tf.summary.scalar("g_loss", G_loss) #记录生成器的误差

summary_writer = tf.summary.FileWriter('snapshots/', graph=tf.get_default_graph()) #日志记录器

D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) #判别器的训练器
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G) #生成器的训练器

mb_size = 128 #训练的batch_size
Z_dim = 100 #生成器输入的随机噪声的列的维度

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True) #mnist是手写数字数据集

sess = tf.Session() #会话层
sess.run(tf.global_variables_initializer()) #初始化所有可训练参数

if not os.path.exists('out/'): #初始化训练过程中的可视化结果的输出文件夹
    os.makedirs('out/')

if not os.path.exists('snapshots/'): #初始化训练过程中的模型保存文件夹
    os.makedirs('snapshots/')

saver = tf.train.Saver(var_list=tf.global_variables(),   max_to_keep=50) #模型的保存器

i = 0 #训练过程中保存的可视化结果的索引

for it in range(1000000): #训练100万次
    if it % 1000 == 0: #每训练1000次就保存一下结果
        samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})

        fig = plot(samples) #通过plot函数生成可视化结果
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight') #保存可视化结果
        i += 1
        plt.close(fig)

    X_mb, _ = mnist.train.next_batch(mb_size) #得到训练一个batch所需的真实手写数字(作为判别器的输入)

#下面是得到训练一次的结果,通过sess来run出来
_, D_loss_curr, dreal_loss_sum_value, dfake_loss_sum_value, d_loss_sum_value = sess.run([D_solver, D_loss, dreal_loss_sum, dfake_loss_sum, d_loss_sum], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
_, G_loss_curr, g_loss_sum_value = sess.run([G_solver, G_loss, g_loss_sum], feed_dict={Z: sample_Z(mb_size, Z_dim)})

if it%100 ==0: #每过100次记录一下日志,可以通过tensorboard查看
    summary_writer.add_summary(dreal_loss_sum_value, it)
    summary_writer.add_summary(dfake_loss_sum_value, it)
    summary_writer.add_summary(d_loss_sum_value, it)
    summary_writer.add_summary(g_loss_sum_value, it)

if it % 1000 == 0: #每训练1000次输出一下结果
    save(saver, sess, 'snapshots/', it)
    print('Iter: {}'.format(it))
    print('D loss: {:.4}'. format(D_loss_curr))
    print('G_loss: {:.4}'.format(G_loss_curr))
    print()
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,287评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,346评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,277评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,132评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,147评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,106评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,019评论 3 417
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,862评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,301评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,521评论 2 332
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,682评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,405评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,996评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,651评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,803评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,674评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,563评论 2 352

推荐阅读更多精彩内容