深度学习模块10-TripletAttention模块

9、TripletAttention模块

论文《Rotate to Attend: Convolutional Triplet Attention Module》

1、作用

Triplet Attention是一种新颖的注意力机制,它通过捕获跨维度交互,利用三分支结构来计算注意力权重。对于输入张量,Triplet Attention通过旋转操作建立维度间的依赖关系,随后通过残差变换对信道和空间信息进行编码,实现了几乎不增加计算成本的情况下,有效增强视觉表征的能力。

2、机制

1、三分支结构

Triplet Attention包含三个分支,每个分支负责捕获输入的空间维度H或W与信道维度C之间的交互特征。

2、跨维度交互

通过在每个分支中对输入张量进行排列(permute)操作,并通过Z-pool和k×k的卷积层处理,以捕获跨维度的交互特征。

3、注意力权重的生成

利用sigmoid激活层生成注意力权重,并应用于排列后的输入张量,然后将其排列回原始输入形状。

3、 独特优势

1、跨维度交互

Triplet Attention通过捕获输入张量的跨维度交互,提供了丰富的判别特征表征,较之前的注意力机制(如SENet、CBAM等)能够更有效地增强网络的性能。

2、几乎无计算成本增加

相比于传统的注意力机制,Triplet Attention在提升网络性能的同时,几乎不增加额外的计算成本和参数数量,使得它可以轻松地集成到经典的骨干网络中。

3、无需降维

与其他注意力机制不同,Triplet Attention不进行维度降低处理,这避免了因降维可能导致的信息丢失,保证了信道与权重间的直接对应关系。

总的来说,Triplet Attention通过其独特的三分支结构和跨维度交互机制,在提高模型性能的同时,保持了计算效率,显示了其在各种视觉任务中的应用潜力。

4、代码

import torch
import torch.nn as nn

# 定义一个基本的卷积模块,包括卷积、批归一化和ReLU激活
class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        # 定义卷积层
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        # 条件性地添加批归一化层
        self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
        # 条件性地添加ReLU激活函数
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)  # 应用卷积
        if self.bn is not None:
            x = self.bn(x)  # 应用批归一化
        if self.relu is not None:
            x = self.relu(x)  # 应用ReLU
        return x

# 定义ZPool模块,结合最大池化和平均池化结果
class ZPool(nn.Module):
    def forward(self, x):
        # 结合最大值和平均值
        return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)

# 定义注意力门,用于根据输入特征生成注意力权重
class AttentionGate(nn.Module):
    def __init__(self):
        super(AttentionGate, self).__init__()
        kernel_size = 7  # 设定卷积核大小
        self.compress = ZPool()  # 使用ZPool模块
        self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False)  # 通过卷积调整通道数

    def forward(self, x):
        x_compress = self.compress(x)  # 应用ZPool
        x_out = self.conv(x_compress)  # 通过卷积生成注意力权重
        scale = torch.sigmoid_(x_out)  # 应用Sigmoid激活
        return x * scale  # 将注意力权重乘以原始特征

# 定义TripletAttention模块,结合了三种不同方向的注意力门
class TripletAttention(nn.Module):
    def __init__(self, no_spatial=False):
        super(TripletAttention, self).__init__()
        self.cw = AttentionGate()  # 定义宽度方向的注意力门
        self.hc = AttentionGate()  # 定义高度方向的注意力门
        self.no_spatial = no_spatial  # 是否忽略空间注意力
        if not no_spatial:
            self.hw = AttentionGate()  # 定义空间方向的注意力门

    def forward(self, x):
        # 应用注意力门并结合结果
        x_perm1 = x.permute(0, 2, 1, 3).contiguous()  # 转置以应用宽度方向的注意力
        x_out1 = self.cw(x_perm1)
        x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()  # 还原转置
        x_perm2 = x.permute(0, 3, 2, 1).contiguous()  # 转置以应用高度方向的注意力
        x_out2 = self.hc(x_perm2)
        x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()  # 还原转置
        if not self.no_spatial:
            x_out = self.hw(x)  # 应用空间注意力
            x_out = 1 / 3 * (x_out + x_out11 + x_out21)  # 结合三个方向的结果
        else:
            x_out = 1 / 2 * (x_out11 + x_out21)  # 结合两个方向的结果(如果no_spatial为True)
        return x_out

# 示例代码
if __name__ == '__main__':
    input = torch.randn(50, 512, 7, 7)  # 生成随机输入
    triplet = TripletAttention()  # 实例化TripletAttention
    output = triplet(input)  # 应用TripletAttention
    print(output.shape)  # 打印输出形状

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容