什么是生成式对抗网络GAN
(本教程的代码,训练数据全部来自—《深度学习框架Pytroch入门与实践》, many thanks,此书采用pytorch 0.4.0版本,API接口与1.0有所差别,1.0版本pytroch中已经不推荐使用Variable)
开发/测试环境
- Ubuntu 18.04
- anaconda3, python3.6
- pycharm
- pytroch 1.0
训练过程
刚开始训练,输入为噪声向量, 生成的图像也是噪声
训练了几十次迭代之后
随着迭代次数增加,逐渐产生轮廓,仔细观察刚开始生成的图像为黑白灰度图像,没有彩色信息。
继续迭代,逐渐产生了彩色信息。
Loss曲线的变化
使用GPU进行训练
CPU进行训练太慢了,笔者采用Intel i7 5500u CPU进行训练,一秒钟大概只能迭代一次,而且batch size设置为4~8。之后切换到GPU上(Nvdia 1080ti), 单块GPU, 计算速度为20~30iter/sec, batch size=64, 直观上比CPU计算块20倍多。
迭代30K次
Process
代码
网络定义
GAN网络不同于一般的分类网络,由2部分组成: 生成器,判别器。
生成器
NetG
输入: 1x100x1x1 (NxCxHxW) 100维的噪声向量
输出: 1x3x96x96 3(Channels)x96(Height)x96(Width)的图像
from torch import nn
class NetG(nn.Module):
'''
生成器定义
'''
def __init__(self, opt):
super(NetG, self).__init__()
ngf = opt.ngf # 生成器feature map数
self.main = nn.Sequential(
# 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
nn.ConvTranspose2d(opt.nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# 上一步的输出形状:(ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# 上一步的输出形状: (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# 上一步的输出形状: (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# 上一步的输出形状:(ngf) x 32 x 32
nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
nn.Tanh() # 输出范围 -1~1 故而采用Tanh
# 输出形状:3 x 96 x 96
)
def forward(self, input):
return self.main(input)
判别器
NetD
输入: 1x3x96x96 的图像
输出: 1x1x1x1 的一个数,表示概率值
class NetD(nn.Module):
'''
判别器定义
'''
def __init__(self, opt):
super(NetD, self).__init__()
ndf = opt.ndf
self.main = nn.Sequential(
# 输入 3 x 96 x 96
nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 输出 (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# 输出 (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# 输出 (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# 输出 (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid() # 输出一个数(概率)
)
def forward(self, input):
return self.main(input).view(-1)
参数配置
- batch_size
- learning_rate
- max_epoch 最大迭代epoch个数
import os
import ipdb
import torch as t
import torchvision as tv
import tqdm
from model import NetG, NetD
from torch.autograd import Variable
from torchnet.meter import AverageValueMeter
class Config(object):
data_path = 'data/' # 数据集存放路径
num_workers = 4 # 多进程加载数据所用的进程数
image_size = 96 # 图片尺寸
batch_size = 16
max_epoch = 200
lr1 = 2e-4 # 生成器的学习率
lr2 = 2e-4 # 判别器的学习率
beta1=0.5 # Adam优化器的beta1参数
gpu=False # 是否使用GPU
nz=100 # 噪声维度
ngf = 64 # 生成器feature map数
ndf = 64 # 判别器feature map数
save_path = 'imgs/' #生成图片保存路径
vis = True # 是否使用visdom可视化
env = 'GAN' # visdom的env
plot_every = 20 # 每间隔20 batch,visdom画图一次
debug_file = '/tmp/debuggan' # 存在该文件则进入debug模式
d_every = 1 # 每1个batch训练一次判别器
g_every = 5 # 每5个batch训练一次生成器
decay_every = 10 # 没10个epoch保存一次模型
netd_path = './checkpoints/netd_100.pth' # 'checkpoints/netd_.pth' #预训练模型
netg_path = './checkpoints/netg_100.pth' # 'checkpoints/netg_211.pth'
# 只测试不训练
gen_img = 'result.png'
# 从512张生成的图片中保存最好的64张
gen_num = 64
gen_search_num = 512
gen_mean = 0 # 噪声的均值
gen_std = 1 #噪声的方差
opt = Config()
训练
训练生成器网络
训练判别器网络
def train(**kwargs):
for k_,v_ in kwargs.items():
setattr(opt,k_,v_)
if opt.vis:
from visualize import Visualizer
vis = Visualizer(opt.env)
transforms = tv.transforms.Compose([
tv.transforms.Scale(opt.image_size),
tv.transforms.CenterCrop(opt.image_size),
tv.transforms.ToTensor(),
tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = tv.datasets.ImageFolder(opt.data_path,transform=transforms)
dataloader = t.utils.data.DataLoader(dataset,
batch_size = opt.batch_size,
shuffle = True,
num_workers= opt.num_workers,
drop_last=True
)
# 定义网络
netg, netd = NetG(opt), NetD(opt)
map_location=lambda storage, loc: storage
if opt.netd_path:
netd.load_state_dict(t.load(opt.netd_path, map_location = map_location))
if opt.netg_path:
netg.load_state_dict(t.load(opt.netg_path, map_location = map_location))
# 定义优化器和损失
optimizer_g = t.optim.Adam(netg.parameters(),opt.lr1,betas=(opt.beta1, 0.999))
optimizer_d = t.optim.Adam(netd.parameters(),opt.lr2,betas=(opt.beta1, 0.999))
criterion = t.nn.BCELoss()
# 真图片label为1,假图片label为0
# noises为生成网络的输入
true_labels = Variable(t.ones(opt.batch_size))
fake_labels = Variable(t.zeros(opt.batch_size))
fix_noises = Variable(t.randn(opt.batch_size,opt.nz,1,1))
noises = Variable(t.randn(opt.batch_size,opt.nz,1,1))
errord_meter = AverageValueMeter()
errorg_meter = AverageValueMeter()
if opt.gpu:
netd.cuda()
netg.cuda()
criterion.cuda()
true_labels,fake_labels = true_labels.cuda(), fake_labels.cuda()
fix_noises,noises = fix_noises.cuda(),noises.cuda()
epochs = range(opt.max_epoch)
for epoch in iter(epochs):
for ii,(img,_) in tqdm.tqdm(enumerate(dataloader)):
real_img = Variable(img)
if opt.gpu:
real_img=real_img.cuda()
if ii%opt.d_every==0:
# 训练判别器
optimizer_d.zero_grad()
## 尽可能的把真图片判别为正确
output = netd(real_img)
error_d_real = criterion(output,true_labels)
error_d_real.backward()
## 尽可能把假图片判别为错误
noises.data.copy_(t.randn(opt.batch_size,opt.nz,1,1))
fake_img = netg(noises).detach() # 根据噪声生成假图
output = netd(fake_img)
error_d_fake = criterion(output,fake_labels)
error_d_fake.backward()
optimizer_d.step()
error_d = error_d_fake + error_d_real
errord_meter.add(error_d.data.item())
if ii%opt.g_every==0:
# 训练生成器
optimizer_g.zero_grad()
noises.data.copy_(t.randn(opt.batch_size,opt.nz,1,1))
fake_img = netg(noises)
output = netd(fake_img)
error_g = criterion(output,true_labels)
error_g.backward()
optimizer_g.step()
errorg_meter.add(error_g.data.item())
if opt.vis and ii%opt.plot_every == opt.plot_every-1:
## 可视化
if os.path.exists(opt.debug_file):
ipdb.set_trace()
fix_fake_imgs = netg(fix_noises)
vis.images(fix_fake_imgs.data.cpu().numpy()[:64]*0.5+0.5,win='fixfake')
vis.images(real_img.data.cpu().numpy()[:64]*0.5+0.5,win='real')
vis.plot('errord',errord_meter.value()[0])
vis.plot('errorg',errorg_meter.value()[0])
if epoch%opt.decay_every==0:
# 保存模型、图片
tv.utils.save_image(fix_fake_imgs.data[:64],'%s/%s.png' %(opt.save_path,epoch),normalize=True,range=(-1,1))
t.save(netd.state_dict(),'checkpoints/netd_%s.pth' %epoch)
t.save(netg.state_dict(),'checkpoints/netg_%s.pth' %epoch)
errord_meter.reset()
errorg_meter.reset()
optimizer_g = t.optim.Adam(netg.parameters(),opt.lr1,betas=(opt.beta1, 0.999))
optimizer_d = t.optim.Adam(netd.parameters(),opt.lr2,betas=(opt.beta1, 0.999))
visualize.py
#coding:utf8
from itertools import chain
import visdom
import torch
import time
import torchvision as tv
import numpy as np
class Visualizer():
'''
封装了visdom的基本操作,但是你仍然可以通过`self.vis.function`
调用原生的visdom接口
'''
def __init__(self, env='default', **kwargs):
import visdom
self.vis = visdom.Visdom(env=env, **kwargs)
# 画的第几个数,相当于横座标
# 保存(’loss',23) 即loss的第23个点
self.index = {}
self.log_text = ''
def reinit(self,env='default',**kwargs):
'''
修改visdom的配置
'''
self.vis = visdom.Visdom(env=env,**kwargs)
return self
def plot_many(self, d):
'''
一次plot多个
@params d: dict (name,value) i.e. ('loss',0.11)
'''
for k, v in d.iteritems():
self.plot(k, v)
def img_many(self, d):
for k, v in d.iteritems():
self.img(k, v)
def plot(self, name, y):
'''
self.plot('loss',1.00)
'''
x = self.index.get(name, 0)
self.vis.line(Y=np.array([y]), X=np.array([x]),
win=(name),
opts=dict(title=name),
update=None if x == 0 else 'append'
)
self.index[name] = x + 1
def img(self, name, img_):
'''
self.img('input_img',t.Tensor(64,64))
'''
if len(img_.size())<3:
img_ = img_.cpu().unsqueeze(0)
self.vis.image(img_.cpu(),
win=unicode(name),
opts=dict(title=name)
)
def img_grid_many(self,d):
for k, v in d.iteritems():
self.img_grid(k, v)
def img_grid(self, name, input_3d):
'''
一个batch的图片转成一个网格图,i.e. input(36,64,64)
会变成 6*6 的网格图,每个格子大小64*64
'''
self.img(name, tv.utils.make_grid(
input_3d.cpu()[0].unsqueeze(1).clamp(max=1,min=0)))
def log(self,info,win='log_text'):
'''
self.log({'loss':1,'lr':0.0001})
'''
self.log_text += ('[{time}] {info} <br>'.format(
time=time.strftime('%m%d_%H%M%S'),\
info=info))
self.vis.text(self.log_text,win='log_text')
def __getattr__(self, name):
return getattr(self.vis, name)
测试
输入: 1x100x1x1的噪声向量
输出: 1x3x96x96 的图像
def generate(**kwargs):
'''
随机生成动漫头像,并根据netd的分数选择较好的
'''
for k_,v_ in kwargs.items():
setattr(opt,k_,v_)
netg, netd = NetG(opt).eval(), NetD(opt).eval()
noises = t.randn(opt.gen_search_num,opt.nz,1,1).normal_(opt.gen_mean,opt.gen_std)
noises = Variable(noises, volatile=True)
map_location=lambda storage, loc: storage
print(opt.netd_path)
print(opt.netg_path)
netd.load_state_dict(t.load(opt.netd_path, map_location='cpu'))
netg.load_state_dict(t.load(opt.netg_path, map_location='cpu'))
# netd.load_state_dict(t.load(opt.netd_path, map_location= map_location))
# netg.load_state_dict(t.load(opt.netg_path, map_location= map_location))
if opt.gpu:
netd.cuda()
netg.cuda()
noises = noises.cuda()
# 生成图片,并计算图片在判别器的分数
fake_img = netg(noises)
scores = netd(fake_img).data
# 挑选最好的某几张
indexs = scores.topk(opt.gen_num)[1]
result = []
for ii in indexs:
result.append(fake_img.data[ii])
# 保存图片
tv.utils.save_image(t.stack(result),opt.gen_img,normalize=True,range=(-1,1))
生成的图像: