DCGAN学习和编程

DCGAN

网络结构

DCGAN其主要贡献在于把原始GAN中的全连接层替换为了卷积层。具体如下:

  • 首先是全卷积网络,这使用了跨步卷积代替了确定性的空间池化功能(例如最大池化等操作),从而能让网络能够学习自身的空间下采样。
  • 其次是再卷积层特征上消除全连接层的趋势,全局池化就是一个最好的例子。
  • 第三是采用了BatchNormalization,这个通过将输入归一化从而稳定了训练的过程,并有助于在梯度在更深的模型中进行流动,BN并不用于生成器输出层和鉴别器输入层。
  • 使用了ReLU激活函数,并发现使用LeaklyReLU函数可以让正常工作,特别是对于更高分辨率的建模。

DCGAN的生成器结构可以用如下的图来表示:


DCGAN生成64*64图像的生成器结构

DCGAN的判别器和生成器的结构基本相反,其主要是通过进行卷积降维从而把输入的图像生成为一个标量,从而使用Sigmoid激活层确认其概率。

一些的DCGAN结构指南

  • 用跨步卷积(针对鉴别器)和分数跨步卷积(针对生成器)替换掉所有的池化层。
  • 在生成器和鉴别器中都使用BN,并且需要注意的是不对生成器的最后一层和鉴别器的输入层使用BN。
  • 删除掉全连接的隐藏层从而实现更深层次的体系结构。
  • 在生成器中全都使用ReLU激活函数,并在最后一层使用Tanh激活函数
  • 在鉴别其中,对所有层使用LeakyReLU激活函数。

训练的一些细节:

  • 使用了batch_size=128
  • 所有权重都服从0中心方差为0.02的正态分布。
  • 在LeakyReLU的泄露斜率值都为0.2
  • 使用Adam的优化器,lr=0.0002,\beta 1=0.5(作者发现0.9会有不稳定的情况发生)

代码实现:

# 使用pytorch在ununtu20上使用的代码
# gpu:Nvidia RTX2070s 8g显存

import os,math,torch,torchvision
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import torch.nn as nn
import random
from torch.utils.data import Dataset
random.seed(666)
torch.manual_seed(666)
from torch.autograd import Variable
import torch.nn.functional as F
os.makedirs('myImages', exist_ok=True)
#下面是一些初始化数据的定义
n_epochs=2
batch_size=512
lr=0.0002
b1=0.5
b2=0.999
n_cpu=8
latent_dim=100
img_size=64
channels=3
sample_interval=400
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# dataset = torchvision.datasets.MNIST(root='../../data/mnist',download=True,
#                             transform=transforms.Compose([transforms.Resize(size=img_size),
#                                                          transforms.ToTensor(),
#                                                          transforms.Normalize([0.5]*3,[0.5]*3)]
#                                                           )
#                             )
# dataloader = DataLoader(dataset=dataset,batch_size=batch_size,shuffle=False,num_workers=n_cpu)
import PIL.Image as Image
class CeleBaDataset(torch.utils.data.Dataset):
    def __init__(self,img_root:str,transform=None):
        super(CeleBaDataset,self).__init__()
        temp_list=list()
        for s in os.listdir(path=img_root):
            if s.find('.png'):
                temp_list.append(os.path.join(img_root,s))
        self.datalist = temp_list
        self.transform = transform
    def __len__(self):
        return len(self.datalist)
    def __getitem__(self,idx):
        image = Image.open(self.datalist[idx])
        if self.transform:
            image = self.transform(image)
        return image
dataloader = DataLoader(dataset=CeleBaDataset(img_root='/home/hx/Desktop/WorkDisk/DadaSets/CelebA/Img/img_align_celeba_png.7z/img_align_celeba_png/'
                                              ,transform=transforms.Compose([transforms.Resize(size=img_size),
                                                         transforms.Resize(64),
                                                         transforms.CenterCrop(64),
                                                         transforms.ToTensor(),
                                                         transforms.Normalize([0.5]*3,[0.5]*3),
                                                        ])
                                              ),
                        batch_size=batch_size,
                        num_workers=n_cpu,
                        shuffle=False,
                        pin_memory=True)

#%%

def weight_init(modules:torch.nn.Module):
    for m in modules.modules():
        if isinstance(m,nn.ConvTranspose2d):
            nn.init.normal_(m.weight.data,0,0.02)
        elif isinstance(m,nn.BatchNorm2d):
            nn.init.normal_(m.weight.data,0,0.02)
