Swin Transformer

目前transformer从语言到视觉任务的挑战主要是由于这两个领域间的差异:

  • 1、尺度变化大
  • 2、高分辨率的输入

为了解决以上两点,我们提出了层级Transformer,通过滑动窗口提取特征的方式将使得self.attention的计算量降低为和图像尺寸的线性相关。

简介

我们观察到将语言领域迁移到视觉领域的主要问题可以被总结为两种:

  • 1、不同于word token,它的尺度是固定的,但是视觉领域的尺度变化非常剧烈
  • 2、相对于上下文中的words,图片有着更高分辨率的像素,计算量会随着图片的尺寸成平方倍的增长。

结构

image.png

以上是论文中结构图,每一个stage feature map的尺寸都会减半。易知主要分为四个模块:

  • Patch Partition
  • Linear Embedding
  • Swin Transformer Block(主要模块)
    • W-MSA:regular window partitionmutil-head self attention
    • SW-MSA: shift window partitionmutil-head self attention
  • Patch Merging

1、Patch Partition 和 Linear Embedding

在源码实现中两个模块合二为一,称为PatchEmbedding。输入图片尺寸为H \times W \times 3 的RGB图片,将4x4x3视为一个patch,用一个linear embedding 层将patch转换为任意dimension(通道)的feature。源码中使用4x4的stride=4的conv实现。-> \frac{H}{4} \times \frac{W}{4} \times C

class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim
       
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

2、Swin Transformer Block

这是这篇论文的核心模块。

  • 如何解决计算量随着输入尺寸的增大成平方倍的增长? 抛弃传统的transformer基于全局来计算注意力的方法,将输入划分为不同的窗口,分别对每个窗口(window)施加注意力。
  • 仅仅对窗口(window)单独施加注意力,如何解决窗口(window)之间的信息流动?交替使用W-MSASW-MSA模块,因此SwinTransformerBlock必须是偶数。如下图所示:
    image.png

    整体流程如下:
    • 先对特征图进行LayerNorm
    • 通过self.shift_size决定是否需要对特征图进行shift
    • 然后将特征图切成一个个窗口
    • 计算Attention,通过self.attn_mask来区分Window Attention还是Shift Window Attention
    • 将各个窗口合并回来
    • 如果之前有做shift操作,此时进行reverse shift,把之前的shift操作恢复
    • 做dropout和残差连接
    • 再通过一层LayerNorm+全连接层,以及dropout和残差连接

2.1、window partition

