pytorch GAN手写体识别

这是一个github的教程,我用自己的理解复述一遍,方面记忆

一、关于生成对抗网络的第一篇论文是Generative Adversarial Networks,这是2014年发表的一篇论文

GAN模型的基本原理

二、随机产生噪声数据(至于为什么要使用噪声作为输入数据可以参考这里对抗生成网络GAN为什么输入随机噪声?), 通过生成器变成一个假数据(Fake Data)一般来说是图片。

三、生成器的目标是尽力生成和目标数据分布一致的假数据, 而判别模型目标是分辨真数据和假数据, 抽象来说可以看下图:


GAN原理

假设数据是一维数据,如图所示,上方的图形(绿线,虚线)表示数据的密度分布。绿线表示真实数据分布,大虚线表示GAN生成数据的分布,而小虚线表示一个判别函数(你可以近似看成Sigmoid)。

(a)图中由于没有迭代,导致判别函数只能大致判断真数据和假数据。
(b)图中经过训练,可以看出判别函数已经可以大致判断出真数据和假数据分布(你可以理解为>0.5表示假数据,<0.5表示真数据)。
(c)图中,经过不断学习,GAN生成的数据和真实分布已经很接近了,这是判别函数已经很难区分真实数据和假数据分布了。
(d)最终,真数据和假数据分布一致,判别函数无法判断(恒为0.5),这是GAN生成的数据和真实数据分布已经一致,可以达到以假乱真。

1. 加载必要的包

这里使用的pytorch==1.3.0, torchvision==0.4.1

import os
import numpy as np

import torchvision.transforms as transforms # 对数据进行转化,归一化等操作
from torchvision.utils import save_image

from torch.utils.data import DataLoader # 加载数据,变成一个类似迭代器的东西
from torchvision import datasets # 再带的数据数据集,一般都是常用的数据集

import torch.nn as nn
import torch.nn.functional as F
import torch

2. 参数预设

这里使用一个字典opt保存所需变量

opt = {}
opt['n_epochs'] = 200 # 迭代次数
opt['batch_size'] = 64 # 每个batch的样本数
opt['lr'] = 0.0002 # 优化器adam的学习率
opt['b1'] = 0.5 # 优化器adam的梯度的一阶动量衰减 momentum
opt['b2'] = 0.999 # 优化器adam的梯度的二阶动量衰减 momentum

opt['latent_dim'] = 100  # latent(潜)空间的维数, 可以理解为噪声数据的维度
opt['img_size'] = 28 # 输入数据是一个1*28*28的灰度图片
opt['channels'] = 1 # RGB通道个数,这里是1个通道的灰度图
opt['sample_interval'] = 400 # 图像采样间隔(做记录)

# 输入图片大小 1*28*28
img_shape = (opt['channels'], opt['img_size'], opt['img_size'])
# 如果GPU可以使用, 则先在这里立个flag
cuda = True if torch.cuda.is_available() else False
print(cuda)
# GPU可以使用的话使用GPU的FloatTensor
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 

3.加载数据

这里使用torch自带的手写体数据集
# 输出图片(GAN生成假图片文件夹)
os.makedirs("images", exist_ok=True)
# 训练数据文件夹
os.makedirs("mnist", exist_ok=True)
# 下载图片数据
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt['img_size']), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt['batch_size'], # 每次取出数据量
    shuffle=True,
)

4.定义模型

# 生成器模型
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
  
    # 简化代码,将Linear,BatchNorm1d, LeakyReLU装在一起
    def block(in_feat, out_feat, normalize=True):
      layers = [nn.Linear(in_feat, out_feat)]
      if normalize:
        layers.append(nn.BatchNorm1d(out_feat, 0.8))
      layers.append(nn.LeakyReLU(0.2, inplace=True)) # inplace一个原地操作
      # 是对于Conv2d这样的上层网络传递下来的tensor直接进行修改,好处就是可以节省运算内存,不用多储存变量y
      return layers
  
    # 这里前面加*相当于在Sequential中extend
    self.model = nn.Sequential(
        *block(opt['latent_dim'], 128, normalize=False),
        *block(128,256),
        *block(256,512),
        *block(512, 1024),
        nn.Linear(1024, int(np.prod(img_shape))), # np.prod摊开
        nn.Tanh()
    )
  def forward(self, data):
    img = self.model(data)
    img = img.view(img.size(0), *img_shape) # 因为输出数据应该为一张图片,所以需要将Reshae变为图片(图片数, channel,长, 宽)
    return img

