一维变量的GAN

import tensorflow as tf


# tf.config.run_functions_eagerly(True)  # 调试时使用

class GAN(tf.keras.Model):
    def __init__(self, data_dim, latent_dim, hidden_dims1, hidden_dims2):
        super(GAN, self).__init__()
        self.data_dim = data_dim
        self.latent_dim = latent_dim
        self.hidden_dims1 = hidden_dims1
        self.hidden_dims2 = hidden_dims2
        # 定义判别器
        input_d = tf.keras.layers.Input(shape=(data_dim,))
        if len(hidden_dims1) != 0:
            discriminator = tf.keras.layers.Dense(hidden_dims1[0], activation='relu')(input_d)
            for dim in hidden_dims1[1:]:
                discriminator = tf.keras.layers.Dense(dim, activation='relu')(discriminator)
            discriminator = tf.keras.layers.Dense(1, activation='sigmoid')(discriminator)
        else:
            discriminator = tf.keras.layers.Dense(1, activation='sigmoid')(input_d)
        self.discriminator = tf.keras.Model(input_d, discriminator, name='discriminator')
        self.discriminator.summary()
        # 定义生成器
        input_g = tf.keras.layers.Input(shape=(latent_dim,))  # 生成
        if len(hidden_dims2) != 0:
            generator = tf.keras.layers.Dense(hidden_dims2[0], activation='relu')(input_g)
            for dim in hidden_dims2[1:]:
                generator = tf.keras.layers.Dense(dim, activation='relu')(generator)
            generator = tf.keras.layers.Dense(data_dim, activation='sigmoid')(generator)
        else:
            generator = tf.keras.layers.Dense(data_dim, activation='sigmoid')(input_g)
        self.generator = tf.keras.Model(input_g, generator, name='generator')
        self.generator.summary()
        # ----优化器和记录d_loss和g_loss
        self.d_optimizer = None
        self.g_optimizer = None
        self.loss_func = None
        self.d_loss = None
        self.g_loss = None

    def compile(self, d_optimizer, g_optimizer, loss_func):
        super().compile()
        # 优化器和损失函数
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_func = loss_func
        self.d_loss = tf.keras.metrics.Mean(name="d_loss")
        self.g_loss = tf.keras.metrics.Mean(name="g_loss")

    def train_step(self, real_datas):
        # 从潜在空间进行采样
        batch_size = tf.shape(real_datas)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        # 通过生成器产生假样本
        generated_datas = self.generator(random_latent_vectors)
        # 组合真实样本和假样本
        combined_datas = tf.concat([generated_datas, real_datas], axis=0)
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )  # 标签0和1,0:真实样本, 1:fake(keras官网是这样)
        # Add random noise to the labels - important trick!(keras官网例子这么做)
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # 训练判别器,训练5次判别器
        for i in range(5):
            with tf.GradientTape() as tape:
                predictions = self.discriminator(combined_datas)
                d_loss = self.loss_func(labels, predictions)
                self.d_loss.update_state(d_loss)
            grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
            self.d_optimizer.apply_gradients(
                zip(grads, self.discriminator.trainable_weights)
            )

        # 从潜在空间进行采样
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        # 假定这些样本的都是正常的
        misleading_labels = tf.zeros((batch_size, 1))
        # 训练生成器,训练1次
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_func(misleading_labels, predictions)
            self.g_loss.update_state(g_loss)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # 一次迭代结束
        return {
            "d_loss": self.d_loss.result(),
            "g_loss": self.g_loss.result(),
        }
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容