深度学习模块9-CoTAttention模块

8、CoTAttention模块

论文《Contextual Transformer Networks for Visual Recognition》

1、 作用

Contextual Transformer (CoT) block 设计为视觉识别的一种新颖的 Transformer 风格模块。该设计充分利用输入键之间的上下文信息指导动态注意力矩阵的学习,从而加强视觉表示的能力。CoT block 首先通过 3x3 卷积对输入键进行上下文编码,得到输入的静态上下文表示。然后,将编码后的键与输入查询合并,通过两个连续的 1x1 卷积学习动态多头注意力矩阵。学习到的注意力矩阵乘以输入值,实现输入的动态上下文表示。最终将静态和动态上下文表示的融合作为输出。

2、机制

1、上下文编码

通过 3x3 卷积在所有邻居键内部空间上下文化每个键表示,捕获键之间的静态上下文信息。

2、动态注意力学习

基于查询和上下文化的键的连接,通过两个连续的 1x1 卷积产生注意力矩阵,这一过程自然地利用每个查询和所有键之间的相互关系进行自我注意力学习,并由静态上下文指导。

3、静态和动态上下文的融合

将静态上下文和通过上下文化自注意力得到的动态上下文结合,作为 CoT block 的最终输出。

3、 独特优势

1、上下文感知

CoT 通过在自注意力学习中探索输入键之间的富上下文信息,使模型能够更准确地捕获视觉内容的细微差异。

2、动静态上下文的统一

CoT 设计巧妙地将上下文挖掘与自注意力学习统一到单一架构中,既利用键之间的静态关系又探索动态特征交互,提升了模型的表达能力。

3、灵活替换与优化

CoT block 可以直接替换现有 ResNet 架构中的标准卷积,不增加参数和 FLOP 预算的情况下实现转换为 Transformer 风格的骨干网络(CoTNet),通过广泛的实验验证了其在多种应用(如图像识别、目标检测和实例分割)中的优越性。

4、代码

# 导入必要的PyTorch模块
import torch
from torch import nn
from torch.nn import functional as F

class CoTAttention(nn.Module):
    # 初始化CoT注意力模块
    def __init__(self, dim=512, kernel_size=3):
        super().__init__()
        self.dim = dim  # 输入的通道数
        self.kernel_size = kernel_size  # 卷积核大小

        # 定义用于键(key)的卷积层,包括一个分组卷积,BatchNorm和ReLU激活
        self.key_embed = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=4, bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU()
        )

        # 定义用于值(value)的卷积层,包括一个1x1卷积和BatchNorm
        self.value_embed = nn.Sequential(
            nn.Conv2d(dim, dim, 1, bias=False),
            nn.BatchNorm2d(dim)
        )

        # 缩小因子,用于降低注意力嵌入的维度
        factor = 4
        # 定义注意力嵌入层,由两个卷积层、一个BatchNorm层和ReLU激活组成
        self.attention_embed = nn.Sequential(
            nn.Conv2d(2*dim, 2*dim//factor, 1, bias=False),
            nn.BatchNorm2d(2*dim//factor),
            nn.ReLU(),
            nn.Conv2d(2*dim//factor, kernel_size*kernel_size*dim, 1)
        )

    def forward(self, x):
        # 前向传播函数
        bs, c, h, w = x.shape  # 输入特征的尺寸
        k1 = self.key_embed(x)  # 生成键的静态表示
        v = self.value_embed(x).view(bs, c, -1)  # 生成值的表示并调整形状

        y = torch.cat([k1, x], dim=1)  # 将键的静态表示和原始输入连接
        att = self.attention_embed(y)  # 生成动态注意力权重
        att = att.reshape(bs, c, self.kernel_size*self.kernel_size, h, w)
        att = att.mean(2, keepdim=False).view(bs, c, -1)  # 计算注意力权重的均值并调整形状
        k2 = F.softmax(att, dim=-1) * v  # 应用注意力权重到值上
        k2 = k2.view(bs, c, h, w)  # 调整形状以匹配输出

        return k1 + k2  # 返回键的静态和动态表示的总和

# 实例化CoTAttention模块并测试
if __name__ == '__main__':
    block = CoTAttention(64)  # 创建一个输入通道数为64的CoTAttention实例
    input = torch.rand(1, 64, 64, 64)  # 创建一个随机输入
    output = block(input)  # 通过CoTAttention模块处理输入
    print(output.shape)  # 打印输入和输出的尺寸

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容