pytroch学习—生成式对抗网络GAN

什么是生成式对抗网络GAN

(本教程的代码,训练数据全部来自—《深度学习框架Pytroch入门与实践》, many thanks,此书采用pytorch 0.4.0版本,API接口与1.0有所差别,1.0版本pytroch中已经不推荐使用Variable)


开发/测试环境

  • Ubuntu 18.04
  • anaconda3, python3.6
  • pycharm
  • pytroch 1.0

训练过程

刚开始训练,输入为噪声向量, 生成的图像也是噪声

image.png

训练了几十次迭代之后

image.png

随着迭代次数增加,逐渐产生轮廓,仔细观察刚开始生成的图像为黑白灰度图像,没有彩色信息。

image.png
image.png
image.png

继续迭代,逐渐产生了彩色信息。

image.png
image.png
image.png
image.png
image.png
image.png
image.png

Loss曲线的变化

image.png
image.png
image.png
image.png
image.png

使用GPU进行训练

CPU进行训练太慢了,笔者采用Intel i7 5500u CPU进行训练,一秒钟大概只能迭代一次,而且batch size设置为4~8。之后切换到GPU上(Nvdia 1080ti), 单块GPU, 计算速度为20~30iter/sec, batch size=64, 直观上比CPU计算块20倍多。

image.png

迭代30K次


image.png

Process

深度录屏_选择区域_20190206000807.gif
深度录屏_TeamViewer_20190206000840.gif
深度录屏_TeamViewer_20190206001653.gif

代码

网络定义

GAN网络不同于一般的分类网络,由2部分组成: 生成器,判别器。

生成器

NetG
输入: 1x100x1x1 (NxCxHxW) 100维的噪声向量
输出: 1x3x96x96 3(Channels)x96(Height)x96(Width)的图像


from torch import nn


class NetG(nn.Module):
    '''
    生成器定义
    '''
    def __init__(self, opt):
        super(NetG, self).__init__()
        ngf = opt.ngf  # 生成器feature map数

        self.main = nn.Sequential(
            # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
            nn.ConvTranspose2d(opt.nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 上一步的输出形状:(ngf*8) x 4 x 4

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*4) x 8 x 8

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*2) x 16 x 16

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 上一步的输出形状:(ngf) x 32 x 32

            nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
            nn.Tanh()  # 输出范围 -1~1 故而采用Tanh
            # 输出形状:3 x 96 x 96
        )

    def forward(self, input):
        return self.main(input)


判别器

NetD

输入: 1x3x96x96 的图像
输出: 1x1x1x1 的一个数,表示概率值

class NetD(nn.Module):
    '''
    判别器定义
    '''
    def __init__(self, opt):
        super(NetD, self).__init__()
        ndf = opt.ndf
        self.main = nn.Sequential(
            # 输入 3 x 96 x 96
            nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf) x 32 x 32
            
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*2) x 16 x 16
            
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*4) x 8 x 8
            
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*8) x 4 x 4
            
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()  # 输出一个数(概率)
        )

    def forward(self, input):
        return self.main(input).view(-1)

参数配置

  • batch_size
  • learning_rate
  • max_epoch 最大迭代epoch个数

import os
import ipdb
import torch as t
import torchvision as tv
import tqdm
from model import NetG, NetD
from torch.autograd import Variable
from torchnet.meter import AverageValueMeter   


