DeepMind 提出 Perceiver:使用RNN的方式进行注意力,通过交叉注意力节省计算量,附使用方法

今天要解读的论文来自 DeepMind ,论文名为《Perceiver: General Perception with Iterative Attention》,文中介绍了一种基于 Transformer 的结构,不对数据做任何假设,不需要修改网络结构,就可以利用于各种模态的数据。

image

我们人在感知世界的时候,是通过同时处理各个模态的高维数据,而现在深度学习中使用的方法,都会引入很多领域内的知识,比如现在几乎所有的视觉方法,都引入了”局部性“的假设,即在一张图像内,局部的特征是有用的,这也是 CNN 有用的基本原理。引入这些有帮助信息的同时,也将模型的作用范围限制在了某一个模态以内。

在这篇论文中,作者提出了 Perceiver,它是一个基于 Transformer 的模型,几乎没有做任何关于输入数据之间关系的结构性假设,但是也与 ConvNets 一样,可以扩展到数十万的输入上。

In this paper we introduce the Perceiver - a model that builds upon Transformers and hence makes few architectural assumptions about the relationship between its inputs, but that also scales to hundreds of thousands of inputs, like ConvNets.

作者提出的结构,达到甚至超过了精心设计用于某一个模态的模型的效果。在实验中,作者用了 ImageNet 的图像数据,AudioSet 的视频和音频数据,以及 3D 点云数据。

image

<figcaption class="Image-caption" style="margin-top: 0.66667em; padding: 0px 1em; font-size: 0.9em; line-height: 1.5; text-align: center; color: rgb(153, 153, 153);">所用的数据</figcaption>

方法

image

使用了两个部分来构建网络:

  1. 使用交叉注意力机制(cross attention)来将一个输入向量(文中叫做 byte array)与一个隐向量映射为一个隐向量

  2. 使用 transformer 塔将一个隐向量映射为另一个同样大小的隐向量

输入向量的大小被输入数据所决定,这个一般会很大,例如一张 ImageNet 中的图像,有 224*224 维,也就是 50176 维。而隐向量是模型中的一个超参数,可以人为控制,这个一般很小,作者在 ImageNet 中使用了 1024 维。

所提方法的关键在于:通过一个低维的注意力瓶颈层,将输入的高维数据,映射到低维,再将它送入深度的 transformer 中。

这样做的好处是,如果仅直接使用 transformer 层,那么面临最大的问题是,训练太耗费时间,以及需要非常大的显存。作者在文中分析,transformer 的时间复杂度为序列长度的二次关系,即 O(M^2),这里 M 指序列长度。使用文中提出的交叉注意力机制,变成了 O(MN),而一般可以设置 N 远小于 M。

接下来是一些我的理解:

熟悉注意力机制的都知道,它包括三个部分,分别是Q、K和V。一般的作用方式是,序列长度是多少,那么Q、K 和 V的长度就是多少。但这一点其实是没有必要的。对于一张图,我们不需要每一个位置,都需要一个查询向量(Q)。这样就容易理解,作者提出的结构。对于长度为 M 的序列中的每一个元素,我们会有 N 个查询向量作用于它,所以时间复杂度就变为了 O(MN)。当有了这样 N 个结果以后,再送入传统的 Transformer 结构,这样就极大程度上减少了运算量和显存的占用。

迭代式的注意力机制

瓶颈层可能会限制网络捕捉必要信息的能力,为了缓解这个现象,Perceiver 使用多个 byte-attend 层,也就是交叉注意力层,当网络需要详细的输入信息时候,它就能够获得到这些信息。

最后,借助这样的迭代的注意力机制,可以将网络设计成权值共享的形式(最终的网络结构非常类似于RNN)。权值共享使得参数量减少约 10 倍,减少了网络的过拟合,提高了验证集上的性能。

实验部分

在 ImageNet 上的实验

image

在 ImageNet 上的实验结果。红色的方法代表设计模型时引入了一些特定知识,蓝色的方法代表没有引入。可以看到 Perceiver 达到了非常有竞争力的效果。

将图像像素随机打乱

image

这里作者将图像中的像素随机打乱,Fixed 代表所有图像都是用同一个打乱的方式,Random 代表每张图都是随机打乱,可以看到,当进行随机打乱时,其余方法的性能大幅下降。

后面一列是每个模型输入单元的感受野。

这里可能会有一个问题,就是,既然我们知道图像中局部的信息是有用的,为什么不利用它呢?作者的考虑主要是,这样可以得到一个应用范围更广的模型,因为如果面临的是多模态任务,比如视频、音频、嗅觉传感器和触摸传感器等等数据,再去手动设计输入数据的交互形式是非常困难的。

注意力可视化

image

这里展示的是交叉注意力可视化的结果。

其中,蓝色代表是第一层网络的可视化结果,绿色代表第2-7层网络的结果,橙色代表第八层网络的可视化结果。第一行是每层抽了一个注意力图作为特写。

从图中可以看到,所提方法没有取局部的信息,而是以一种类似格网的形式扫描整张图。

视频音频的结果

image

使用了 AudioSet 数据集,单独使用视频或音频,或者两者结合使用,都达到了最好的结果。

点云数据

image

在点云数据的结果中,PointNet ++ 使用了额外的几何特征,以及更多的增强技术。蓝色的方法都没有使用这些技术。在蓝色的里面,效果是最好的。

使用方法

安装

pip install perceiver-pytorch

使用


import torch
from perceiver_pytorch import Perceiver

model = Perceiver(
    input_channels = 3,          # 序列中每一个元素的维度
    input_axis = 2,              # 输入数据的坐标数(用于构建位置编码,图像的话就是2:x和y)
    num_freq_bands = 6,          # number of freq bands, with original value (2 * K + 1)
    max_freq = 10.,              # maximum frequency, hyperparameter depending on how fine the data is
    depth = 6,                   # 网络深度
    num_latents = 256,           # 隐向量的个数
    cross_dim = 512,             # 交叉注意力的维度
    latent_dim = 512,            # 隐向量的维度
    cross_heads = 1,             # 交叉注意力的头数
    latent_heads = 8,            # 隐自注意力的头数
    cross_dim_head = 64,
    latent_dim_head = 64,
    num_classes = 1000,          # 最终输出类别数
    attn_dropout = 0.,
    ff_dropout = 0.,
)

img = torch.randn(1, 224, 224, 3) # imagenet 图像数据

model(img) # (1, 1000)

参考资料


写在最后:如果觉得这篇文章对您有帮助,欢迎点赞收藏评论支持我,谢谢!
也欢迎关注我的公众号:算法小哥克里斯。

推荐阅读:

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 219,366评论 6 508
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,521评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 165,689评论 0 356
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,925评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,942评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,727评论 1 305
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,447评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,349评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,820评论 1 317
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,990评论 3 337
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,127评论 1 351
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,812评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,471评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,017评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,142评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,388评论 3 373
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,066评论 2 355

推荐阅读更多精彩内容