TFGAN是谷歌开源的一个轻量级生成对抗网络(GAN)工具库,它为开发者轻松训练 GAN 提供了基础条件,提供经过完整测试的损失函数和评估指标,同时提供易于使用的范例,这些范例展示了 TFGAN 的表达能力和灵活性。这个库被包含在了TensorFlow contrib中,可以直接通过tf来进行使用,本文通过一个简单的unconditional gan模型在MNIST数据集上进行演示。
Github:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan
环境
- Python 3.6
- Tensorflow-gpu 1.8.0
GAN
生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。原始 GAN 理论中,并不要求 G 和 D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为 G 和 D 。一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。
在训练的过程中固定一方,更新另一方的网络权重,交替迭代,在这个过程中,双方都极力优化自己的网络,从而形成竞争对抗,直到双方达到一个动态的平衡(纳什均衡),此时生成模型 G 恢复了训练数据的分布(造出了和真实数据一模一样的样本),判别模型再也判别不出来结果,准确率为 50%,约等于乱猜。
当固定生成网络 G 的时候,对于判别网络 D 的优化,可以这样理解:输入来自于真实数据,D 优化网络结构使自己输出 1,输入来自于生成数据,D 优化网络结构使自己输出 0;当固定判别网络 D 的时候,G 优化自己的网络使自己输出尽可能和真实数据一样的样本,并且使得生成的样本经过 D 的判别之后,D 输出高概率。
上述过程可以表述为如下公式:
TFGAN
TFGAN中的训练通常包括以下步骤:
1.指定网络的输入。
2.使用GANModel设置生成器和鉴别器。
3.使用GANLoss指定损失。
4.使用GANTrainOps创建训练操作。
5.开始训练模型。
数据
首先读入MNIST数据作为输入数据x,如下所示:
import tensorflow as tf
import tensorflow.contrib.gan as tfgan
import tensorflow.contrib.layers as layers
from tensorflow.examples.tutorials.mnist import input_data
def provide_data(batch_size, num_threads=1):
file = "MNIST"
# range 0~1
mnist = input_data.read_data_sets(file, one_hot=True)
train_data = mnist.train.images.reshape(-1, 28, 28, 1) * 255
train_labels = mnist.train.labels
# transfer to -1~1
train_data = (tf.to_float(train_data) - 128.0) / 128.0
# Creates a QueueRunner for the pre-fetching operation.
input_queue = tf.train.slice_input_producer([train_data, train_labels], shuffle=True)
images, labels = tf.train.batch(
input_queue,
batch_size=batch_size,
num_threads=num_threads,
capacity=5 * batch_size)
return images, labels
生成随机噪声作为输入数据z,如下所示:
images, _ = provide_data(batch_size, num_threads=2)
noise = tf.random_normal([batch_size, 64])
定义模型
首先我们需要定义生成器(generator)和鉴别器(discriminator)。
generator定义如下所示,将一个一维的随机噪声通过反卷积生成通道数为1的图片数据,使用tanh是为了保持生成数据的范围与输入数据一致:
def unconditional_generator(noise, weight_decay=2.5e-5, is_training=True):
"""Simple generator to produce MNIST images.
Args:
noise: A single Tensor representing noise.
weight_decay: The value of the l2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
A generated image in the range [-1, 1].
"""
with tf.contrib.framework.arg_scope(
[layers.fully_connected, layers.conv2d_transpose],
activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
weights_regularizer=layers.l2_regularizer(weight_decay)):
with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training,
zero_debias_moving_mean=True):
net = layers.fully_connected(noise, 1024)
net = layers.fully_connected(net, 7 * 7 * 128)
net = tf.reshape(net, [-1, 7, 7, 128])
net = layers.conv2d_transpose(net, 64, [4, 4], stride=2)
net = layers.conv2d_transpose(net, 32, [4, 4], stride=2)
# Make sure that generator output is in the same range as `inputs`
# ie [-1, 1].
net = layers.conv2d(net, 1, [4, 4], normalizer_fn=None, activation_fn=tf.tanh)
return net
discriminator的定义如下,是一个比较简单的二分类网络,用来判断输入的数据是生成的还是真实的:
def unconditional_discriminator(img, unused_conditioning, weight_decay=2.5e-5,
is_training=True):
"""Discriminator network on MNIST digits.
Args:
img: Real or generated MNIST digits. Should be in the range [-1, 1].
unused_conditioning: The TFGAN API can help with conditional GANs, which
would require extra `condition` information to both the generator and the
discriminator. Since this example is not conditional, we do not use this
argument.
weight_decay: The L2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
Logits for the probability that the image is real.
"""
with tf.contrib.framework.arg_scope(
[layers.conv2d, layers.fully_connected],
activation_fn=tf.nn.relu, normalizer_fn=None,
weights_regularizer=layers.l2_regularizer(weight_decay),
biases_regularizer=layers.l2_regularizer(weight_decay)):
net = layers.conv2d(img, 64, [4, 4], stride=2)
net = layers.conv2d(net, 128, [4, 4], stride=2)
net = layers.flatten(net)
with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training):
net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm)
return layers.linear(net, 1)
然后使用TFGAN定义一个GAN模型:
with tf.name_scope('model'):
# Build the generator and discriminator.
gan_model = tfgan.gan_model(
generator_fn=unconditional_generator, # you define
discriminator_fn=unconditional_discriminator, # you define
real_data=images,
generator_inputs=noise)
设置损失函数
使用TFGAN自带的损失函数配置模型,如下所示:
with tf.name_scope('loss'):
# Build the GAN loss.
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
gradient_penalty_weight=1.0,
add_summaries=True)
同时可以进行自定义的损失函数,损失函数接收一个gan_model,对gan_model的output进行损失计算。如下所示:
def silly_custom_generator_loss(gan_model, add_summaries=False):
return tf.reduce_mean(gan_model.discriminator_gen_outputs)
def silly_custom_discriminator_loss(gan_model, add_summaries=False):
return (tf.reduce_mean(gan_model.discriminator_gen_outputs) -
tf.reduce_mean(gan_model.discriminator_real_outputs))
custom_gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=silly_custom_generator_loss,
discriminator_loss_fn=silly_custom_discriminator_loss)
配置训练操作
接下来使用TFGAN来配置训练操作,制定模型、损失、优化器、训练率等参数。关于check_for_unused_update_ops这个参数,由于batch_norm层的原因,如果设置为True会导致更新参数检查不一致,因此需要设置为False。
with tf.name_scope('train'):
# Create the train ops, which calculate gradients and apply updates to weights.
train_ops = tfgan.gan_train_ops(
gan_model,
gan_loss,
generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
check_for_unused_update_ops=False,
summarize_gradients=True,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
训练模型
通过TFGAN指定训练操作与保存的文件夹,就可以很容易的开始训练过程。如下所示:
# Run the train ops in the alternating training scheme.
tfgan.gan_train(
train_ops,
hooks=[tf.train.StopAtStepHook(num_steps=max_steps)],
logdir=train_log_dir,
save_summaries_steps=10)
可视化评估
def test(eval_dir, checkpoint_dir):
tf.reset_default_graph()
if not tf.gfile.Exists(eval_dir):
tf.gfile.MakeDirs(eval_dir)
random_inputs = tf.random_normal([100, 64])
with tf.variable_scope('Generator'):
images = unconditional_generator(random_inputs, is_training=False)
reshaped_images = tfgan.eval.image_reshaper(images[:100, ...], num_cols=10)
uint8_images = float_image_to_uint8(reshaped_images)
image_write_ops = tf.write_file(
'%s/%s' % (eval_dir, 'unconditional_gan.png'),
tf.image.encode_png(uint8_images[0]))
tf.contrib.training.evaluate_repeatedly(
checkpoint_dir,
eval_ops=image_write_ops,
hooks=[tf.contrib.training.StopAfterNEvalsHook(1)],
max_number_of_evaluations=1)
不同Epoch阶段生成的效果如下:
PS:
在实验中发现生成器的学习率对生成效果有着很大的影响,最初参考官方文档使用1e-3的学习率,发现在Epoch:200左右生成了较为模糊的数字图片,进一步训练生成图片反而全是黑色背景,参考Loss变化发现整个训练过程非常不稳定。后面改用1e-4的学习率,才达到当前的生成效果。
完整的unconditional gan代码如下所示,也可以参考官方的tutorial。
import tensorflow as tf
import tensorflow.contrib.gan as tfgan
import tensorflow.contrib.layers as layers
from tensorflow.examples.tutorials.mnist import input_data
def float_image_to_uint8(image):
"""Convert float image in [-1, 1) to [0, 255] uint8.
Note that `1` gets mapped to `0`, but `1 - epsilon` gets mapped to 255.
Args:
image: An image tensor. Values should be in [-1, 1).
Returns:
Input image cast to uint8 and with integer values in [0, 255].
"""
image = (image * 128.0) + 128.0
return tf.cast(image, tf.uint8)
def provide_data(batch_size, num_threads=1):
file = "MNIST"
# range 0~1
mnist = input_data.read_data_sets(file, one_hot=True)
train_data = mnist.train.images.reshape(-1, 28, 28, 1) * 255
train_labels = mnist.train.labels
# transfer to -1~1
train_data = (tf.to_float(train_data) - 128.0) / 128.0
# Creates a QueueRunner for the pre-fetching operation.
input_queue = tf.train.slice_input_producer([train_data, train_labels], shuffle=True)
images, labels = tf.train.batch(
input_queue,
batch_size=batch_size,
num_threads=num_threads,
capacity=5 * batch_size)
return images, labels
def unconditional_generator(noise, weight_decay=2.5e-5, is_training=True):
"""Simple generator to produce MNIST images.
Args:
noise: A single Tensor representing noise.
weight_decay: The value of the l2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
A generated image in the range [-1, 1].
"""
with tf.contrib.framework.arg_scope(
[layers.fully_connected, layers.conv2d_transpose],
activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
weights_regularizer=layers.l2_regularizer(weight_decay)):
with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training,
zero_debias_moving_mean=True):
net = layers.fully_connected(noise, 1024)
net = layers.fully_connected(net, 7 * 7 * 128)
net = tf.reshape(net, [-1, 7, 7, 128])
net = layers.conv2d_transpose(net, 64, [4, 4], stride=2)
net = layers.conv2d_transpose(net, 32, [4, 4], stride=2)
# Make sure that generator output is in the same range as `inputs`
# ie [-1, 1].
net = layers.conv2d(net, 1, [4, 4], normalizer_fn=None, activation_fn=tf.tanh)
return net
def unconditional_discriminator(img, unused_conditioning, weight_decay=2.5e-5,
is_training=True):
"""Discriminator network on MNIST digits.
Args:
img: Real or generated MNIST digits. Should be in the range [-1, 1].
unused_conditioning: The TFGAN API can help with conditional GANs, which
would require extra `condition` information to both the generator and the
discriminator. Since this example is not conditional, we do not use this
argument.
weight_decay: The L2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
Logits for the probability that the image is real.
"""
with tf.contrib.framework.arg_scope(
[layers.conv2d, layers.fully_connected],
activation_fn=tf.nn.relu, normalizer_fn=None,
weights_regularizer=layers.l2_regularizer(weight_decay),
biases_regularizer=layers.l2_regularizer(weight_decay)):
net = layers.conv2d(img, 64, [4, 4], stride=2)
net = layers.conv2d(net, 128, [4, 4], stride=2)
net = layers.flatten(net)
with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training):
net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm)
return layers.linear(net, 1)
def train(batch_size, max_steps, gen_lr, dis_lr, train_log_dir):
tf.reset_default_graph()
if not tf.gfile.Exists(train_log_dir):
tf.gfile.MakeDirs(train_log_dir)
# Set up the input.
images, _ = provide_data(batch_size)
noise = tf.random_normal([batch_size, 64])
with tf.name_scope('model'):
# Build the generator and discriminator.
gan_model = tfgan.gan_model(
generator_fn=unconditional_generator, # you define
discriminator_fn=unconditional_discriminator, # you define
real_data=images,
generator_inputs=noise)
with tf.name_scope('loss'):
# Build the GAN loss.
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
gradient_penalty_weight=1.0,
add_summaries=True)
with tf.name_scope('train'):
# Create the train ops, which calculate gradients and apply updates to weights.
train_ops = tfgan.gan_train_ops(
gan_model,
gan_loss,
generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
check_for_unused_update_ops=False,
summarize_gradients=True,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
# Run the train ops in the alternating training scheme.
tfgan.gan_train(
train_ops,
hooks=[tf.train.StopAtStepHook(num_steps=max_steps)],
logdir=train_log_dir,
save_summaries_steps=10)
def test(eval_dir, checkpoint_dir):
tf.reset_default_graph()
if not tf.gfile.Exists(eval_dir):
tf.gfile.MakeDirs(eval_dir)
random_inputs = tf.random_normal([100, 64])
with tf.variable_scope('Generator'):
images = unconditional_generator(random_inputs, is_training=False)
reshaped_images = tfgan.eval.image_reshaper(images[:100, ...], num_cols=10)
uint8_images = float_image_to_uint8(reshaped_images)
image_write_ops = tf.write_file(
'%s/%s' % (eval_dir, 'unconditional_gan.png'),
tf.image.encode_png(uint8_images[0]))
tf.contrib.training.evaluate_repeatedly(
checkpoint_dir,
eval_ops=image_write_ops,
hooks=[tf.contrib.training.StopAfterNEvalsHook(1)],
max_number_of_evaluations=1)
if __name__ == '__main__':
train(16, 5000, 1e-4, 1e-4, 'logs/')
test('eval/', 'logs/')