window partition分为regular window partitionshift window partition,对应于W-MSASW-MSA。通过窗口划分,将输入的feature map B \times H \times W \times C转换为num_windows*B, window_size, window_size, C,其中 num_windows = H*W / window_size / window_size。然后resize 到 num_windows*B, window_size*window_size, C进行attention。源码如下:

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows
image.png
  • Layer1regular window partition,窗口的大小是4x4,将图片分成了4个窗口。
  • Layer2shift window partition,为了保证不同窗口的信息流动,起始点从(windows_size//2, windows_size//2)开始进行划分,将图片分成了9个窗口。可以看到移位后的窗口包含了原本相邻窗口的元素。但是同时也引入了新的问题,窗口大小不一致的问题,有2x2、2x4、4x2、4x4,最简单的方法就是统一padding到4x4,但是窗口数量由4增加至9,计算量变大了2.25倍。因此作者提出了cycle shift去解决这个问题。
    image.png

    以下的示例图片来自于:https://mp.weixin.qq.com/s/8x1pgRLWaMkFSjT7zjhTgQ
    image.png

    首先对窗口进行shift window partition,得到左图部分。不进行padding,而是采用滚动的方式调整窗口,源码中用torch.roll()函数实现,得到了右图。这时候得到了和regular window partition一样的4个2x2大小的window,不同的是,在一个2x2的windows区域内是不连续的(index不一样)。
    image.png

    我们希望在计算Attention的时候,让具有相同index Q \times K^T进行计算,而忽略不同index QK计算结果。因此我们为其添加上mask。源码计算mask实现如下:
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

2.2、W-MSA

regular window partition模块 和 mutil-head self attention模块组成。
W-MSA相比于直接使用MSA主要是为了降低计算量。传统的transformer都是基于全局来计算注意力,因此计算复杂度非常高。但是swin transformer通过对每个窗口施加注意力,从而减少了计算量。attention的主要计算过程如下:
Q=x \times W^q \\ K=x \times W^k \\ V=x \times W^v \\ attn=Q \times K^T \\ Z = attn \times V \\ output = Z \times W^z
假设每一个window的区块大小为M\times M,输入的尺寸为h \times w,以下为原始的MSAW-MSA的计算复杂度:
\Omega(\mathrm{MSA})=4 h w C^{2}+2(h w)^{2} C\\ \Omega(\mathrm{W}-\mathrm{MSA})=4 h w C^{2}+2 M^{2} h w C

  • 对于MSA:对输入的feature map做全局attention,QKV的计算量分别是hwC^2attnZ的计算量分别是(hw)^2Coutput的计算量是hwC^2
  • 对于W-MSA:在windows内的M \times M大小的区域内做attention,feature map会被划分为\frac{h}{M} \times \frac{w}{M}windows,每个windows的尺寸为M \times MQKV的计算量分别是hwC^2attnZ的计算量的分别是M^2 hwCoutput的计算量是hwC^2。因此和输入尺寸成线性关系。

2.3、SW-MSA

虽然W-MSA降低了计算量,但是由于将attention限制在window内,因此不重合的window缺乏联系,限制了模型的性能。因此提出了SW-MSA模块。在MSA前面加上一个cycle shift window partition

3、Patch Merging

swin transformer中没有使用pooling进行下采样,而是使用了和yolov5中的focus层进行feature map的下采样。H\times W\times C -> \frac{H}{2} \times \frac{W}{2} \times 4C,在使用一个全连接层->\frac{H}{2} \times \frac{W}{2} \times 2C,在一个stage中将feature map的高宽减半,通道数翻倍。

image.png

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

不同尺寸的网络结构

基准模型结构命名为Swin-B,模型大小和计算复杂度和ViT-B/DeiT-B相近。同时我们也提出了Swin-TSwin-SSwin-L,分别对应0.25×, 0.5×倍的模型尺寸和计算复杂度。Swin-TSwin-S的计算复杂度分别和ResNet-50ResNet-101相近。M默认设置为7。C代表第一层隐藏层的数量。

  • Swin-T: C = 96, layer numbers = {2, 2, 6, 2}
  • Swin-S: C = 96, layer numbers ={2, 2, 18, 2}
  • Swin-B: C = 128, layer numbers ={2, 2, 18, 2}
  • Swin-L: C = 192, layer numbers ={2, 2, 18, 2}
image.png

不同数据集的实验结果

1、ImageNet

image.png

2、COCO Object Detection

  • 在不同的模型上使用swin transformer 作为特征提取网络
  • 在cascade mask rcnn上使用swin transformer 作为backbone
  • 直接对比其他目标检测网络


    image.png

3、Semantic Segmentation on ADE20K

image.png

消融实验

1、shifted windows 的有效性

image.png

2、position bias

image.png

3、sliding window 和 shift window的速度和性能

image.png
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 194,390评论 5 459
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 81,821评论 2 371
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 141,632评论 0 319
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 52,170评论 1 263
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 61,033评论 4 355
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 46,098评论 1 272
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 36,511评论 3 381
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 35,204评论 0 253
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 39,479评论 1 290
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 34,572评论 2 309
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 36,341评论 1 326
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,213评论 3 312
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 37,576评论 3 298
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 28,893评论 0 17
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,171评论 1 250
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 41,486评论 2 341
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 40,676评论 2 335

推荐阅读更多精彩内容