12、MSCA模块
论文《SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation》
1、作用
SegNeXt旨在为语义分割任务提供一个简单而有效的卷积网络架构。通过重新考虑卷积注意力的设计,提出了一种比传统自注意力机制更高效的方法来编码空间信息。
2、机制
1、SegNeXt结合了强大的编码器、多尺度信息交互和空间注意力来提升语义分割的性能。
2、通过采用便宜的卷积操作和简化的设计,SegNeXt实现了与先进方法相比的显著性能提升,同时大幅减少了参数数量。
3、该模型通过使用多尺度卷积特征来激发空间注意力,采用简单的元素级乘法操作,证明了这种方式比标准卷积和自注意力在空间信息编码方面更高效。
3、独特优势
1、SegNeXt在多个流行的基准测试中,包括ADE20K、Cityscapes、COCO-Stuff、Pascal VOC、Pascal Context和iSAID上,显著改善了之前最先进方法的性能。
2、特别是,SegNeXt使用只有EfficientNet-L2 w/ NAS-FPN 1/10参数的情况下,在Pascal VOC 2012测试排行榜上达到了90.6% mIoU的成绩。
3、平均而言,SegNeXt在ADE20K数据集上比最先进方法提高了约2.0%的mIoU,同时计算量更少或相同。
4、代码
import torch
from torch import nn
class AttentionModule(nn.Module):
def __init__(self, dim):
super().__init__()
# 使用5x5核的卷积层,应用深度卷积
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
# 两组卷积层,分别使用1x7和7x1核,用于跨度不同的特征提取,均应用深度卷积
self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)
self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)
# 另外两组卷积层,使用更大的核进行特征提取,分别为1x11和11x1,也是深度卷积
self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
# 使用最大尺寸的核进行特征提取,为1x21和21x1,深度卷积
self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)
self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)
# 最后一个1x1卷积层,用于整合上述所有特征提取的结果
self.conv3 = nn.Conv2d(dim, dim, 1)
def forward(self, x):
u = x.clone() # 克隆输入x,以便之后与注意力加权的特征进行相乘
attn = self.conv0(x) # 应用初始的5x5卷积
# 应用1x7和7x1卷积,进一步提取特征
attn_0 = self.conv0_1(attn)
attn_0 = self.conv0_2(attn_0)
# 应用1x11和11x1卷积,进一步提取特征
attn_1 = self.conv1_1(attn)
attn_1 = self.conv1_2(attn_1)
# 应用1x21和21x1卷积,进一步提取特征
attn_2 = self.conv2_1(attn)
attn_2 = self.conv2_2(attn_2)
attn = attn + attn_0 + attn_1 + attn_2 # 将所有特征提取的结果相加
attn = self.conv3(attn) # 应用最后的1x1卷积层整合特征
return attn * u # 将原始输入和注意力加权的特征相乘,返回最终结果
if __name__ == "__main__":
# 创建 AttentionModule 实例,这里以64个通道为例
attention_module = AttentionModule(dim=64)
# 创建一个假的输入数据,维度为 [batch_size, channels, height, width]
# 例如,1个样本,64个通道,64x64的图像
input_tensor = torch.rand(1, 64, 64, 64)
# 通过AttentionModule处理输入
output_tensor = attention_module(input_tensor)
# 打印输出张量的形状
print(output_tensor.shape)