简单易用的轻量级生成对抗网络工具库:TFGAN

TFGAN是谷歌开源的一个轻量级生成对抗网络(GAN)工具库,它为开发者轻松训练 GAN 提供了基础条件,提供经过完整测试的损失函数和评估指标,同时提供易于使用的范例,这些范例展示了 TFGAN 的表达能力和灵活性。这个库被包含在了TensorFlow contrib中,可以直接通过tf来进行使用,本文通过一个简单的unconditional gan模型在MNIST数据集上进行演示。

Githubhttps://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan

环境

  • Python 3.6
  • Tensorflow-gpu 1.8.0

GAN

生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。原始 GAN 理论中,并不要求 G 和 D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为 G 和 D 。一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。

GAN

在训练的过程中固定一方,更新另一方的网络权重,交替迭代,在这个过程中,双方都极力优化自己的网络,从而形成竞争对抗,直到双方达到一个动态的平衡(纳什均衡),此时生成模型 G 恢复了训练数据的分布(造出了和真实数据一模一样的样本),判别模型再也判别不出来结果,准确率为 50%,约等于乱猜。
当固定生成网络 G 的时候,对于判别网络 D 的优化,可以这样理解:输入来自于真实数据,D 优化网络结构使自己输出 1,输入来自于生成数据,D 优化网络结构使自己输出 0;当固定判别网络 D 的时候,G 优化自己的网络使自己输出尽可能和真实数据一样的样本,并且使得生成的样本经过 D 的判别之后,D 输出高概率。

上述过程可以表述为如下公式:

TFGAN

TFGAN中的训练通常包括以下步骤:

1.指定网络的输入。
2.使用GANModel设置生成器和鉴别器。
3.使用GANLoss指定损失。
4.使用GANTrainOps创建训练操作。
5.开始训练模型。

数据

首先读入MNIST数据作为输入数据x,如下所示:

import tensorflow as tf
import tensorflow.contrib.gan as tfgan
import tensorflow.contrib.layers as layers
from tensorflow.examples.tutorials.mnist import input_data


def provide_data(batch_size, num_threads=1):
    file = "MNIST"
    # range 0~1
    mnist = input_data.read_data_sets(file, one_hot=True)

    train_data = mnist.train.images.reshape(-1, 28, 28, 1) * 255
    train_labels = mnist.train.labels

    # transfer to -1~1
    train_data = (tf.to_float(train_data) - 128.0) / 128.0

    # Creates a QueueRunner for the pre-fetching operation.
    input_queue = tf.train.slice_input_producer([train_data, train_labels], shuffle=True)
    images, labels = tf.train.batch(
            input_queue,
            batch_size=batch_size,
            num_threads=num_threads,
            capacity=5 * batch_size)

    return images, labels

生成随机噪声作为输入数据z,如下所示:

    images, _ = provide_data(batch_size, num_threads=2)
    noise = tf.random_normal([batch_size, 64])

定义模型

首先我们需要定义生成器(generator)和鉴别器(discriminator)。

generator定义如下所示,将一个一维的随机噪声通过反卷积生成通道数为1的图片数据,使用tanh是为了保持生成数据的范围与输入数据一致:

def unconditional_generator(noise, weight_decay=2.5e-5, is_training=True):
    """Simple generator to produce MNIST images.

    Args:
        noise: A single Tensor representing noise.
        weight_decay: The value of the l2 weight decay.
        is_training: If `True`, batch norm uses batch statistics. If `False`, batch
            norm uses the exponential moving average collected from population 
            statistics.

    Returns:
        A generated image in the range [-1, 1].
    """
    with tf.contrib.framework.arg_scope(
        [layers.fully_connected, layers.conv2d_transpose],
        activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
        weights_regularizer=layers.l2_regularizer(weight_decay)):
        with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training,
                        zero_debias_moving_mean=True):

            net = layers.fully_connected(noise, 1024)
            net = layers.fully_connected(net, 7 * 7 * 128)
            net = tf.reshape(net, [-1, 7, 7, 128])
            net = layers.conv2d_transpose(net, 64, [4, 4], stride=2)
            net = layers.conv2d_transpose(net, 32, [4, 4], stride=2)
            # Make sure that generator output is in the same range as `inputs`
            # ie [-1, 1].
            net = layers.conv2d(net, 1, [4, 4], normalizer_fn=None, activation_fn=tf.tanh)

            return net