def weight_init_apply(m:object):
    if m.__class__.__name__.find('Conv'):
        nn.init.normal_(m.weight.data,0,0.02)
    elif m.__class__.__name__.find('BatchNorm'):
        nn.init.normal_(m.weight.data,0,0.02)

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        in_channels=[latent_dim,512,256,128,64]
        out_channels=[512,256,128,64,3]
        paddings=[0,1,1,1,1]
        strides=[1,2,2,2,2]
        layers=[]
        for i in range(5):
            layers.append(nn.BatchNorm2d(num_features=in_channels[i]))
            layers.append(nn.ConvTranspose2d(in_channels=in_channels[i],
                                             out_channels=out_channels[i],
                                             kernel_size=4,
                                             stride=strides[i],
                                             padding=paddings[i]))
            if i != 4:
                layers.append(nn.LeakyReLU(negative_slope=0.2,inplace=True))
            else:
                layers.append(nn.Tanh())
        self.G=nn.Sequential(*layers)

    def forward(self,x):
        return self.G(x)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        layers=[]
        def block(in_channels,out_channels,stride=2,padding=1,if_bn=True,if_relu=True):
            if if_bn:
                layers.append(nn.BatchNorm2d(in_channels))
            layers.append(nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=4,stride=stride,padding=padding))
            if if_relu:
                layers.append(nn.LeakyReLU(negative_slope=0.2,inplace=True))
            else:
                layers.append(nn.Sigmoid())
        block(3,64,stride=2,padding=1,if_bn=False)  # 此时64*32*32
        block(64,128,2,1)                           # 此时128*16*16
        block(128,256,2,1)                          # 此时256*8*8
        block(256,512,2,1)                          # 此时512*4*4
        block(512,1,1,0,if_relu=False)              # 此时1*1*1
        self.D=nn.Sequential(*layers)

    def forward(self,x):
        return self.D(x)

#%%

generator = Generator()
weight_init(generator)
discriminator=Discriminator()
weight_init(discriminator)
loss_fn = torch.nn.BCELoss()
generator.to(device)
discriminator.to(device)
loss_fn.to(device)

opm_G = torch.optim.Adam(generator.parameters(),lr=lr,betas=(b1,b2))
opm_D = torch.optim.Adam(discriminator.parameters(),lr=lr,betas=(b1,b2))

#%%
data = CeleBaDataset(img_root='/home/hx/Desktop/WorkDisk/DadaSets/CelebA/Img/img_align_celeba_png.7z/img_align_celeba_png/')
data.__getitem__(10000)

#%%

for epoch in range(20):
    for i,img in enumerate(dataloader):
        img = img.to(device)
        real = torch.ones((img.shape[0],1),device=device)
        fake = torch.zeros((img.shape[0],1),device=device)
        z = torch.randn(size=(img.shape[0],latent_dim,1,1),device=device)
        opm_D.zero_grad()
        real_loss = loss_fn(discriminator(img).view(img.shape[0],-1),real)
        fake_loss = loss_fn(discriminator(generator(z).detach())view(img.shape[0],-1),fake)
        d_loss = (real_loss+fake_loss)/2
        d_loss.backward()
        opm_D.step()
        print('Dloss:',d_loss)

        opm_G.zero_grad()
        z = torch.randn(size=(img.shape[0],latent_dim,1,1),device=device)
        g_loss = loss_fn(discriminator(generator(z)).view(img.shape[0],-1),fake)
        g_loss.backward()
        opm_G.step()
        print('Gloss:',g_loss)
    print('epoch:{}Dloss:{}Gloss:{}',epoch,d_loss,g_loss)

最后的图像生成效果。。。待跑完

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 16.批量归一化和残差网络 批量归一化(BatchNormalization) BN是由Google于2015年提...
    0d3382cf56eb阅读 853评论 0 0
  • GAN 由Goodfellow等人于2014年引入的生成对抗网络(GAN)是用于学习图像潜在空间的VAE的替代方案...
    七八音阅读 7,831评论 1 3
  • 想从Tensorflow循环生成对抗网络开始。但是发现从最难的内容入手还是?太复杂了所以搜索了一下他的始祖也就是深...
    Feather轻飞阅读 5,117评论 1 4
  • (转)生成对抗网络(GANs)最新家谱:为你揭秘GANs的前世今生 生成对抗网络(GAN)一...
    Eric_py阅读 4,380评论 0 4
  • 推荐指数: 6.0 书籍主旨关键词:特权、焦点、注意力、语言联想、情景联想 观点: 1.统计学现在叫数据分析,社会...
    Jenaral阅读 5,770评论 0 5