论文:TMI.2020
代码:https://github.com/FENGShuanglang/CPFNet_Project
Abstract
医学图像的精确自动分割是临床诊断和分析的关键步骤。基于U形结构的卷积神经网络(CNN)方法在许多不同的医学图像分割任务中取得了显著的效果。然而,在这种结构中,由于类别不平衡和边界模糊等问题,single stage(单级)上下文信息提取能力不足。本文提出了一种新的上下文金字塔融合网络(CPFNet),通过组合两个金字塔模块来融合全局/多尺度上下文信息。基于U形结构,我们首先在编码器和解码器之间设计了多个全局金字塔制导(GPG)模块,旨在通过重建跳转连接为解码器提供不同级别的全局上下文信息。我们进一步设计了一个尺度感知金字塔融合(SAPF)模块来动态融合高层特征中的多尺度上下文信息。这两个金字塔模块可以逐步利用和融合丰富的上下文信息。实验结果表明,在皮肤病变分割、视网膜线性病变分割、胸部危险器官多类分割和视网膜水肿病变多类分割四种不同的挑战性任务上,我们提出的方法与其他最先进的方法具有很强的竞争力。
Introduction
U型网络的问题:
- 首先,一方面,编码器较深阶段获取的全局上下文信息被逐渐传输到较浅层,由于单个阶段的特征提取能力较弱,可能会逐渐稀释全局上下文信息;另一方面,每个阶段的简单跳过连接忽略了全局信息,是局部信息的任意组合,会引入无关的杂波,导致像素的误分类。
- 第二,在每一个single stage,都没有有效地提取和利用多尺度的上下文信息。当处理具有复杂结构的目标时,此类信息是必要的,以便同时考虑结构的周围环境,避免做出模棱两可的决定[22]。
网络的架构还是很简单的,依旧采用U型结构,在跳跃链接处增加的GPG模块,并用SAPF模块替换了编码器的最后一层,编码器和解码器的结构略有改动,在上图也很好的标记出来了,下面我们主要看一下,作者提出的GPG和SAPF模块。
GPG
在GPG模块中,将这一阶段的特征图与所有更高一级阶段的特征图相结合,重构跳过连接。
以Stage3为例,首先,通过规则的3 × 3卷积将所有阶段的特征映射到与Stage3相同的通道空间。接下来,将生成的feature map F4和F5上采样到与F3相同的大小并进行连接。然后,为了从不同层次的特征地图中提取全局上下文信息,并行使用三个不同膨胀率(1、2、4)的可分离卷积[33](Dsconv@1, Dsconv@2, Dsconv@4),其中可分离卷积用于降低模型参数。值得注意的是,平行路径的数量和膨胀率随融合阶段的数量而变化。最后,利用规则的卷积得到最终的特征图。
SAPF
在SAPF模块中,使用三个不同空洞率(1,2,4)的并行的空洞卷积来获取不同的尺度信息。注意,这些不同的空洞卷积共享卷积权值,这可以减少模型参数的数量和过拟合的风险。
在此基础上,我们设计了一个scale-aware模块来融合不同尺度的特征。如下图所示,引入空间注意机制,通过自我学习动态的选择合适的尺度特征并融合。具体来说,两个不同尺度的特征和
经过一系列卷积,得到两个特征图
(H:特征图的高度,W:特征图的宽度)。然后利用softmax对空间值生成像素方向的注意映射
,最后以加权和的形式得到融合特征图。
我们使用两个级联的尺度感知模块来获得三个分支的最终融合特征。然后,利用具有可学习参数α的残差连接来获得整个SAPF模块的输出。
class SAPblock(nn.Module):
def __init__(self, in_channels):
super(SAPblock, self).__init__()
self.conv3x3 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, dilation=1, kernel_size=3,
padding=1)
self.bn = nn.ModuleList([nn.BatchNorm2d(in_channels), nn.BatchNorm2d(in_channels), nn.BatchNorm2d(in_channels)])
self.conv1x1 = nn.ModuleList(
[nn.Conv2d(in_channels=2 * in_channels, out_channels=in_channels, dilation=1, kernel_size=1, padding=0),
nn.Conv2d(in_channels=2 * in_channels, out_channels=in_channels, dilation=1, kernel_size=1, padding=0)])
self.conv3x3_1 = nn.ModuleList(
[nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 2, dilation=1, kernel_size=3, padding=1),
nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 2, dilation=1, kernel_size=3, padding=1)])
self.conv3x3_2 = nn.ModuleList(
[nn.Conv2d(in_channels=in_channels // 2, out_channels=2, dilation=1, kernel_size=3, padding=1),
nn.Conv2d(in_channels=in_channels // 2, out_channels=2, dilation=1, kernel_size=3, padding=1)])
self.conv_last = ConvBnRelu(in_planes=in_channels, out_planes=in_channels, ksize=1, stride=1, pad=0, dilation=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x_size = x.size()
branches_1 = self.conv3x3(x)
branches_1 = self.bn[0](branches_1)
branches_2 = F.conv2d(x, self.conv3x3.weight, padding=2, dilation=2) # share weight
branches_2 = self.bn[1](branches_2)
branches_3 = F.conv2d(x, self.conv3x3.weight, padding=4, dilation=4) # share weight
branches_3 = self.bn[2](branches_3)
feat = torch.cat([branches_1, branches_2], dim=1)
# feat=feat_cat.detach()
feat = self.relu(self.conv1x1[0](feat))
feat = self.relu(self.conv3x3_1[0](feat))
att = self.conv3x3_2[0](feat)
att = F.softmax(att, dim=1)
att_1 = att[:, 0, :, :].unsqueeze(1)
att_2 = att[:, 1, :, :].unsqueeze(1)
fusion_1_2 = att_1 * branches_1 + att_2 * branches_2
feat1 = torch.cat([fusion_1_2, branches_3], dim=1)
# feat=feat_cat.detach()
feat1 = self.relu(self.conv1x1[0](feat1))
feat1 = self.relu(self.conv3x3_1[0](feat1))
att1 = self.conv3x3_2[0](feat1)
att1 = F.softmax(att1, dim=1)
att_1_2 = att1[:, 0, :, :].unsqueeze(1)
att_3 = att1[:, 1, :, :].unsqueeze(1)
ax = self.relu(self.gamma * (att_1_2 * fusion_1_2 + att_3 * branches_3) + (1 - self.gamma) * x)
ax = self.conv_last(ax)
return ax