5、SimAM模块
论文《SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks》
1、作用
SimAM(Simple Attention Module)提出了一个概念简单但非常有效的注意力模块,用于卷积神经网络。与现有的通道维度和空间维度注意力模块不同,SimAM能够为特征图中的每个神经元推断出3D注意力权重,而无需在原始网络中添加参数。
2、机制
1、能量函数优化:
SimAM基于著名的神经科学理论,通过优化一个能量函数来找出每个神经元的重要性。这个过程不添加任何新参数到原始网络中。
2、快速闭合形式解决方案:
对于能量函数,SimAM推导出了一个快速的闭合形式解决方案,并展示了这个解决方案可以在不到十行代码中实现。这种方法避免了结构调整的繁琐工作,使模块的设计更为简洁高效。
3、独特优势
1、无参数设计:
SimAM的一个显著优势是它不增加任何额外的参数。这使得SimAM可以轻松地集成到任何现有的CNN架构中,几乎不增加计算成本。
2、直接生成3D权重:
与大多数现有的注意力模块不同,SimAM能够直接为每个神经元生成真正的3D权重,而不是仅仅在通道或空间维度上。这种全面的注意力机制能够更精确地捕捉到重要的特征信息。
3、基于神经科学的设计:
SimAM的设计灵感来自于人类大脑中的注意力机制,尤其是空间抑制现象,使其在捕获视觉任务中的关键信息方面更为高效和自然。
4、代码
import torch
import torch.nn as nn
from thop import profile # 引入thop库来计算模型的FLOPs和参数数量
# 定义SimAM模块
class Simam_module(torch.nn.Module):
def __init__(self, e_lambda=1e-4):
super(Simam_module, self).__init__()
self.act = nn.Sigmoid() # 使用Sigmoid激活函数
self.e_lambda = e_lambda # 定义平滑项e_lambda,防止分母为0
def forward(self, x):
b, c, h, w = x.size() # 获取输入x的尺寸
n = w * h - 1 # 计算特征图的元素数量减一,用于下面的归一化
# 计算输入特征x与其均值之差的平方
x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
# 计算注意力权重y,这里实现了SimAM的核心计算公式
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
# 返回经过注意力加权的输入特征
return x * self.act(y)
# 示例使用
if __name__ == '__main__':
model = Simam_module().cuda() # 实例化SimAM模块并移到GPU上
x = torch.randn(1, 3, 64, 64).cuda() # 创建一个随机输入并移到GPU上
y = model(x) # 将输入传递给模型
print(y.size()) # 打印输出尺寸
# 使用thop库计算模型的FLOPs和参数数量
flops, params = profile(model, (x,))
print(flops / 1e9) # 打印以Giga FLOPs为单位的浮点操作数
print(params) # 打印模型参数数量