discriminator的定义如下,是一个比较简单的二分类网络,用来判断输入的数据是生成的还是真实的:

def unconditional_discriminator(img, unused_conditioning, weight_decay=2.5e-5,
                     is_training=True):
    """Discriminator network on MNIST digits.

    Args:
        img: Real or generated MNIST digits. Should be in the range [-1, 1].
        unused_conditioning: The TFGAN API can help with conditional GANs, which
            would require extra `condition` information to both the generator and the
            discriminator. Since this example is not conditional, we do not use this
            argument.
        weight_decay: The L2 weight decay.
        is_training: If `True`, batch norm uses batch statistics. If `False`, batch
            norm uses the exponential moving average collected from population 
            statistics.

    Returns:
        Logits for the probability that the image is real.
    """
    with tf.contrib.framework.arg_scope(
        [layers.conv2d, layers.fully_connected],
        activation_fn=tf.nn.relu, normalizer_fn=None,
        weights_regularizer=layers.l2_regularizer(weight_decay),
        biases_regularizer=layers.l2_regularizer(weight_decay)):

        net = layers.conv2d(img, 64, [4, 4], stride=2)
        net = layers.conv2d(net, 128, [4, 4], stride=2)
        net = layers.flatten(net)

        with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training):
            net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm)

        return layers.linear(net, 1)

然后使用TFGAN定义一个GAN模型:

    with tf.name_scope('model'):
        # Build the generator and discriminator.
        gan_model = tfgan.gan_model(
            generator_fn=unconditional_generator,  # you define 
            discriminator_fn=unconditional_discriminator,  # you define
            real_data=images,
            generator_inputs=noise)

设置损失函数

使用TFGAN自带的损失函数配置模型,如下所示:

    with tf.name_scope('loss'):
        # Build the GAN loss.
        gan_loss = tfgan.gan_loss(
            gan_model,
            generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
            discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
            gradient_penalty_weight=1.0,
            add_summaries=True)

同时可以进行自定义的损失函数,损失函数接收一个gan_model,对gan_model的output进行损失计算。如下所示:

def silly_custom_generator_loss(gan_model, add_summaries=False):
    return tf.reduce_mean(gan_model.discriminator_gen_outputs)

def silly_custom_discriminator_loss(gan_model, add_summaries=False):
    return (tf.reduce_mean(gan_model.discriminator_gen_outputs) -
            tf.reduce_mean(gan_model.discriminator_real_outputs))

custom_gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=silly_custom_generator_loss,
    discriminator_loss_fn=silly_custom_discriminator_loss)

配置训练操作

接下来使用TFGAN来配置训练操作,制定模型、损失、优化器、训练率等参数。关于check_for_unused_update_ops这个参数,由于batch_norm层的原因,如果设置为True会导致更新参数检查不一致,因此需要设置为False。

    with tf.name_scope('train'):
        # Create the train ops, which calculate gradients and apply updates to weights.
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
            discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
            check_for_unused_update_ops=False,
            summarize_gradients=True,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

训练模型

通过TFGAN指定训练操作与保存的文件夹,就可以很容易的开始训练过程。如下所示:

    # Run the train ops in the alternating training scheme.
    tfgan.gan_train(
        train_ops,
        hooks=[tf.train.StopAtStepHook(num_steps=max_steps)],
        logdir=train_log_dir,
        save_summaries_steps=10)

可视化评估

