论文研读:SA-GAN

前言

SA-GAN( Self-Attention GAN)[1] 由 Google Brain出品,效果比之前的 SOTA 提高的有点多,并且还有开源代码学习下:https://github.com/brain-research/self-attention-gan

PS. Attention Mechanism 相关的论文还没怎么看,所以注意力的模块原理等先不展开讨论,主要讨论SA-GAN中的自注意力模块的设计。

问题以及 idea

SA-GAN 先从之前的SOTA的图像生成模型[2] 出发,发现之前的模型在生成纹理性强的类的图片(如海洋、天空等)时效果好,但是生成结构性强的图片效果不好。

SA-GAN 针对的是图像生成中卷积层的缺点:传统的卷积层由于本身局部性的特点只能处理一定大小的局部相邻的信息,如果仅使用卷积层对图像进行建模,对图像中长距离依赖的特征的建模效率不高(要使用多层卷积层才能处理到更大范围、更高层的特征)。基于这个问题,SA-GAN 提出了自注意力层。

Self-Attention Module

自注意力模块架构如图:

Self-Attention Module

对自注意力模块的输入图片{x \in \mathbb{R}^{C \times N}}

此时 {x} 是将宽高两个维度flatten成一个维度了({N = W \times H}
),将某一个位置(i,j)的通道上的向量看成该位置的拥有的feature,也就是{x_{i} \in \mathbb{R}^{C \times 1}},接着做如下的操作:

  • {x}通过函数{f,g}

  • 计算{s},其中{s_{ij} = f(x_i)^T g(x_j) }

  • {s}按行做softmax归一化得到{\beta},此时{\beta}的元素{\beta_{ji}}表示了第{i}位置的像素的feature 对第j位置的输出的贡献(权重,注意力)

  • {x}通过函数{h}变换,再使用计算的注意力{\beta}{h(x)}相乘(也就是一个加权的过程),这样输出的每个像素都由原来全部像素点的特征组合而来。

  • 通过函数{v}得到输出。

  • 最后将注意力层的输出乘上一个因子{\gamma}和输入累加输出。

{ y = \gamma o + x }

实现中:

  • 其中,四个函数的权重矩阵的大小为:
    {W_f \in \mathbb{R}^{\tilde{C} \times C }}{W_g \in \mathbb{R}^{\tilde{C} \times C }}{W_h \in \mathbb{R}^{\tilde{C} \times C }}{W_f \in \mathbb{R}^{C \times \tilde{C} }}

  • 实现中所有函数都由一个1x1的卷积层代替,{\tilde{C} = C / k},并且文中提到将{\tilde{C}}调整到{C/8} 时模型表现没有下降,为了效率起见使用{k = 8}

  • 其中自注意力层的输出比例{\gamma}为可学习的参数,初始化为0,由模型学习自注意力层应该在原图上修改多少细节。

设计思路阐述:

  • 使用1x1卷积层和{\tilde{C} = C / 8} 的设计应该是减少计算量,{f,g,h}{v}就是分别起降维和升维的作用。

  • 输入图片{x}的每个位置在通道维度的向量作为该位置的特征向量{x_i}

简单的pytorch 实现:

import torch.nn as nn 
import torch 
import torch.nn.functional as F
from torch.nn.utils import spectral_norm

def sn_conv1x1(in_dim, out_dim):
    return spectral_norm(nn.Conv2d(in_dim, out_dim, 1,  1))

class SAModule(nn.Module):
    def __init__(self, dim, k = 8):
        super(SAModule, self).__init__()
        self.c = dim // k
        self.f = sn_conv1x1(dim, self.c)
        self.g = sn_conv1x1(dim, self.c)
        self.h = sn_conv1x1(dim, self.c)
        self.v = sn_conv1x1(self.c, dim)
        self.gamma = nn.Parameter(torch.tensor(0, dtype=torch.float),  requires_grad=True)
    
    def forward(self, x):
        N, C, H, W = x.size()
        num_locals = H * W
        theta = self.f(x).view(N, self.c, num_locals).permute(0,2,1)
        phi = self.g(x).view(N, self.c, num_locals)

        s = torch.bmm(theta, phi)
        atten = F.softmax(s, dim = 2)

        a = self.h(x).view(N, self.c, num_locals).permute(0,2,1)
        out = torch.bmm(atten, a).permute(0,2,1).view(N, self.c, H, W).contiguous()
        x0 = self.v(out)

        return x0 * self.gamma + x

使用的 trick

  • Spectral Normalization[3]: 谱归一化在[2]中就已经使用,而且[2][3]都是同一个作者提出,这里简述一下:谱归一化从WGAN出发,是一种限制模型权重Lipschitz常数的方法,好处是计算代价低(仅需要一次迭代),而且不需要超参数。[2] 中的模型仅在判别器上使用,SA-GAN在生成器和判别器上都使用 Spectral Normalization。可以参考下这个简单的实验体会Spectral Normalization 的作用。

  • 判别器和生成器使用不同的学习率,判别器0.0004,生成器0.0001。

  • 在判别器上使用了[2] 中的 projection discriminator 结构,在生成器上使用了 Conditional BatchNorm 的结构。

实验、复现细节、超参数、实验结论

  • 使用 Hinge Adv Loss

  • 文章通过实验证明:对 middel-to-high level 的 feature 使用 Self-Attention Layer 比对 low level 的 feature 上使用效果更好(即自注意力层尽量放置在靠近输出层的卷积层之后)因为这样Self-Attention 可以捕捉更高级特征和更大区域的feature 的联系

  • 文中又用残差块替换原有卷积层训练新模型,结果效果没有baseline好,证明自注意力层的加入起作用不是因为加深了模型、扩大了模型的容量

Reference

  1. Self-Attention Generative Adversarial Networks
  2. cGAN with projection discriminator
  3. Spectral Normalization for Generative Adversarial Networks

后言

文章是为了方便自己理解而写,所以难免有不清楚或错误之处、或者自创的方便理解的术语,如有错误,欢迎指正。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容