# 定义判别模型
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    
    self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512), # np.prod将图片展开成一维向量
                               nn.LeakyReLU(0.2, inplace=True),
                               nn.Linear(512, 256),
                               nn.LeakyReLU(0.2, inplace=True),
                               nn.Linear(256, 1),
                               nn.Sigmoid(),
                              )
  def forward(self, img):
    img_flat = img.view(img.size(0), -1) # img.size(0)表示每个batch图片数量
    # 拉成 图片1 维度相乘 1*28*28
    #          图片n 维度相乘
    validity = self.model(img_flat)
    return validity
需要说明的是block前面加*表示把block中的层安顺序平铺开,等价于下面的代码:
 self.model = nn.Sequential(
        *block(opt['latent_dim'], 128),
        *block(128,256),)

# 等价于下面的代码
self.model = nn.Sequential(
        # 这是一个block
        nn.Linear(in_feat, 128)
        layers.append(nn.BatchNorm1d(out_feat, 0.8))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        # 这是第二个block
        nn.Linear(128, 256)
        layers.append(nn.BatchNorm1d(256, 0.8))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
)

4.损失函数和优化器

# 判别器损失函数
adversarial_loss = torch.nn.BCELoss()

# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()
# 如果CPU可以用
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt['lr'], betas=(opt['b1'], opt['b2']))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt['lr'], betas=(opt['b1'], opt['b2']))
到此为止基本的模型定义就算完成了,我们现在有数据集dataloader(一个迭代器),generator生成器对象, discriminator判别器对象,下面我们将训练模型

5.训练模型

# 训练模型
for epoch in range(opt['n_epochs']):
    for i, (imgs, _) in enumerate(dataloader): # 这里不需要类标签
        # Variable和Tensor在新版本中已经合并
        valid = Tensor(imgs.size(0), 1).fill_(1.0) # 真实图片类标签设为1.0
        valid.requires_grad=False # 不能更新梯度
        fake = Tensor(imgs.size(0), 1).fill_(0.0) # 假图片类标签设为0.0
        fake.requires_grad=False
        
        # 将数据转成cuda的tensor,加速
        # imgs.type() = 'torch.FloatTensor'
        # real_imgs.type() = 'torch.cuda.FloatTensor'
        real_imgs = imgs.type(Tensor)
        
        optimizer_G.zero_grad() # 生成模型的优化器梯度清零
        
        # 训练生成器
        # 产生噪声(输入)数据 64张100维噪声
        noise = Tensor(np.random.normal(0, 1, (imgs.shape[0], opt['latent_dim'])))
        # 生成假图片
        fake_img = generator(noise)
        # 更新生层模型
        g_loss = adversarial_loss(discriminator(fake_img), valid)
        g_loss.backward() # 返回梯度
        optimizer_G.step() # 更新权重
        
        optimizer_D.zero_grad() # 判别模型的优化器梯度清零
        # 训练判别器
        real_loss = adversarial_loss(discriminator(real_imgs), valid) # 真实图片的损失
        # detach返回一个新的张量,它与当前图形分离。fake_img结果永远不需要梯度
        fake_loss = adversarial_loss(discriminator(fake_img.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2 # 去一个均值作为损失
        d_loss.backward()
        optimizer_D.step()
        # 打印损失
        if i%500 == 0:
             print(
                   "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                  % (epoch, opt['n_epochs'], i, len(dataloader), d_loss.item(),
                  g_loss.item())
          )
        # 选择前25个图片保存
        batches_done = epoch * len(dataloader) + i
        if batches_done % opt['sample_interval'] == 0:
      save_image(fake_img.data[:25], "images/%d.png" % batches_done, 
                          nrow=5, normalize=True)

下面是训练第0次(epoch * len(dataloader) + i),10000次和100000次的结果,可以看出若持续训练,结果会更加逼近真实数据

0.png
10000.png
100000.png

下面给出完整的代码

import os
import math
import numpy as np
import argparse

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

opt = {}
opt['n_epochs'] = 200 # 迭代次数
opt['batch_size'] = 64
opt['lr'] = 0.0002 # adam: 学习率
opt['b1'] = 0.5 # adam: 梯度的一阶动量衰减 momentum
opt['b2'] = 0.999 # adam: 梯度的一阶动量衰减 momentum

opt['latent_dim'] = 100  # latent空间的维数
opt['img_size'] = 28
opt['channels'] = 1
opt['sample_interval'] = 400 # 图像采样间隔(做记录)

# 输入图片大小
img_shape = (opt['channels'], opt['img_size'], opt['img_size'])
cuda = True if torch.cuda.is_available() else False
print(cuda)

os.makedirs("images", exist_ok=True)
os.makedirs("mnist", exist_ok=True)
# 下载图片
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt['img_size']), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt['batch_size'],
    shuffle=True,
)
# 如果CPU可以用
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# 生成器模型
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
  
    # 整个作为一层
    def block(in_feat, out_feat, normalize=True):
      layers = [nn.Linear(in_feat, out_feat)]
      if normalize:
        layers.append(nn.BatchNorm1d(out_feat, 0.8))
      layers.append(nn.LeakyReLU(0.2, inplace=True)) # inplace一个原地操作
      # 是对于Conv2d这样的上层网络传递下来的tensor直接进行修改,好处就是可以节省运算内存,不用多储存变量y
      return layers
  
    # 这里前面加*相当于在Sequential中extend
    self.model = nn.Sequential(
        *block(opt['latent_dim'], 128, normalize=False),
        *block(128,256),
        *block(256,512),
        *block(512, 1024),
        nn.Linear(1024, int(np.prod(img_shape))), # np.prod摊开
        nn.Tanh()
    )
  def forward(self, data):
    img = self.model(data)
    img = img.view(img.size(0), *img_shape) # 将平铺的数据变为图片(图片数, 长, 宽,channel)
    return img
