DeeplabV3plus 是一种先进的用于语义分割任务的深度学习模型。DeepLabV3plus模型采用了编码器-解码器(Encoder-Decoder)结构,通过编码器提取图像特征,再通过解码器将这些特征映射回原始图像尺寸,实现像素级的分类。具体来说,模型的主干网络(论文中对ResNet101或Xception做了实验)负责特征提取,特征提取分为高层语义提取和底层的语义提取两个部分。然后,模型会利用空洞卷积(Dilated Convolution)技术,构建了ASPP(Atrous Spatial Pyramid Pooling)模块,提高模型在不同尺度特征提取上的能力。最后,通过解码器恢复图像的细节信息,得到最终的分割结果。总体流程如下:
这里面,核心部分是ASPP模块,也就是空洞金字塔池化模块,该模型最大的特点就是利用空洞卷积来提取出不同尺度的信息。并把不同尺度的特征信息进行拼接,再结合浅层特征后进行上采样,得到影像的预测结果。具体流程如下:
- 原始图像经过骨干特征提取特征,采用ResNet或Xception等卷积神经网络进行特征提取;
- 这里分成两部分,一部分是较为浅层的特征x1,一部分是较为深层的特征x2;
- 将较为深层的特征x2,输入ASPP模块,在ASPP中,分为五个分支:
a. 第一个分支经过1x1卷积,不改变特征大小,得到特征图;
b. 第二个分支经过3x3卷积,设置空洞系数为6,填充和空洞系数一致,不改变特征大小,得到特征图;
c. 第三个分支经过3x3卷积,设置空洞系数为12,填充和空洞系数一致,不改变特征大小,得到特征图;
d. 第四个分支经过3x3卷积,设置空洞系数为18,填充和空洞系数一致,不改变特征大小,得到特征图;
e. 第五个分支经过平均池化操作,再经过一个1x1卷积改变通道数,得到特征图;
f. 按通道维度合并五个分支的特征;
g. 合并后的特征经过1x1卷积,得到深层特征的最终特征图x3; - 将较为浅层的特征x1进行1x1卷积,得到特征图x4;
- 将深层特征的最终特征图x3进行上采样,恢复到和浅层特征x1一样的大小,假设称为x5;
- 按通道维度合并浅层特征x4和深层特征x5;
- 再进行一个3x3卷积,得到分类结果;
- 上采样,恢复成原始输入图像的大小,得到图像分割结果。
空洞卷积的内容,网上有很多介绍。大家可以自己去查阅相关资料,简单来说,空洞卷积或者叫膨胀卷积,就是为了增加感受野的一种卷积方式。
扩张率为1的时候,就是普通卷积,可以看到感受野就是3x3,当扩张率为2的时候,卷积核元素之间就会间隔1个像素点,实际参与运算的感受野范围就会扩大,等效卷积核变成了5x5,感受野变成了7x7,当扩张率为4的时候,卷积核元素之间就会间隔3个像素点,等效卷积核变成了9x9,感受野扩张到15x15。可以看到,空洞卷积的目的就是在不增加卷积核元素的前提下,增加感受野。DeeplabV3plus模型就是利用这种卷积方式,获取到不同尺度下的特征值。
DeeplabV3plus的代码实现如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet101 # 可以选择其他主干网络
class ASPPModule(nn.Module):
def __init__(self, in_channels, out_channels, dilations):
super(ASPPModule, self).__init__()
self.branches = nn.ModuleList()
self.branches.append(
# image pooling 分支
nn.Sequential(nn.AvgPool2d(3,1,1),
nn.Conv2d(in_channels, out_channels, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True))
)
# 四个空洞卷积分支
for d in dilations:
self.branches.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, dilation=d, padding=d),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
)
# 1x1卷积
self.conv_bn_relu = nn.Sequential(
nn.Conv2d((len(dilations)+1) * out_channels, out_channels, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.5)
)
def forward(self, x):
size = x.size()[2:]
print("size: ",size)
features = []
# 获取各个分支的特征,并把大小调整到一致
for i in range(len(self.branches)):
out = self.branches[i](x)
print("out.shape: ",out.shape)
out = F.interpolate(out, size=size, mode='bilinear', align_corners=True)
print("upsample out.shape: ",out.shape)
features.append(out)
# 按通道维度合并五个特征分支
features = torch.cat(features, dim=1)
return self.conv_bn_relu(features)
# 凯明初始化
def initialize_weights(*models):
for model in models:
for module in model.modules():
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
nn.init.kaiming_normal(module.weight)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm2d):
module.weight.data.fill_(1)
module.bias.data.zero_()
class DeepLabV3Plus(nn.Module):
def __init__(self, n_classes=21, backbone='resnet101', output_stride=16):
super(DeepLabV3Plus, self).__init__()
if backbone == 'resnet101':
# 这里要用新的写法,否则会显示警告信息,提示过期
#self.backbone = resnet101(pretrained=False)
self.backbone = resnet101(weights="IMAGENET1K_V1")
# 修改ResNet的最后几个层以适应DeepLabV3+
# 移除最后的平均池化层和分类层
self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
self.first = self.backbone[0:3]
self.layer1 = self.backbone[4]
self.layer2 = self.backbone[5]
self.layer3 = self.backbone[6]
self.layer4 = self.backbone[7]
else:
raise ValueError('Unsupported backbone - `{}`, Use resnet101'.format(backbone))
self.aspp = ASPPModule(2048, 256, [1, 6, 12, 18])
self.conv1x1 = nn.Conv2d(256, 48, 1, 1)
self.upsample4 = nn.ConvTranspose2d(48, 48, 4, stride=2, padding=1)
self.low_level_conv = nn.Sequential(
nn.Conv2d(256, 48, 1, 1),
nn.BatchNorm2d(48),
nn.ReLU(inplace=True)
)
self.final_conv = nn.Conv2d(96, n_classes, 3, 1, 1)
initialize_weights(self.backbone, self.aspp, self.conv1x1, self.low_level_conv, self.final_conv)
def forward(self, x):
# 获取主干网络的特征图
c2, c3, c4, c5 = self._forward_backbone(x)
size0 = x.size()[2:]
print("size:",size0)
# ASPP模块
features = self.aspp(c5)
print("features.shape: ", features.shape)
features = self.conv1x1(features)
print("features.shape: ", features.shape)
features = self.upsample4(features)
print("features.shape: ", features.shape)
# 低级特征融合
low_level_features = self.low_level_conv(c3)
size = low_level_features.size()[2:]
features = F.interpolate(features, size=size, mode='bilinear', align_corners=True)
features = torch.cat([features, low_level_features], dim=1)
# 最终分类层
output = self.final_conv(features)
# 最终上采样
output = F.interpolate(output, size=size0, mode='bilinear', align_corners=True)
return output
def _forward_backbone(self, x):
c2 = self.first(x)
c3 = self.layer1(c2)
c4 = self.layer2(c3)
c5 = self.layer3(c4)
c5 = self.layer4(c5)
print("c2.shape: {}".format(c2.shape))
print("c3.shape: {}".format(c3.shape))
print("c4.shape: {}".format(c4.shape))
print("c5.shape: {}".format(c5.shape))
return c2, c3, c4, c5
# 示例用法
model = DeepLabV3Plus(n_classes=21) # Pascal VOC数据集的类别数
input_tensor = torch.randn(1, 3, 513, 513) # 示例输入,批量大小为1,3个通道,高度和宽度为513
output = model(input_tensor)
print(output.shape) # 输出形状应该是[1, 21, 513, 513],表示每个像素的类别预测
# 输出:
c2.shape: torch.Size([1, 64, 257, 257])
c3.shape: torch.Size([1, 256, 257, 257])
c4.shape: torch.Size([1, 512, 129, 129])
c5.shape: torch.Size([1, 2048, 33, 33])
size: torch.Size([513, 513])
size: torch.Size([33, 33])
out.shape: torch.Size([1, 256, 33, 33])
upsample out.shape: torch.Size([1, 256, 33, 33])
out.shape: torch.Size([1, 256, 33, 33])
upsample out.shape: torch.Size([1, 256, 33, 33])
out.shape: torch.Size([1, 256, 33, 33])
upsample out.shape: torch.Size([1, 256, 33, 33])
out.shape: torch.Size([1, 256, 33, 33])
upsample out.shape: torch.Size([1, 256, 33, 33])
out.shape: torch.Size([1, 256, 33, 33])
upsample out.shape: torch.Size([1, 256, 33, 33])
features.shape: torch.Size([1, 256, 33, 33])
features.shape: torch.Size([1, 48, 33, 33])
features.shape: torch.Size([1, 48, 66, 66])
torch.Size([1, 21, 513, 513])
对遥感影像解译数据集GID进行训练,学习率0.01,batch_size设置为8,训练100个epoch,总体精度达到0.847,各类别精度如下:
2024-12-12 13:22:26,051 - __main__ - DEBUG - --------------------------------------
2024-12-12 13:22:26,051 - __main__ - DEBUG - |0|background|0.84|
2024-12-12 13:22:26,051 - __main__ - DEBUG - |1|building|0.78|
2024-12-12 13:22:26,052 - __main__ - DEBUG - |2|farmland|0.83|
2024-12-12 13:22:26,052 - __main__ - DEBUG - |3|tree|0.21|
2024-12-12 13:22:26,052 - __main__ - DEBUG - |4|grass|0.37|
2024-12-12 13:22:26,052 - __main__ - DEBUG - |5|water|0.75|
2024-12-12 13:22:26,052 - __main__ - DEBUG - --------------------------------------