def test(eval_dir, checkpoint_dir):
    tf.reset_default_graph()

    if not tf.gfile.Exists(eval_dir):
        tf.gfile.MakeDirs(eval_dir)

    random_inputs = tf.random_normal([100, 64])

    with tf.variable_scope('Generator'):
        images = unconditional_generator(random_inputs, is_training=False)

    reshaped_images = tfgan.eval.image_reshaper(images[:100, ...], num_cols=10)
    uint8_images = float_image_to_uint8(reshaped_images)

    image_write_ops = tf.write_file(
          '%s/%s' % (eval_dir, 'unconditional_gan.png'),
          tf.image.encode_png(uint8_images[0]))

    tf.contrib.training.evaluate_repeatedly(
            checkpoint_dir,
            eval_ops=image_write_ops,
            hooks=[tf.contrib.training.StopAfterNEvalsHook(1)],
            max_number_of_evaluations=1)

不同Epoch阶段生成的效果如下:

Epoch:400
Epoch:1000
Epoch:2000
Epoch:3000
Epoch:4000
Epoch:5000
Loss

PS:
在实验中发现生成器的学习率对生成效果有着很大的影响,最初参考官方文档使用1e-3的学习率,发现在Epoch:200左右生成了较为模糊的数字图片,进一步训练生成图片反而全是黑色背景,参考Loss变化发现整个训练过程非常不稳定。后面改用1e-4的学习率,才达到当前的生成效果。

完整的unconditional gan代码如下所示,也可以参考官方的tutorial

import tensorflow as tf
import tensorflow.contrib.gan as tfgan
import tensorflow.contrib.layers as layers
from tensorflow.examples.tutorials.mnist import input_data


def float_image_to_uint8(image):
    """Convert float image in [-1, 1) to [0, 255] uint8.
    Note that `1` gets mapped to `0`, but `1 - epsilon` gets mapped to 255.
    Args:
        image: An image tensor. Values should be in [-1, 1).
    Returns:
        Input image cast to uint8 and with integer values in [0, 255].
    """
    image = (image * 128.0) + 128.0

    return tf.cast(image, tf.uint8)


def provide_data(batch_size, num_threads=1):
    file = "MNIST"
    # range 0~1
    mnist = input_data.read_data_sets(file, one_hot=True)

    train_data = mnist.train.images.reshape(-1, 28, 28, 1) * 255
    train_labels = mnist.train.labels

    # transfer to -1~1
    train_data = (tf.to_float(train_data) - 128.0) / 128.0

    # Creates a QueueRunner for the pre-fetching operation.
    input_queue = tf.train.slice_input_producer([train_data, train_labels], shuffle=True)
    images, labels = tf.train.batch(
            input_queue,
            batch_size=batch_size,
            num_threads=num_threads,
            capacity=5 * batch_size)

    return images, labels


def unconditional_generator(noise, weight_decay=2.5e-5, is_training=True):
    """Simple generator to produce MNIST images.

    Args:
        noise: A single Tensor representing noise.
        weight_decay: The value of the l2 weight decay.
        is_training: If `True`, batch norm uses batch statistics. If `False`, batch
            norm uses the exponential moving average collected from population 
            statistics.

    Returns:
        A generated image in the range [-1, 1].
    """
    with tf.contrib.framework.arg_scope(
        [layers.fully_connected, layers.conv2d_transpose],
        activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
        weights_regularizer=layers.l2_regularizer(weight_decay)):
        with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training,
                        zero_debias_moving_mean=True):

            net = layers.fully_connected(noise, 1024)
            net = layers.fully_connected(net, 7 * 7 * 128)
            net = tf.reshape(net, [-1, 7, 7, 128])
            net = layers.conv2d_transpose(net, 64, [4, 4], stride=2)
            net = layers.conv2d_transpose(net, 32, [4, 4], stride=2)
            # Make sure that generator output is in the same range as `inputs`
            # ie [-1, 1].
            net = layers.conv2d(net, 1, [4, 4], normalizer_fn=None, activation_fn=tf.tanh)

            return net


