from __future__ import print_function, division
from sklearn import datasets
import math
import matplotlib.pyplot as plt
import numpy as np
import progressbar
from sklearn.datasets import fetch_mldata
from mlfromscratch.deep_learning.optimizers import Adam
from mlfromscratch.deep_learning.loss_functions import CrossEntropy
from mlfromscratch.deep_learning.layers import Dense, Dropout, Flatten, Activation, Reshape, BatchNormalization
from mlfromscratch.deep_learning import NeuralNetwork
class GAN():
"""A Generative Adversarial Network with deep fully-connected neural nets as
Generator and Discriminator.
Training Data: MNIST Handwritten Digits (28x28 images)
"""
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.img_dim = self.img_rows * self.img_cols
self.latent_dim = 100
optimizer = Adam(learning_rate=0.0002, b1=0.5)
loss_function = CrossEntropy
# Build the discriminator
self.discriminator = self.build_discriminator(optimizer, loss_function)
# Build the generator
self.generator = self.build_generator(optimizer, loss_function)
# Build the combined model
self.combined = NeuralNetwork(optimizer=optimizer, loss=loss_function)
self.combined.layers.extend(self.generator.layers)
self.combined.layers.extend(self.discriminator.layers)
print ()
self.generator.summary(name="Generator")
self.discriminator.summary(name="Discriminator")
def build_generator(self, optimizer, loss_function):
model = NeuralNetwork(optimizer=optimizer, loss=loss_function)
model.add(Dense(256, input_shape=(self.latent_dim,)))
model.add(Activation('leaky_relu'))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(Activation('leaky_relu'))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(Activation('leaky_relu'))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(self.img_dim))
model.add(Activation('tanh'))
return model
def build_discriminator(self, optimizer, loss_function):
model = NeuralNetwork(optimizer=optimizer, loss=loss_function)
model.add(Dense(512, input_shape=(self.img_dim,)))
model.add(Activation('leaky_relu'))
model.add(Dropout(0.5))
model.add(Dense(256))
model.add(Activation('leaky_relu'))
model.add(Dropout(0.5))
model.add(Dense(2))
model.add(Activation('softmax'))
return model
def train(self, n_epochs, batch_size=128, save_interval=50):
mnist = fetch_mldata('MNIST original')
X = mnist.data
y = mnist.target
# Rescale [-1, 1]
X = (X.astype(np.float32) - 127.5) / 127.5
half_batch = int(batch_size / 2)
for epoch in range(n_epochs):
# ---------------------
# Train Discriminator
# ---------------------
self.discriminator.set_trainable(True)
# Select a random half batch of images
idx = np.random.randint(0, X.shape[0], half_batch)
imgs = X[idx]
# Sample noise to use as generator input
noise = np.random.normal(0, 1, (half_batch, self.latent_dim))
# Generate a half batch of images
gen_imgs = self.generator.predict(noise)
# Valid = [1, 0], Fake = [0, 1]
valid = np.concatenate((np.ones((half_batch, 1)), np.zeros((half_batch, 1))), axis=1)
fake = np.concatenate((np.zeros((half_batch, 1)), np.ones((half_batch, 1))), axis=1)
# Train the discriminator
d_loss_real, d_acc_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake, d_acc_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * (d_loss_real + d_loss_fake)
d_acc = 0.5 * (d_acc_real + d_acc_fake)
# ---------------------
# Train Generator
# ---------------------
# We only want to train the generator for the combined model
self.discriminator.set_trainable(False)
# Sample noise and use as generator input
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# The generator wants the discriminator to label the generated samples as valid
valid = np.concatenate((np.ones((batch_size, 1)), np.zeros((batch_size, 1))), axis=1)
# Train the generator
g_loss, g_acc = self.combined.train_on_batch(noise, valid)
# Display the progress
print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, acc: %.2f%%]" % (epoch, d_loss, 100*d_acc, g_loss, 100*g_acc))
# If at save interval => save generated image samples
if epoch % save_interval == 0:
self.save_imgs(epoch)
def save_imgs(self, epoch):
r, c = 5, 5 # Grid size
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
# Generate images and reshape to image shape
gen_imgs = self.generator.predict(noise).reshape((-1, self.img_rows, self.img_cols))
# Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
plt.suptitle("Generative Adversarial Network")
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt,:,:], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("mnist_%d.png" % epoch)
plt.close()
if __name__ == '__main__':
gan = GAN()
gan.train(n_epochs=200000, batch_size=64, save_interval=400)
[Machine Learning From Scratch]-unsupervised_learning-generative_adversarial_network
©著作权归作者所有,转载或内容合作请联系作者
- 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
- 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
- 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
推荐阅读更多精彩内容
- Unsupervised Representation Learning with Deep Convolutio...