Tensorflow神经网络之GAN

生成对抗网络简介

生成对抗网络(GAN)启发自博弈论中的二人零和博弈(two-player game),类似于周伯通的绝学——“左右互搏”GAN 模型中的两位博弈方分别由生成式模型(generative model)和判别式模型(discriminative model)充当。生成模型 G 捕捉样本数据的分布,用服从某一分布(均匀分布,高斯分布等)的噪声 z 生成一个类似真实训练数据的样本,追求效果是越像真实样本越好;判别模型 D 是一个二分类器,估计一个样本来自于训练数据(而非生成数据)的概率,如果样本来自于真实的训练数据,D 输出大概率,否则,D 输出小概率。可以做如下类比:生成网络 G 好比假币制造团伙,专门制造假币,判别网络 D 好比警察,专门检测使用的货币是真币还是假币,G 的目标是想方设法生成和真币一样的货币,使得 D 判别不出来,D 的目标是想方设法检测出来 G 生成的假币。随着训练时间的增加,判别模型与生成模型的能力都相应的提升!

具体生成网络的示意图如下所示:
[图片上传失败...(image-587d3a-1551263869078)]

Tensorflow生成对抗网络实现

from __future__ import division, print_function, absolute_import

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

导入数据集

# 导入mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./data/", one_hot=True)

Extracting ./data/train-images-idx3-ubyte.gz
Extracting ./data/train-labels-idx1-ubyte.gz
Extracting ./data/t10k-images-idx3-ubyte.gz
Extracting ./data/t10k-labels-idx1-ubyte.gz

参数设置

# Training Params
num_steps = 70000 #总迭代次数
batch_size = 128  # 批量大小
learning_rate = 0.0002 #学习率

# Network Params
image_dim = 784 # 28*28 pixels,生成器的输出层节点数,也是判别器的输入
gen_hidden_dim = 256 # 生成器隐藏层节点数
disc_hidden_dim = 256 # 判别器隐藏层节点数
noise_dim = 100 # Noise data points 生成器输入节点数

# Xavier 初始化方式(更适合有ReLU的网络训练)
def glorot_init(shape):
    return tf.random_normal(shape=shape, stddev=1. / tf.sqrt(shape[0] / 2.))

Xavier 初始化方式方差:

image

这里的参数是标准差。

设置每一层的权重与偏置

# 设置每一层的权重(Xavier初始化)与偏置(初始化为零)
weights = {
    'gen_hidden1': tf.Variable(glorot_init([noise_dim, gen_hidden_dim])),#(100 - 256)
    'gen_out': tf.Variable(glorot_init([gen_hidden_dim, image_dim])), #(256 - 784)
    'disc_hidden1': tf.Variable(glorot_init([image_dim, disc_hidden_dim])),#(784 - 256)
    'disc_out': tf.Variable(glorot_init([disc_hidden_dim, 1])),#(256 - 1)
}
biases = {
    'gen_hidden1': tf.Variable(tf.zeros([gen_hidden_dim])),
    'gen_out': tf.Variable(tf.zeros([image_dim])),
    'disc_hidden1': tf.Variable(tf.zeros([disc_hidden_dim])),
    'disc_out': tf.Variable(tf.zeros([1])),
}

定义生成对抗网络

# 定义生成器函数
def generator(x):
    # 输入x是1x100的矩阵,weights['gen_hidden1']是100x256的矩阵,矩阵相乘结果是1x256的矩阵,生成器隐藏层含256个节点
    hidden_layer = tf.matmul(x, weights['gen_hidden1'])
    # biases['gen_hidden1']是1x256的矩阵,生成器隐藏层含256个节点
    hidden_layer = tf.add(hidden_layer, biases['gen_hidden1'])
    # 激活函数 relu
    hidden_layer = tf.nn.relu(hidden_layer)
    # hidden_layer是1x256的矩阵,weights['gen_out']是256x784的矩阵,矩阵相乘结果是1x784的矩阵,生成器输出层含784个节点
    out_layer = tf.matmul(hidden_layer, weights['gen_out'])
    # biases['gen_out']是1x784的矩阵,生成器输出层含784个节点
    out_layer = tf.add(out_layer, biases['gen_out'])
    # 激活函数 sigmoid
    out_layer = tf.nn.sigmoid(out_layer)
    return out_layer


