简介
Visual Transformer (ViT) 出自于论文《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》,是基于Transformer的模型在视觉领域的开篇之作。本文将尽可能简洁地介绍一下ViT模型的整体架构以及基本原理。ViT模型是基于Transformer Encoder模型的,在这里假设读者已经了解Transformer的基本知识,如果不了解可以参考链接。
Vision Transformer如何工作
我们知道Transformer模型最开始是用于自然语言处理(NLP)领域的,NLP主要处理的是文本、句子、段落等,即序列数据。但是视觉领域处理的是图像数据,因此将Transformer模型应用到图像数据上面临着诸多挑战,理由如下:
- 与单词、句子、段落等文本数据不同,图像中包含更多的信息,并且是以像素值的形式呈现。
- 如果按照处理文本的方式来处理图像,即逐像素处理的话,即使是目前的硬件条件也很难。
- Transformer缺少CNNs的归纳偏差,比如平移不变性和局部受限感受野。
- CNNs是通过相似的卷积操作来提取特征,随着模型层数的加深,感受野也会逐步增加。但是由于Transformer的本质,其在计算量上会比CNNs更大。
- Transformer无法直接用于处理基于网格的数据,比如图像数据。
为了解决上述问题,Google的研究团队提出了ViT模型,它的本质其实也很简单,既然Transformer只能处理序列数据,那么我们就把图像数据转换成序列数据就可以了呗。下面来看下ViT是如何做的。
ViT模型架构
我们先结合下面的动图来粗略地分析一下ViT的工作流程,如下:
- 将一张图片分成patches
- 将patches铺平
- 将铺平后的patches的线性映射到更低维的空间
- 添加位置embedding编码信息
- 将图像序列数据送入标准Transformer encoder中去
- 在较大的数据集上预训练
- 在下游数据集上微调用于图像分类
ViT工作原理解析
我们将上图展示的过程近一步分解为6步,接下来一步一步地来解析它的原理。如下图:步骤1、将图片转换成patches序列
这一步很关键,为了让Transformer能够处理图像数据,第一步必须先将图像数据转换成序列数据,但是怎么做呢?假如我们有一张图片,patch大小为,那么我们可以创建个图像patches,可以表示为,其中,就是序列的长度,类似一个句子中单词的个数。在上面的图中,可以看到图片被分为了9个patches。
步骤2、将Patches铺平
在原论文中,作者选用的patch大小为16,那么一个patch的shape为(3,16,16),维度为3,将它铺平之后大小为3x16x16=768。即一个patch变为长度为768的向量。不过这看起来还是有点大,此时可以使用加一个Linear transformation,即添加一个线性映射层,将patch的维度映射到我们指定的embedding的维度,这样就和NLP中的词向量类似了。
步骤3、添加Position embedding
与CNNs不同,此时模型并不知道序列数据中的patches的位置信息。所以这些patches必须先追加一个位置信息,也就是图中的带数字的向量。实验表明,不同的位置编码embedding对最终的结果影响不大,在Transformer原论文中使用的是固定位置编码,在ViT中使用的可学习的位置embedding 向量,将它们加到对应的输出patch embeddings上。
步骤4、添加class token
在输入到Transformer Encoder之前,还需要添加一个特殊的class token,这一点主要是借鉴了BERT模型。添加这个class token的目的是因为,ViT模型将这个class token在Transformer Encoder的输出当做是模型对输入图片的编码特征,用于后续输入MLP模块中与图片label进行loss计算。
步骤5、输入Transformer Encoder
将patch embedding和class token拼接起来输入标准的Transformer Encoder中,
步骤6、分类
注意Transformer Encoder的输出其实也是一个序列,但是在ViT模型中只使用了class token的输出,将其送入MLP模块中,去输出最终的分类结果。
总结
ViT的整体思想还是比较简单,主要是将图片分类问题转换成了序列问题。即将图片patch转换成token,以便使用Transformer来处理。听起来很简单,但是ViT需要在海量数据集上预训练,然后在下游数据集上进行微调才能取得较好的效果,否则效果不如ResNet50等基于CNN的模型。