def unconditional_discriminator(img, unused_conditioning, weight_decay=2.5e-5,
                     is_training=True):
    """Discriminator network on MNIST digits.

    Args:
        img: Real or generated MNIST digits. Should be in the range [-1, 1].
        unused_conditioning: The TFGAN API can help with conditional GANs, which
            would require extra `condition` information to both the generator and the
            discriminator. Since this example is not conditional, we do not use this
            argument.
        weight_decay: The L2 weight decay.
        is_training: If `True`, batch norm uses batch statistics. If `False`, batch
            norm uses the exponential moving average collected from population 
            statistics.

    Returns:
        Logits for the probability that the image is real.
    """
    with tf.contrib.framework.arg_scope(
        [layers.conv2d, layers.fully_connected],
        activation_fn=tf.nn.relu, normalizer_fn=None,
        weights_regularizer=layers.l2_regularizer(weight_decay),
        biases_regularizer=layers.l2_regularizer(weight_decay)):

        net = layers.conv2d(img, 64, [4, 4], stride=2)
        net = layers.conv2d(net, 128, [4, 4], stride=2)
        net = layers.flatten(net)

        with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training):
            net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm)

        return layers.linear(net, 1)


def train(batch_size, max_steps, gen_lr, dis_lr, train_log_dir):
    tf.reset_default_graph()

    if not tf.gfile.Exists(train_log_dir):
        tf.gfile.MakeDirs(train_log_dir)

    # Set up the input.
    images, _ = provide_data(batch_size)
    noise = tf.random_normal([batch_size, 64])

    with tf.name_scope('model'):
        # Build the generator and discriminator.
        gan_model = tfgan.gan_model(
            generator_fn=unconditional_generator,  # you define 
            discriminator_fn=unconditional_discriminator,  # you define
            real_data=images,
            generator_inputs=noise)

    with tf.name_scope('loss'):
        # Build the GAN loss.
        gan_loss = tfgan.gan_loss(
            gan_model,
            generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
            discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
            gradient_penalty_weight=1.0,
            add_summaries=True)

    with tf.name_scope('train'):
        # Create the train ops, which calculate gradients and apply updates to weights.
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
            discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
            check_for_unused_update_ops=False,
            summarize_gradients=True,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    # Run the train ops in the alternating training scheme.
    tfgan.gan_train(
        train_ops,
        hooks=[tf.train.StopAtStepHook(num_steps=max_steps)],
        logdir=train_log_dir,
        save_summaries_steps=10)


def test(eval_dir, checkpoint_dir):
    tf.reset_default_graph()

    if not tf.gfile.Exists(eval_dir):
        tf.gfile.MakeDirs(eval_dir)

    random_inputs = tf.random_normal([100, 64])

    with tf.variable_scope('Generator'):
        images = unconditional_generator(random_inputs, is_training=False)

    reshaped_images = tfgan.eval.image_reshaper(images[:100, ...], num_cols=10)
    uint8_images = float_image_to_uint8(reshaped_images)

    image_write_ops = tf.write_file(
          '%s/%s' % (eval_dir, 'unconditional_gan.png'),
          tf.image.encode_png(uint8_images[0]))

    tf.contrib.training.evaluate_repeatedly(
            checkpoint_dir,
            eval_ops=image_write_ops,
            hooks=[tf.contrib.training.StopAfterNEvalsHook(1)],
            max_number_of_evaluations=1)


if __name__ == '__main__':
    train(16, 5000, 1e-4, 1e-4, 'logs/')
    test('eval/', 'logs/')

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,384评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,845评论 3 391
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,148评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,640评论 1 290
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,731评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,712评论 1 294
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,703评论 3 415
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,473评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,915评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,227评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,384评论 1 345
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,063评论 5 340
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,706评论 3 324
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,302评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,531评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,321评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,248评论 2 352

推荐阅读更多精彩内容