# 定义判别器函数
def discriminator(x):
    # 输入x是生成器生成的1x784的矩阵(生成的图片),weights['disc_hidden1']是784x256的矩阵,矩阵相乘结果是1x256的矩阵,判别器隐藏层含256个节点
    hidden_layer = tf.matmul(x, weights['disc_hidden1'])
    # biases['disc_hidden1']是1x256的矩阵,生成器隐藏层含256个节点
    hidden_layer = tf.add(hidden_layer, biases['disc_hidden1'])
    # 激活函数 relu
    hidden_layer = tf.nn.relu(hidden_layer)
    # hidden_layer是1x256的矩阵,weights['disc_out']是256x1的矩阵,矩阵相乘结果是一个数,判别器输出层含1个节点
    out_layer = tf.matmul(hidden_layer, weights['disc_out'])
    # biases['disc_out']是一个数,判别器输出层含1个节点
    out_layer = tf.add(out_layer, biases['disc_out'])
    # 激活函数 sigmoid
    out_layer = tf.nn.sigmoid(out_layer)
    return out_layer

# 构建网络
# 网络输入
gen_input = tf.placeholder(tf.float32, shape=[None, noise_dim], name='input_noise') # 生成器 输入噪点 batch*100,none是一个空值,后面赋值batch_size
disc_input = tf.placeholder(tf.float32, shape=[None, image_dim], name='disc_input') # 判别器 输入真实图像 batch*784

# 构建生成器(generator)
gen_sample = generator(gen_input)

# 构建两个判别器(一个是真实图像输入,一个是生成图像)
disc_real = discriminator(disc_input) # 真实图像
disc_fake = discriminator(gen_sample) # 通过生成器生成的图像

# 创建损失函数
# 关于GAN的理论推导,可参见 [^1]
gen_loss = -tf.reduce_mean(tf.log(disc_fake)) # 生成器损失函数
disc_loss = -tf.reduce_mean(tf.log(disc_real) + tf.log(1. - disc_fake)) # 判别器损失函数

# 创建优化器(采用Adam方法),可参见 [^2]
optimizer_gen = tf.train.AdamOptimizer(learning_rate=learning_rate)
optimizer_disc = tf.train.AdamOptimizer(learning_rate=learning_rate)

# Training Variables for each optimizer
# By default in TensorFlow, all variables are updated by each optimizer, so we
# need to precise for each one of them the specific variables to update.
# 生成网络的变量
gen_vars = [weights['gen_hidden1'], weights['gen_out'],
            biases['gen_hidden1'], biases['gen_out']]
# 判别网络的变量
disc_vars = [weights['disc_hidden1'], weights['disc_out'],
            biases['disc_hidden1'], biases['disc_out']]

# 创建训练操作
train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)
train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)

# 变量全局初始化
init = tf.global_variables_initializer()

GAN的网络结构类似于多层感知机:

image

训练生成对抗网络

# 开始训练
# 创建一个会话
sess = tf.Session()

# 初始化
sess.run(init)