class Config(object):
    data_path = 'data/'  # 数据集存放路径
    num_workers = 4  # 多进程加载数据所用的进程数
    image_size = 96  # 图片尺寸
    batch_size = 16
    max_epoch = 200
    lr1 = 2e-4  # 生成器的学习率
    lr2 = 2e-4  # 判别器的学习率
    beta1=0.5  # Adam优化器的beta1参数
    gpu=False  # 是否使用GPU
    nz=100  # 噪声维度
    ngf = 64  # 生成器feature map数
    ndf = 64  # 判别器feature map数
    
    save_path = 'imgs/'  #生成图片保存路径
    
    vis = True  # 是否使用visdom可视化
    env = 'GAN'  # visdom的env
    plot_every = 20  # 每间隔20 batch,visdom画图一次

    debug_file = '/tmp/debuggan'  # 存在该文件则进入debug模式
    d_every = 1  # 每1个batch训练一次判别器
    g_every = 5  # 每5个batch训练一次生成器
    decay_every = 10  # 没10个epoch保存一次模型
    netd_path = './checkpoints/netd_100.pth'  # 'checkpoints/netd_.pth' #预训练模型
    netg_path = './checkpoints/netg_100.pth'  # 'checkpoints/netg_211.pth'
    
    # 只测试不训练
    gen_img = 'result.png'
    # 从512张生成的图片中保存最好的64张
    gen_num = 64 
    gen_search_num = 512 
    gen_mean = 0  # 噪声的均值
    gen_std = 1  #噪声的方差
    


opt = Config()

训练

训练生成器网络
训练判别器网络


def train(**kwargs):
    for k_,v_ in kwargs.items():
        setattr(opt,k_,v_)
    if opt.vis:
        from visualize import Visualizer
        vis = Visualizer(opt.env)
    
    transforms = tv.transforms.Compose([
                    tv.transforms.Scale(opt.image_size),
                    tv.transforms.CenterCrop(opt.image_size),
                    tv.transforms.ToTensor(),
                    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                        ])
    
    dataset = tv.datasets.ImageFolder(opt.data_path,transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size = opt.batch_size,
                                         shuffle = True,
                                         num_workers= opt.num_workers,
                                         drop_last=True
                                         )

    # 定义网络
    netg, netd = NetG(opt), NetD(opt)  
    map_location=lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location = map_location)) 
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location = map_location))

    # 定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),opt.lr1,betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),opt.lr2,betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss()

    # 真图片label为1,假图片label为0
    # noises为生成网络的输入
    true_labels = Variable(t.ones(opt.batch_size))
    fake_labels = Variable(t.zeros(opt.batch_size))
    fix_noises = Variable(t.randn(opt.batch_size,opt.nz,1,1))
    noises = Variable(t.randn(opt.batch_size,opt.nz,1,1))

    errord_meter = AverageValueMeter()
    errorg_meter = AverageValueMeter()

    if opt.gpu:
        netd.cuda()
        netg.cuda()
        criterion.cuda()
        true_labels,fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises,noises = fix_noises.cuda(),noises.cuda()
        
    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):
        for ii,(img,_) in tqdm.tqdm(enumerate(dataloader)):
            real_img = Variable(img)
            if opt.gpu: 
                real_img=real_img.cuda()
            if ii%opt.d_every==0:
                # 训练判别器
                optimizer_d.zero_grad()
                ## 尽可能的把真图片判别为正确
                output = netd(real_img)
                error_d_real = criterion(output,true_labels)
                error_d_real.backward()
                
                ## 尽可能把假图片判别为错误
                noises.data.copy_(t.randn(opt.batch_size,opt.nz,1,1))
                fake_img = netg(noises).detach() # 根据噪声生成假图
                output = netd(fake_img)
                error_d_fake = criterion(output,fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

                error_d = error_d_fake + error_d_real

                errord_meter.add(error_d.data.item())

            if ii%opt.g_every==0:
                # 训练生成器
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size,opt.nz,1,1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = criterion(output,true_labels)
                error_g.backward()
                optimizer_g.step()                
                errorg_meter.add(error_g.data.item())

            if opt.vis and ii%opt.plot_every == opt.plot_every-1:
                ## 可视化
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.data.cpu().numpy()[:64]*0.5+0.5,win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64]*0.5+0.5,win='real')  
                vis.plot('errord',errord_meter.value()[0])
                vis.plot('errorg',errorg_meter.value()[0])   

        if epoch%opt.decay_every==0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:64],'%s/%s.png' %(opt.save_path,epoch),normalize=True,range=(-1,1))
            t.save(netd.state_dict(),'checkpoints/netd_%s.pth' %epoch)
            t.save(netg.state_dict(),'checkpoints/netg_%s.pth' %epoch)
            errord_meter.reset()
            errorg_meter.reset()
            optimizer_g = t.optim.Adam(netg.parameters(),opt.lr1,betas=(opt.beta1, 0.999))
            optimizer_d = t.optim.Adam(netd.parameters(),opt.lr2,betas=(opt.beta1, 0.999))
            

