Generaing Digits with Pytorch

In this blog post we'll implement a generative image model that converts random noise into images of digits! The full code is available here, just clone it to your machine and it's ready to play. As a former Torch7 user, I attempt to reproduce the results from the Torch7 post.

For this, we employ Generative Adversarial Network. A GAN consists of two components; a generator which converts random noise into images and a discriminator which tries to distinguish between generated and real images. Here, ‘real’ means that the image came from our training set of images in contrast to the generated fakes.

To train the model we let the discriminator and generator play a game against each other. We first show the discriminator a mixed batch of real images from our training set and of fake images generated by the generator. We then simultaneously optimize the discriminator to answer NO to fake images and YES to real images and optimize the generator to fool the discriminator into believing that the fake images were real. This corresponds to minimizing the classification error wrt. the discriminator and maximizing it wrt. the generator. With careful optimization both generator and discriminator will improve and the generator will eventually start generating convincing images.

Implementing a GAN

We implement the generator and discriminator as convnets and train them with stochastic gradient descent.

The discriminator is a mlp with consecutive blocks of Linear Layer and LeakyReLU activation.

D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())

This is a pretty standard architecture. The 28x28 grey images of digits are converted into a 781x1 vector by stacking their columns. The discriminator takes a 784x1 vector as input and predicts YES or NO with a single sigmoid output.

The generator is a mlp with a vector with Linear Layer and ReLU activation repeatedly:

G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())

To generate an image we feed the generator with noise distributed N(0,1). After successful training, the output should be meaningful images!

z = torch.randn(batch_size, latent_size).cuda()
z = Variable(z)
fake_images = G(z)

Generating digits

We train our GAN using images of digits from the MNIST dataset. After at around 5 epochs you should start to see blurr digits. And after 80 epochs the results look pleasant.

after 5 epochs

after 100 epochs

Loss and Discriminator Accuracy

accuracy.png

loss.png

We also record the accuracy of the discriminator and the loss of the discriminator and the generator after each epoch. Refer to GAN, when the global minimum of the training criterion is achieved, the loss should be -log4 and the accuracy should be 0.5. Our model does not achieve the ideal result. Moreover, as we can see from the figure, our model start to overfit at around 80 epochs. The model structure and the training strategy should be improved in the future.

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 如果你不能陪我到最后 ,就不要半路进入我的生活,你知道, 我可以习惯独自一人,但是我们都接受不了本来有人陪,突然就...
    吃猫的小渔阅读 350评论 0 0
  • 2018年4月11日,周三,阴转晴。 “ 对不起,我变心了。”这话多酷,一生至少要说一次吧。 在很久很久以前,可以...
    wanlimoon阅读 463评论 0 2
  • 事件:种种原因 我这俩体弱的娃咳嗽还没有好。老公也不在家 今天早上要复诊 各种手忙脚乱。脾气就控制不住了。早饭见大...
    代表胖阅读 424评论 0 0