深度学习模块17-EMSA模块

16、EMSA模块

论文《ResT: An Efficient Transformer for Visual Recognition》

1、作用

ResT是一种高效的多尺度视觉Transformer,作为图像识别领域的通用骨架。与现有的Transformer方法相比,ResT在处理不同分辨率的原始图像时具有多个优点:(1) 构建了一个内存高效的多头自注意力机制,通过简单的深度卷积来压缩内存,并在保持多头注意力的多样性能力的同时,跨注意力头维度进行交互;(2) 位置编码被设计为空间注意力,更加灵活,可以处理任意大小的输入图像,而无需插值或微调;(3) 与在每个阶段开始时直接对原始图像进行分块(tokenization)不同,设计了将分块嵌入作为一系列重叠卷积操作的堆栈,实现了更有效的特征提取。

2、机制

1、多头自注意力压缩

ResT通过简单的深度卷积操作压缩内存,并在注意力头维度进行交互,减少了MSA在Transformer块中的计算和内存需求。

2、空间注意力的位置编码

ResT采用位置编码作为空间注意力,使模型能够灵活地处理不同尺寸的输入图像。

3、重叠卷积的分块嵌入

通过设计分块嵌入为重叠卷积操作的堆栈,ResT在不同阶段有效地提取特征,创建了多尺度的特征金字塔。

3、独特优势

1、高效和灵活性

ResT通过引入压缩的多头自注意力机制和空间注意力的位置编码,在保持计算效率的同时,提供了处理不同分辨率图像的灵活性。

2、改进的特征提取能力

通过重叠卷积的分块嵌入,ResT能够更有效地捕获图像中的局部和全局信息,提高了模型对图像特征的理解能力。

3、通用性

ResT作为一个通用骨架,在图像分类和下游任务(如对象检测和实例分割)上展现了卓越的性能,证明了其作为强大骨架网络的潜力。

4、代码

import numpy as np
import torch
from torch import nn
from torch.nn import init

# 多尺度注意力模块(EMSA),用于实现多尺度注意力机制
class EMSA(nn.Module):

    def __init__(self, d_model, d_k, d_v, h, dropout=.1, H=7, W=7, ratio=3, apply_transform=True):

        super(EMSA, self).__init__()
     
        self.H = H# 输入特征图的高度
        self.W = W# 输入特征图的宽度
        self.fc_q = nn.Linear(d_model, h * d_k)# 查询向量的全连接层
        self.fc_k = nn.Linear(d_model, h * d_k) # 键向量的全连接层
        self.fc_v = nn.Linear(d_model, h * d_v)# 值向量的全连接层
        self.fc_o = nn.Linear(h * d_v, d_model)# 输出的全连接层
        self.dropout = nn.Dropout(dropout)# Dropout层,用于防止过拟合

        self.ratio = ratio # 空间降采样比例
        if (self.ratio > 1):
            # 如果空间降采样比例大于1,添加空间降采样层
            self.sr = nn.Sequential()
            self.sr_conv = nn.Conv2d(d_model, d_model, kernel_size=ratio + 1, stride=ratio, padding=ratio // 2,
                                     groups=d_model)
            self.sr_ln = nn.LayerNorm(d_model)

        self.apply_transform = apply_transform and h > 1
        if (self.apply_transform):
              # 如果应用变换,添加变换层
            self.transform = nn.Sequential()
            self.transform.add_module('conv', nn.Conv2d(h, h, kernel_size=1, stride=1))
            self.transform.add_module('softmax', nn.Softmax(-1))
            self.transform.add_module('in', nn.InstanceNorm2d(h))

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h

        self.init_weights()
         # 初始化权重
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):

        b_s, nq, c = queries.shape
        nk = keys.shape[1]
         # 生成查询、键和值向量
        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)

        if (self.ratio > 1):
            # 如果空间降采样,处理查询以生成键和值向量
            x = queries.permute(0, 2, 1).view(b_s, c, self.H, self.W)  # bs,c,H,W
            x = self.sr_conv(x)  # bs,c,h,w
            x = x.contiguous().view(b_s, c, -1).permute(0, 2, 1)  # bs,n',c
            x = self.sr_ln(x)
            k = self.fc_k(x).view(b_s, -1, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, n')
            v = self.fc_v(x).view(b_s, -1, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, n', d_v)
        else:
            # 不进行空间降采样,直接生成键和值向量
            k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
            v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        if (self.apply_transform):
             # 应用变换计算注意力权重
            att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, n')
            att = self.transform(att)  # (b_s, h, nq, n')
        else:
            # 直接计算注意力权重
            att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, n')
            att = torch.softmax(att, -1)  # (b_s, h, nq, n')

        if attention_weights is not None:
            att = att * attention_weights
        if attention_mask is not None:
            att = att.masked_fill(attention_mask, -np.inf)

        att = self.dropout(att)# 应用dropout
      # 计算输出
        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)
        return out # 返回输出结果


if __name__ == '__main__':
    block = EMSA(d_model=512, d_k=512, d_v=512, h=8, H=8, W=8, ratio=2, apply_transform=True).cuda()# 创建EMSA模块实例,并配置到CUDA上(如果可用)
    input = torch.rand(64, 64, 512).cuda()# 随机生成输入数据
    output = block(input, input, input)# 前向传播
    print(output.shape)

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容