PN-GAN代码阅读笔记

一、简介

https://github.com/naiq/PN_GAN是论文Pose-Normalized Image Generation for Person Re-identification的实现代码

二、代码梳理

2.1 网络构建

2.1.1 生成网络

参数:ngf(channel数相关,可理解为特征图个数), num_resblock(残差块个数,默认为9个)
拼接原图片和pose图片:

 x = torch.cat((im, pose), dim=1)

下采样部分(卷积):

       self.conv1 = nn.Sequential(OrderedDict([
            ('pad', nn.ReflectionPad2d(3)),
            ('conv', nn.Conv2d(6, ngf, kernel_size=7, stride=1, padding=0, bias=True)),
            ('bn', nn.InstanceNorm2d(ngf)),
            ('relu', nn.ReLU(inplace=True)),
        ]))
        self.conv2 = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=1, bias=True)),
            ('bn', nn.InstanceNorm2d(ngf*2)),
            ('relu', nn.ReLU(inplace=True)),
        ]))
        self.conv3 = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=1, bias=True)),
            ('bn', nn.InstanceNorm2d(ngf*4)),
            ('relu', nn.ReLU(inplace=True)),
        ]))

残差块部分,包括num_resblock个ResBlock,ResBlock如下:

# ncf 为ngf*4
  ...
        self.conv1 = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(ncf, ncf, kernel_size=3, stride=1, padding=1, bias=use_bias)),
            ('bn', nn.InstanceNorm2d(ncf)),
            ('relu', nn.ReLU(inplace=True)),
        ]))
        self.conv2 = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(ncf, ncf, kernel_size=3, stride=1, padding=1, bias=use_bias)),
            ('bn', nn.InstanceNorm2d(ncf)),
        ]))
    ....
 

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = out + x
        out = self.relu(out)

        return out

上采样部分(解卷积):

      self.deconv3 = nn.Sequential(OrderedDict([
            ('deconv', nn.ConvTranspose2d(ngf*4, ngf*2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True)),
            ('bn', nn.InstanceNorm2d(ngf*2)),
            ('relu', nn.ReLU(True))
        ]))
        self.deconv2 = nn.Sequential(OrderedDict([
            ('deconv', nn.ConvTranspose2d(ngf*2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True)),
            ('bn', nn.InstanceNorm2d(ngf)),
            ('relu', nn.ReLU(True))
        ]))
        self.deconv1 = nn.Sequential(OrderedDict([
            ('pad', nn.ReflectionPad2d(3)),
            ('conv', nn.Conv2d(ngf, 3, kernel_size=7, stride=1, padding=0, bias=False)),
            ('tanh', nn.Tanh())
        ]))

2.1.2 分类网络( Patch_Discriminator)

参数:ndf(channel数相关,可理解为特征图个数)
下采样部分:

      self.conv1 = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(3, ndf, kernel_size=4, stride=2, padding=1, bias=False)),
            ('relu', nn.LeakyReLU(0.2, True))
        ]))
        self.conv2 = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1, bias=True)),
            ('bn', nn.InstanceNorm2d(ndf*2)),
            ('relu', nn.LeakyReLU(0.2, True))
        ]))
        self.conv3 = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1, bias=True)),
            ('bn', nn.InstanceNorm2d(ndf*4)),
            ('relu', nn.LeakyReLU(0.2, True))
        ]))
        self.conv4 = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=1, padding=0, bias=True)),
            ('bn', nn.InstanceNorm2d(ndf*8)),
            ('relu', nn.LeakyReLU(0.2, True))
        ]))

dis层,即为最后一层卷积:

     self.dis = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(ndf*8, 1, kernel_size=4, stride=1, padding=0, bias=False)),
        ]))

dis.squeeze()为网络最后输出,dis的channel数为1,又经过squeeze,即将dis弄成了batch_sizehw的尺寸。

2.2 优化器

