15、SCSE模块
论文《Concurrent Spatial and Channel `Squeeze & Excitation' in Fully Convolutional Networks》
1、作用
scSE模块主要用于增强F-CNN在图像分割任务中的性能,通过对特征图进行自适应的校准来提升网络对图像中重要特征的响应能力。该模块通过同时在空间和通道上对输入特征图进行校准,鼓励网络学习更加有意义、在空间和通道上都相关的特征图。
2、机制
空间挤压和通道激励(scSE):
scSE模块是通过将通道激励(cSE)和空间激励(sSE)模块的输出进行元素级加法操作得到的。这一操作使得输入特征图的每个位置在获取通道重缩放和空间重缩放的高重要性时获得更高的激活值。通过这种校准方式,网络能够更加有效地关注于图像中的重要特征,同时忽略不重要的信息。
3、独特优势
1、模型复杂度的微小增加:
尽管scSE模块为F-CNN引入了额外的参数,但它对整体网络复杂度的增加非常小。例如,在实验中使用的U-Net添加scSE模块仅增加了大约1.5%的参数量,表明SE模块只需很小的复杂度增加就能显著提升性能。
2、通用性和高效性:
scSE模块可以无缝集成到不同的F-CNN架构中,在多个图像分割任务上都能取得一致的性能提升。这证明了scSE模块是一个高度通用且有效的网络组件,能够在多种医学应用中作为神经网络的重要组成部分。
4、代码
import torch
import torch.nn as nn
class sSE(nn.Module): # 空间(Space)注意力
def __init__(self, in_ch) -> None:
super().__init__()
self.conv = nn.Conv2d(in_ch, 1, kernel_size=1, bias=False) # 定义一个卷积层,用于将输入通道转换为单通道
self.norm = nn.Sigmoid() # 应用Sigmoid激活函数进行归一化
def forward(self, x):
q = self.conv(x) # 使用卷积层减少通道数至1:b c h w -> b 1 h w
q = self.norm(q) # 对卷积后的结果应用Sigmoid激活函数:b 1 h w
return x * q # 通过广播机制将注意力权重应用到每个通道上
class cSE(nn.Module): # 通道(channel)注意力
def __init__(self, in_ch) -> None:
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1) # 使用自适应平均池化,输出大小为1x1
self.relu = nn.ReLU() # ReLU激活函数
self.Conv_Squeeze = nn.Conv2d(in_ch, in_ch // 2, kernel_size=1, bias=False) # 通道压缩卷积层
self.norm = nn.Sigmoid() # Sigmoid激活函数进行归一化
self.Conv_Excitation = nn.Conv2d(in_ch // 2, in_ch, kernel_size=1, bias=False) # 通道激励卷积层
def forward(self, x):
z = self.avgpool(x) # 对输入特征进行全局平均池化:b c 1 1
z = self.Conv_Squeeze(z) # 通过通道压缩卷积减少通道数:b c//2 1 1
z = self.relu(z) # 应用ReLU激活函数
z = self.Conv_Excitation(z) # 通过通道激励卷积恢复通道数:b c 1 1
z = self.norm(z) # 对激励结果应用Sigmoid激活函数进行归一化
return x * z.expand_as(x) # 将归一化权重乘以原始特征,使用expand_as扩展维度与原始特征相匹配
class scSE(nn.Module):
def __init__(self, in_ch) -> None:
super().__init__()
self.cSE = cSE(in_ch) # 通道注意力模块
self.sSE = sSE(in_ch) # 空间注意力模块
def forward(self, x):
c_out = self.cSE(x) # 应用通道注意力
s_out = self.sSE(x) # 应用空间注意力
return c_out + s_out # 合并通道和空间注意力的输出
x = torch.randn(4, 16, 4, 4) # 测试输入
net = scSE(16) # 实例化模型
print(net(x).shape) # 打印输出形状