# 定义判别函数
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    
    self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),
                               nn.LeakyReLU(0.2, inplace=True),
                               nn.Linear(512, 256),
                               nn.LeakyReLU(0.2, inplace=True),
                               nn.Linear(256, 1),
                               nn.Sigmoid(),
                              )
  def forward(self, img):
    img_flat = img.view(img.size(0), -1)
    # 拉成 图片1 维度相乘
    #     图片n 维度相乘
    validity = self.model(img_flat)
    return validity

# 判别器损失函数
adversarial_loss = torch.nn.BCELoss()

# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()
# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt['lr'], betas=(opt['b1'], opt['b2']))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt['lr'], betas=(opt['b1'], opt['b2']))

# 训练模型
for epoch in range(opt['n_epochs']):
  for i, (imgs, _) in enumerate(dataloader):
    valid = Tensor(imgs.size(0), 1).fill_(1.0) # 真实图片类标签
    valid.requires_grad=False
    
    fake = Tensor(imgs.size(0), 1).fill_(0.0) # 假图片类标签
    fake.requires_grad=False
    
    # 将数据转成cuda的tensor
    real_imgs = imgs.type(Tensor)
    
    # 训练生成器
    optimizer_G.zero_grad() # 生成器梯度清零
    # 产生噪声数据 64张100维噪声
    noise = Tensor(np.random.normal(0, 1, (imgs.shape[0], opt['latent_dim'])))
    # 生成噪声图片
    fake_img = generator(noise)
    
    # 更新生层器
    g_loss = adversarial_loss(discriminator(fake_img), valid)
    g_loss.backward()
    optimizer_G.step()
    
    # 训练判别器
    optimizer_D.zero_grad()
    real_loss = adversarial_loss(discriminator(real_imgs), valid)
    # detach返回一个新的张量,它与当前图形分离。结果永远不需要梯度
    fake_loss = adversarial_loss(discriminator(fake_img.detach()), fake)
    d_loss = (real_loss + fake_loss) / 2
    d_loss.backward()
    optimizer_D.step()
    
    if i%500 == 0:
      print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt['n_epochs'], i, len(dataloader), d_loss.item(), g_loss.item())
          )
    
    batches_done = epoch * len(dataloader) + i
    if batches_done % opt['sample_interval'] == 0:
      save_image(fake_img.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,923评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,154评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,775评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,960评论 1 290
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,976评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,972评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,893评论 3 416
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,709评论 0 271
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,159评论 1 308
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,400评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,552评论 1 346
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,265评论 5 341
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,876评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,528评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,701评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,552评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,451评论 2 352