生成网络的优化器为Adam优化器,lr=cfg.TRAIN.LR(默认为0.0002), betas=(0.5, 0.999)。
分类网络的优化器相同
生成网络和分类网络的学习率调整策略均为:
lr_policy = lambda epoch: (1 - 1 * max(0, epoch-cfg.TRAIN.LR_DECAY) / cfg.TRAIN.LR_DECAY)
即第epoch的学习率为原始lr * (1 - 1 * max(0, epoch-cfg.TRAIN.LR_DECAY) / cfg.TRAIN.LR_DECAY)

2.3 训练

生成网络G采用的损失函数为torch.nn.MSELoss(),分类网络D的损失函数为 torch.nn.L1Loss(),分别记做criterionGAN和criterionIdt

2.3.1 数据处理

经过一系列数据预处理之后,得到src_img(原图片),tgt_img(目标图片),pose。

2.3.2 生成图片

根据原图片src_img和姿态图片pose生成原图片中行人姿态变为pose的新图片fake_img

fake_img = netG(src_img, pose)

2.3.3 更新生成器

            D_fake_img = netD(fake_img)
            G_loss = criterionGAN(D_fake_img, torch.ones_like(D_fake_img))
            idt_loss = criterionIdt(fake_img, tgt_img) * cfg.TRAIN.lambda_idt

fake_img经过分类网络D得到D_fake_img,计算Lgan和Ll1(与论文对应):
G_loss对应Lgan:


QQ图片20190918172048.png

G_loss使用mse计算的,跟原文不大一致,但反正就是跟分类器相关,也就是生成器的理想状态是让生成图片被分类器识别成原图,也就是分类结果应为 torch.ones_like(D_fake_img)。
而idt_loss对应的就是:


L1loss.png

这是用了l1 loss计算的,跟原文一致
生成网络的优化目标为:
lgp.png

对应的代码为:

loss_G = G_loss + idt_loss

注意公式中lamda1就是 cfg.TRAIN.lambda_idt。

最后就是更新网络了:

optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()

2.3.4 更新分类器

首先是损失函数的计算:

            D_fake_img = netD(fake_img.detach())
            D_real_img = netD(src_img)

            D_fake_loss = criterionGAN(D_fake_img, torch.zeros_like(D_fake_img))
            D_real_loss = criterionGAN(D_real_img, torch.ones_like(D_real_img))

            loss_D = D_fake_loss + D_real_loss

D_fake_img就是分类器对生成网络生成的图片的分类,D_real_img就是分类器对原图片的分类。原文中,分类器的损失函数公式为:


Ldp.png

但实际上实现的时候,就是计算分类的误差,也就是D_fake_img与理想值(全为0)和D_real_img与理想值(全为1)的平方和。
最后也是更新网络:

            optimizer_D.zero_grad()
            loss_D.backward()
            optimizer_D.step()
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 217,734评论 6 505
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,931评论 3 394
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,133评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,532评论 1 293
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,585评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,462评论 1 302
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,262评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,153评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,587评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,792评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,919评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,635评论 5 345
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,237评论 3 329
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,855评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,983评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,048评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,864评论 2 354

推荐阅读更多精彩内容

  • 前言:本文简述内存管理相关内容,如有错误请留言指正。 第一部分-定时器 1.1 NSTimer和CADisplay...
    梦蕊dream阅读 630评论 0 4
  • 每每听到别人说:“一班同学真团结!″这句话时,我心里就比吃了蜜还甜。因为我们一班不怕困难、不服,所有人的力量都拧成...
    爱肖兔兔DAYTOY阅读 151评论 0 3
  • 沟通手术1.5小时 协助迟完成第一台手术2.5小时 配合孙设计完成第二台手术3.5小时 值班看护留观顾客4.5小时...
    37a6b6adef7c阅读 125评论 0 0
  • 短暂几天的出差即将结束,已经定好了明天的机票回家。 今天晚上是和洋洋在蜗居里的最后一晚。经过7天的生活下来,我完全...
    我在安好阅读 253评论 0 2