# 训练
for i in range(1, num_steps+1):
    # 准备数据
    # 拿到下一批次的 MNIST 数据 (仅需要图像, 不需要标签)
    batch_x, _ = mnist.train.next_batch(batch_size) # 判别器输入 真实图像,batch_*784
    # 给生成器生成噪点数据
    z = np.random.uniform(-1., 1., size=[batch_size, noise_dim]) # 生成器输入 噪声,batch*100

    # 训练
    feed_dict = {disc_input: batch_x, gen_input: z} #给placeholder填入值
    _, _, gl, dl = sess.run([train_gen, train_disc, gen_loss, disc_loss],
                            feed_dict=feed_dict)
    if i % 2000 == 0 or i == 1:
        print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (i, gl, dl))

    Step 1: Generator Loss: 0.223592, Discriminator Loss: 2.090910
    Step 2000: Generator Loss: 4.678916, Discriminator Loss: 0.041115
    Step 4000: Generator Loss: 3.605874, Discriminator Loss: 0.068698
    Step 6000: Generator Loss: 3.845584, Discriminator Loss: 0.190420
    Step 8000: Generator Loss: 4.470613, Discriminator Loss: 0.117488
    Step 10000: Generator Loss: 3.813103, Discriminator Loss: 0.146255
    Step 12000: Generator Loss: 2.991248, Discriminator Loss: 0.392258
    Step 14000: Generator Loss: 3.769275, Discriminator Loss: 0.153639
    Step 16000: Generator Loss: 4.366917, Discriminator Loss: 0.206618
    Step 18000: Generator Loss: 4.052875, Discriminator Loss: 0.225112
    Step 20000: Generator Loss: 3.574747, Discriminator Loss: 0.362798
    Step 22000: Generator Loss: 3.760236, Discriminator Loss: 0.188211
    Step 24000: Generator Loss: 3.055995, Discriminator Loss: 0.354645
    Step 26000: Generator Loss: 3.619049, Discriminator Loss: 0.211489
    Step 28000: Generator Loss: 3.523777, Discriminator Loss: 0.273607
    Step 30000: Generator Loss: 3.889854, Discriminator Loss: 0.286803
    Step 32000: Generator Loss: 3.106094, Discriminator Loss: 0.298111
    Step 34000: Generator Loss: 3.548391, Discriminator Loss: 0.343262
    Step 36000: Generator Loss: 3.081174, Discriminator Loss: 0.332788
    Step 38000: Generator Loss: 2.946176, Discriminator Loss: 0.335102
    Step 40000: Generator Loss: 3.078653, Discriminator Loss: 0.465524
    Step 42000: Generator Loss: 2.601799, Discriminator Loss: 0.409574
    Step 44000: Generator Loss: 3.168177, Discriminator Loss: 0.325075
    Step 46000: Generator Loss: 2.601811, Discriminator Loss: 0.428143
    Step 48000: Generator Loss: 2.853810, Discriminator Loss: 0.403768
    Step 50000: Generator Loss: 2.690175, Discriminator Loss: 0.483180
    Step 52000: Generator Loss: 3.278867, Discriminator Loss: 0.375016
    Step 54000: Generator Loss: 2.869437, Discriminator Loss: 0.477840
    Step 56000: Generator Loss: 2.561056, Discriminator Loss: 0.449300
    Step 58000: Generator Loss: 2.814199, Discriminator Loss: 0.484522
    Step 60000: Generator Loss: 2.469474, Discriminator Loss: 0.428359
    Step 62000: Generator Loss: 2.721684, Discriminator Loss: 0.494090
    Step 64000: Generator Loss: 2.491284, Discriminator Loss: 0.654795
    Step 66000: Generator Loss: 2.725388, Discriminator Loss: 0.423149
    Step 68000: Generator Loss: 2.758215, Discriminator Loss: 0.513224
    Step 70000: Generator Loss: 3.072056, Discriminator Loss: 0.481437

测试

# 测试
# 通过训练出的生成网络输入噪点,生成图像
n = 6
canvas = np.empty((28 * n, 28 * n))
for i in range(n):
    # 噪点输入
    z = np.random.uniform(-1., 1., size=[n, noise_dim])
    # 生成图像
    g = sess.run(gen_sample, feed_dict={gen_input: z})
    # 颜色反转便于显示
    g = -1 * (g - 1)
    for j in range(n):
        # 绘制生成的手写体数字
        canvas[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = g[j].reshape([28, 28])

plt.figure(figsize=(n, n))
plt.imshow(canvas, origin="upper", cmap="gray")
plt.show()
image

参考

[1] 机器之心GitHub项目:GAN完整理论推导与实现,Perfect!

[2] 深度学习最全优化方法总结比较(SGD,Adagrad,Adadelta,Adam,Adamax,Nadam)

[3] 不要怂,就是GAN (生成式对抗网络) (一): GAN 简介

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,997评论 6 502
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,603评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 163,359评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,309评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,346评论 6 390
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,258评论 1 300
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,122评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,970评论 0 275
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,403评论 1 313
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,596评论 3 334
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,769评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,464评论 5 344
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,075评论 3 327
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,705评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,848评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,831评论 2 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,678评论 2 354

推荐阅读更多精彩内容