虽迟但到,重新定义视觉backbone:CNN+Transformer->BoTNet

在过往的十年里,卷积神经网络CNN支配了计算机视觉的技术发展,下一个十年,或许将由Transformer接过接力棒了?

还记得在2020年的年末回顾中(《[年终AI大事件回顾]我眼中的2020年TOP5》),我们梳理了2020年的TOP5的AI领域研究工作。其中被排在第四位的,便是一篇基于Transformer的视觉领域的工作。

当时给出的评语是:

在更复杂的图像任务上的效果还有待探索和加强,但这一工作让我们看到了“跨界”的可能,相信未来会有更多在视觉上的工作基于Transformer而展开。

图源[1] Vision Transformer

当时想的是:除了单纯用Transformer的结构完全替代CNN解决CV问题之外,是时候把CNN和Transformer结合在一起了吧。类似的工作比如说DETR,用CNN提取图像特征,在之后接了Transformer的encoder和decoder。

2021年1月27日在arxiv上发的一篇文章Bottleneck Transformers for Visual Recognition同样是采用了CNN+Transformer,但在我看来,似乎是更加elegant的做法:

1. 将Transformer的Self-attention融入了一个CNN的backbone中,而非叠加;

2. 具体来说是在ResNet的最后三个bottleneck blocks中用MHSA(多头自注意力层,Multi-Head Self-attention)替换了原本的3x3卷积(下图)。这些新的blocks被命名为BoT blocks,这一新的网络被命名为BotNet

图源[2] ResNet的Bottleneck vs. 加入自注意力的Bottleneck

这样做的好处是显而易见的

1. 能够利用成熟的、经过检验的CNN网络结构提取特征,CNN在视觉领域是有一些先验或者inductive biases的;

2. 用CNN对输入图像做了下采样后,再由self-attention进行运算,相比于直接使用self-attention在原图上处理,能够降低运算量;

3. 这样的设计能够与其他方法结合,例如可能作为backbone应用于DETR中。

在MHSA层中,特征输入X经过WQ, WK, WV三个矩阵映射成q,k,v,分别代表query,key和value。Self-attention的操作一般是对qkv进行计算。

Attention的计算公式

Multi-Head(多头)体现在对每一个head都有不同的WQ, WK, WV,完成对特征输入的映射,并进行上述自注意力的运算,以拓展模型在不同的表示空间里学习。

此外,由于上述的multi-head和self-attention操作中,没有引入与图像中位置相关的信息。因此,引入了相对位置编码(relative position encoding):R_h和R_w,分别表征高度和宽度编码。

这三个操作结合起来便是下图中的结构,对于熟悉Transformer的同学而言,其实并没有太多特殊的操作。

图源[2] 多头自注意力结构Multi-Head Self-Attention

这里引用一段在Pytorch上对这一Attention层的非官方实现[3]:

class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        fmap_size,
        heads = 4,
        dim_head = 128,
        rel_pos_emb = False
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = heads * dim_head

        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)

        rel_pos_class = AbsPosEmb if not rel_pos_emb else RelPosEmb
        self.pos_emb = rel_pos_class(fmap_size, dim_head)

    def forward(self, fmap):
        heads, b, c, h, w = self.heads, *fmap.shape

        q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), (q, k, v))

        q *= self.scale

        sim = einsum('b h i d, b h j d -> b h i j', q, k)
        sim += self.pos_emb(q)

        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        return out

其中位置编码的做法(求qr^T):

def rel_to_abs(x):
    """
    Converts relative indexing to absolute.
    Input: [bs, heads, length, 2*length - 1]
    Output: [bs, heads, length, length]
    """
    b, h, l, _, device, dtype = *x.shape, x.device, x.dtype
    dd = {'device': device, 'dtype': dtype}
    col_pad = torch.zeros((b, h, l, 1), **dd)
    x = torch.cat((x, col_pad), dim = 3)
    flat_x = rearrange(x, 'b h l c -> b h (l c)')
    flat_pad = torch.zeros((b, h, l - 1), **dd)
    flat_x_padded = torch.cat((flat_x, flat_pad), dim = 2)
    final_x = flat_x_padded.reshape(b, h, l + 1, 2 * l - 1)
    final_x = final_x[:, :, :l, (l-1):]
    return final_x

def relative_logits_1d(q, rel_k):
    """
    Compute relative logits along one dimenion.
    `q`: [bs, heads, height, width, dim]
    `rel_k`: [2*width - 1, dim]
    """
    b, heads, h, w, dim = q.shape
    logits = einsum('b h x y d, r d -> b h x y r', q, rel_k)
    logits = rearrange(logits, 'b h x y r -> b (h x) y r')
    logits = rel_to_abs(logits)
    logits = logits.reshape(b, heads, h, w, w)
    logits = expand_dim(logits, dim = 3, k = h)
    return logits

class RelPosEmb(nn.Module):
    def __init__(
        self,
        fmap_size,
        dim_head
    ):
        super().__init__()
        height, width = pair(fmap_size)
        scale = dim_head ** -0.5
        self.fmap_size = fmap_size
        self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)
        self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)

    def forward(self, q):
        h, w = self.fmap_size

        q = rearrange(q, 'b h (x y) d -> b h x y d', x = h, y = w)
        rel_logits_w = relative_logits_1d(q, self.rel_width)
        rel_logits_w = rearrange(rel_logits_w, 'b h x i y j-> b h (x y) (i j)')

        q = rearrange(q, 'b h x y d -> b h y x d')
        rel_logits_h = relative_logits_1d(q, self.rel_height)
        rel_logits_h = rearrange(rel_logits_h, 'b h x i y j -> b h (y x) (j i)')
        return rel_logits_w + rel_logits_h

下图定量分析了BoTNets在ImageNet数据集上的性能:top-1 acc与EfficientNet B7匹敌,但运算速度更快。

图源[2] 实验结果

定性分析分别使用ResNet50(下图左)和BoTNet50(下图右)作为MaskRCNN的backbone网络:

图源[2] 定性分析:ResNet-50 vs. BoTNet-50 Mask R-CNN

总结一下,这篇论文虽然做了一些不一样的工作,但却只是开了个头。
在我看来,未来必然的发展是:

1. 改进Transformer的计算和存储效率,使之能适用于低算力平台;

2. CNN+Transformer的结构会存在一段时间,但最终被Transformer完全取代;

3. 最终的形态是类似于在NLP领域的应用——在大量图像数据上训练出的Transformer网络,能够适用于多个下游视觉任务。

最后,在这里期待一下,下一个“造福人类”、引领下一个十年的方法是否会出现在这个领域呢。

图源[2]What will be the next?

参考资料:
[1]https://arxiv.org/pdf/2010.11929.pdf
[2] https://arxiv.org/abs/2101.11605
[3] https://github.com/lucidrains/bottleneck-transformer-pytorch

- END -

新朋友们可以看看我过往的相关文章⬇
【相关推荐阅读】
[年终AI大事件回顾]我眼中的2020年TOP5
模式识别学科发展报告丨前言
梯度手术-多任务学习优化方法[NeurIPS 2020]

欢迎分享/转载,并注明出处。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,864评论 6 494
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,175评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 159,401评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,170评论 1 286
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,276评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,364评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,401评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,179评论 0 269
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,604评论 1 306
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,902评论 2 328
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,070评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,751评论 4 337
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,380评论 3 319
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,077评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,312评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,924评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,957评论 2 351

推荐阅读更多精彩内容