【CV中的Attention机制】融合Non-Local和SENet的GCNet

前言: 之前已经介绍过SENet和Non Local Neural Network(NLNet),两者都是有效的注意力模块。作者发现NLNet中attention maps在不同位置的响应几乎一致,并结合SENet后,提出了Global Context block,用于全局上下文建模,在主流的benchmarks中的结果优于SENet和NLNet。

GCNet论文名称为:《GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond》,是由清华大学提出的一个注意力模型,与SE block、Non Local block类似,提出了GC block。为了克服NL block计算量过大的缺点,提出了一个Simplified NL block,由于其与SE block结构的相似性,于是在其基础上结合SE改进得到GC block。

SENet中提出的SE block是使用全局上下文对不同通道进行权值重标定,对通道依赖进行调整。但是采用这种方法,并没有充分利用全局上下文信息。

捕获长距离依赖关系的目标是对视觉场景进行全局理解,对很多计算机视觉任务都有效,比如图片分类、视频分类、目标检测、语义分割等。而NLNet就是通过自注意力机制来对长距离依赖关系进行建模。

作者对NLNet进行试验,选择COCO数据集中的6幅图,对于不同的查询点(query point)分别对Attention maps进行可视化,得到以下结果:

image

可以看出,对于不同的查询点,其attention map是几乎一致的,这说明NLNet学习到的是独立于查询的依赖(query-independent dependency),这说明虽然NLNet想要对每一个位置进行特定的全局上下文计算,但是可视化结果以及实验数据证明,全局上下文不受位置依赖。

基于以上发现,作者希望能够减少不必要的计算量,降低计算,并结合SENet设计,提出了GCNet融合了两者的优点,既能够有用NLNet的全局上下文建模能力,又能够像SENet一样轻量。

作者首先针对NLNet的问题,提出了一个Simplified NLNet, 极大地减少了计算量。

image

NLNet 中的Non-Local block可以表示为:
z_i=x_i+W_z\sum^{N_p}_{j=1}\frac{f(x_i,x_j)}{C(x)}(W_v×x_j)
输入的feature map定义为x=\{x_i\}^{N_p}_{i=1}, N_p是位置数量。x和z是NL block输入和输出。i是位置索引,j枚举所有可能位置。f(x_i,x_j)表示位置i和j的关系,C(x)是归一化因子。$W_z和W_v是线性转换矩阵。

NLNet中提出了四个相似度计算模型,其效果是大概相似的。作者以Embedded Gaussian为基础进行改进,可以表达为:
W_{ij}=\frac{exp(W_qx_i,W_kx_j)}{\sum_{m}exp(W_qx_i,W_kx_m)}
简化后版本的Simplified NLNet想要通过计算一个全局注意力即可,可以表达为:
z_i=x_i+W_v\sum^{N_p}_{j=1}\frac{exp(W_kx_j)}{\sum^{N_p}_{m=1}exp(W_kx_m)}x_j
这里的W_v、W_q、W_k都是1\times1卷积,具体实现可以参考上图。

简化后的NLNet虽然计算量下去了,但是准确率并没有提升,所以作者观察到SENet与当前的模块有一定的相似性,所以结合了SENet模块,提出了GCNet。

image

可以看出,GCNet在上下文信息建模这个地方使用了Simplified NL block中的机制,可以充分利用全局上下文信息,同时在Transform阶段借鉴了SE block。

GC block在ResNet中的使用位置是每两个Stage之间的连接部分,下边是GC block的官方实现(基于mmdetection进行修改):

代码实现:

import torch
from torch import nn

class ContextBlock(nn.Module):
    def __init__(self,inplanes,ratio,pooling_type='att',
                 fusion_types=('channel_add', )):
        super(ContextBlock, self).__init__()
        valid_fusion_types = ['channel_add', 'channel_mul']

        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'

        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types

        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None


    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)
        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)
        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term
        return out

if __name__ == "__main__":
    in_tensor = torch.ones((12, 64, 128, 128))

    cb = ContextBlock(inplanes=64, ratio=1./16.,pooling_type='att')
    
    out_tensor = cb(in_tensor)

    print(in_tensor.shape)
    print(out_tensor.shape)

对这个模块进行了测试,需要说明的是,如果ratio × inplanes < 1, 将会出问题,这与通道个数有关,通道的个数是无法小于1的。

实验部分

作者基于mmdetection进行修改,添加了GC block,以下是消融实验。

image
  • 从block设计来讲,可以看出Simplified NL与NL几乎一直,但是参数量要小一些。而每个阶段都使用GC block的情况下能比baseline提高2-3%。
  • 从添加位置进行试验,在residual block中添加在add操作之后效果最好。
  • 从添加的不同阶段来看,施加在三个阶段效果最优,能比baseline高1-3%。
image
  • Bottleneck设计方面,测试使用缩放、ReLU、LayerNorm等组合,发现使用简化版的NLNet并使用1×1卷积作为transform的时候效果最好,但是其计算量太大。

  • 缩放因子设计:发现ratio=1/4的时候效果最好。

  • 池化和特征融合设计:分别使用average pooling和attention pooling与add、scale方法进行组合实验,发现attention pooling+add的方法效果最好。

image

此处对ImageNet数据集进行了实验,提升大概在1%以内。

image

在动作识别数据集Kinetics中,也取得了1%左右的提升。

总结:GCNet结合了SENet和Non Local的优点,在满足计算量相对较小的同时,优化了全局上下文建模能力,之后进行了详尽的消融实验证明了其在目标检测、图像分类、动作识别等视觉任务中的有效性。这篇论文值得多读几遍。


参考:

论文地址:https://arxiv.org/abs/1904.11492

官方实现代码:https://github.com/xvjiarui/GCNet

文章中核心代码:https://github.com/pprp/SimpleCVReproduction/tree/master/attention/GCBlock

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,372评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,368评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,415评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,157评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,171评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,125评论 1 297
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,028评论 3 417
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,887评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,310评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,533评论 2 332
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,690评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,411评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,004评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,659评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,812评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,693评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,577评论 2 353