GAN:对抗性生成网络,通俗来讲,即有两个网络一个是g(generator )网络,用于生成,一个是d(discriminator)网络,用于判断。
GAN网络的目的就是使其自己生成一副图片,比如说经过对一系列猫的图片的处理,g网络可以自己“绘制”出一张猫的图片,且尽量真实。
d网络则是用来进行判断的,将一张真实的图片和一张由g网络生成的照片同时交给d网络,不断训练d网络,使其可以准确判断,将d网络生成的“假图片”找出来。
再回到两个网络上,g网络不断改进使其可以骗过d网络,而d网络不断改进使其可以更准确找到“假图片”,这种相互促进相互对抗的关系,就叫做对抗网络。
我们可以使用tensorflow中的mnist手写体数据来进行实现。
实现原理如下:
将一张随机像素的图片经过一个全连接层后经过一个Leaky ReLU处理,之后为了避免过拟合dropout后再经过一个全连接层进行tanh激活后,生成一张“假图片”
def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
with tf.variable_scope("generator", reuse=reuse):
hidden1 = tf.layers.dense(noise_img, n_units) # 全连接层
hidden1 = tf.maximum(alpha * hidden1, hidden1)
hidden1 = tf.layers.dropout(hidden1, rate=0.2)
logits = tf.layers.dense(hidden1, out_dim)
outputs = tf.tanh(logits)
return logits, outputs
将待判定的图片经过全连接层-->Leaky ReLU-->全连接层-->sigmoid激活函数处理后,得到0或1的结果。
def get_discriminator(img, n_units, reuse=False, alpha=0.01):
with tf.variable_scope("discriminator", reuse=reuse):
hidden1 = tf.layers.dense(img, n_units)
hidden1 = tf.maximum(alpha * hidden1, hidden1)
logits = tf.layers.dense(hidden1, 1)
outputs = tf.sigmoid(logits)
return logits, outputs
在实现时,我们可以首先把MNIST数据中的标签为0的图像提取出来,存到列表中。
i = j = 0
while i<5000:
if mnist.train.labels[j] == 0:
samples.append(mnist.train.images[j])
i += 1
j += 1
这样就可以在训练时只训练标签为0的图像。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import numpy as np
mnist = input_data.read_data_sets("D:/python/MNIST_data/")
img = mnist.train.images[50]
def get_inputs(real_size, noise_size):
real_img = tf.placeholder(tf.float32, [None, real_size], name="real_img")
noise_img = tf.placeholder(tf.float32, [None, noise_size], name="noise_img")
return real_img, noise_img
# 生成
def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
with tf.variable_scope("generator", reuse=reuse):
hidden1 = tf.layers.dense(noise_img, n_units) # 全连接层
hidden1 = tf.maximum(alpha * hidden1, hidden1)
hidden1 = tf.layers.dropout(hidden1, rate=0.2)
logits = tf.layers.dense(hidden1, out_dim)
outputs = tf.tanh(logits)
return logits, outputs
# 判别
def get_discriminator(img, n_units, reuse=False, alpha=0.01):
with tf.variable_scope("discriminator", reuse=reuse):
hidden1 = tf.layers.dense(img, n_units)
hidden1 = tf.maximum(alpha * hidden1, hidden1)
logits = tf.layers.dense(hidden1, 1)
outputs = tf.sigmoid(logits)
return logits, outputs
img_size = mnist.train.images[0].shape[0]
noise_size = 100
g_units = 128
d_units = 128
alpha = 0.01
learning_rate = 0.001
smooth = 0.1
tf.reset_default_graph()
real_img, noise_img = get_inputs(img_size, noise_size)
g_logits, g_outputs = get_generator(noise_img, g_units, img_size)
d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True)
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=d_logits_real, labels=tf.ones_like(d_logits_real)
) * (1 - smooth))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)
))
d_loss = tf.add(d_loss_real, d_loss_fake)
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)
) * (1 - smooth))
train_vars = tf.trainable_variables()
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
epochs = 5000
samples = []
n_sample = 10
losses = []
i = j = 0
while i<5000:
if mnist.train.labels[j] == 0:
samples.append(mnist.train.images[j])
i += 1
j += 1
print(len(samples))
size = samples[0].size
with tf.Session() as sess:
tf.global_variables_initializer().run()
for e in range(epochs):
batch_images = samples[e] * 2 -1
batch_noise = np.random.uniform(-1, 1, size=noise_size)
_ = sess.run(d_train_opt, feed_dict={real_img:[batch_images], noise_img:[batch_noise]})
_ = sess.run(g_train_opt, feed_dict={noise_img:[batch_noise]})
sample_noise = np.random.uniform(-1, 1, size=noise_size)
g_logit, g_output = sess.run(get_generator(noise_img, g_units, img_size,
reuse=True), feed_dict={
noise_img:[sample_noise]
})
print(g_logit.size)
g_output = (g_output+1)/2
plt.imshow(g_output.reshape([28, 28]), cmap='Greys_r')
plt.show()
运行结果:
可以看出,在经过了5000次的迭代后,g网络生成的图片已经可以大致呈现出一个0的形状。