visualize.py


#coding:utf8
from itertools import chain
import visdom
import torch
import time
import torchvision as tv
import numpy as np

class Visualizer():
    '''
    封装了visdom的基本操作,但是你仍然可以通过`self.vis.function`
    调用原生的visdom接口
    '''

    def __init__(self, env='default', **kwargs):
        import visdom
        self.vis = visdom.Visdom(env=env, **kwargs)
        
        # 画的第几个数,相当于横座标
        # 保存(’loss',23) 即loss的第23个点
        self.index = {} 
        self.log_text = ''
    def reinit(self,env='default',**kwargs):
        '''
        修改visdom的配置
        '''
        self.vis = visdom.Visdom(env=env,**kwargs)
        return self

    def plot_many(self, d):
        '''
        一次plot多个
        @params d: dict (name,value) i.e. ('loss',0.11)
        '''
        for k, v in d.iteritems():
            self.plot(k, v)

    def img_many(self, d):
        for k, v in d.iteritems():
            self.img(k, v)

    def plot(self, name, y):
        '''
        self.plot('loss',1.00)
        '''
        x = self.index.get(name, 0)
        self.vis.line(Y=np.array([y]), X=np.array([x]),
                      win=(name),
                      opts=dict(title=name),
                      update=None if x == 0 else 'append'
                      )
        self.index[name] = x + 1

    def img(self, name, img_):
        '''
        self.img('input_img',t.Tensor(64,64))
        '''
         
        if len(img_.size())<3:
            img_ = img_.cpu().unsqueeze(0) 
        self.vis.image(img_.cpu(),
                       win=unicode(name),
                       opts=dict(title=name)
                       )



    def img_grid_many(self,d):
        for k, v in d.iteritems():
            self.img_grid(k, v)

    def img_grid(self, name, input_3d):
        '''
        一个batch的图片转成一个网格图,i.e. input(36,64,64)
        会变成 6*6 的网格图,每个格子大小64*64
        '''
        self.img(name, tv.utils.make_grid(
            input_3d.cpu()[0].unsqueeze(1).clamp(max=1,min=0)))

    def log(self,info,win='log_text'):
        '''
        self.log({'loss':1,'lr':0.0001})
        '''

        self.log_text += ('[{time}] {info} <br>'.format(
                            time=time.strftime('%m%d_%H%M%S'),\
                            info=info)) 
        self.vis.text(self.log_text,win='log_text')   

    def __getattr__(self, name):
        return getattr(self.vis, name)

测试

输入: 1x100x1x1的噪声向量
输出: 1x3x96x96 的图像

def generate(**kwargs):
    '''
    随机生成动漫头像,并根据netd的分数选择较好的
    '''
    for k_,v_ in kwargs.items():
        setattr(opt,k_,v_)
    
    netg, netd = NetG(opt).eval(), NetD(opt).eval()  
    noises = t.randn(opt.gen_search_num,opt.nz,1,1).normal_(opt.gen_mean,opt.gen_std)
    noises = Variable(noises, volatile=True)

    map_location=lambda storage, loc: storage
    print(opt.netd_path)
    print(opt.netg_path)
    netd.load_state_dict(t.load(opt.netd_path, map_location='cpu'))
    netg.load_state_dict(t.load(opt.netg_path, map_location='cpu'))
    # netd.load_state_dict(t.load(opt.netd_path, map_location= map_location))
    # netg.load_state_dict(t.load(opt.netg_path, map_location= map_location))
    
    if opt.gpu:
        netd.cuda()
        netg.cuda()
        noises = noises.cuda()
        
    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).data

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    tv.utils.save_image(t.stack(result),opt.gen_img,normalize=True,range=(-1,1))

生成的图像:


result.png
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容