语义分割——DeeplabV3plus

DeeplabV3plus 是一种先进的用于语义分割任务的深度学习模型。DeepLabV3plus模型采用了编码器-解码器(Encoder-Decoder)结构,通过编码器提取图像特征,再通过解码器将这些特征映射回原始图像尺寸,实现像素级的分类。具体来说,模型的主干网络(论文中对ResNet101或Xception做了实验)负责特征提取,特征提取分为高层语义提取和底层的语义提取两个部分。然后,模型会利用空洞卷积(Dilated Convolution)技术,构建了ASPP(Atrous Spatial Pyramid Pooling)模块,提高模型在不同尺度特征提取上的能力。最后,通过解码器恢复图像的细节信息,得到最终的分割结果。总体流程如下:



这里面,核心部分是ASPP模块,也就是空洞金字塔池化模块,该模型最大的特点就是利用空洞卷积来提取出不同尺度的信息。并把不同尺度的特征信息进行拼接,再结合浅层特征后进行上采样,得到影像的预测结果。具体流程如下:

  1. 原始图像经过骨干特征提取特征,采用ResNet或Xception等卷积神经网络进行特征提取;
  2. 这里分成两部分,一部分是较为浅层的特征x1,一部分是较为深层的特征x2;
  3. 将较为深层的特征x2,输入ASPP模块,在ASPP中,分为五个分支:
    a. 第一个分支经过1x1卷积,不改变特征大小,得到特征图;
    b. 第二个分支经过3x3卷积,设置空洞系数为6,填充和空洞系数一致,不改变特征大小,得到特征图;
    c. 第三个分支经过3x3卷积,设置空洞系数为12,填充和空洞系数一致,不改变特征大小,得到特征图;
    d. 第四个分支经过3x3卷积,设置空洞系数为18,填充和空洞系数一致,不改变特征大小,得到特征图;
    e. 第五个分支经过平均池化操作,再经过一个1x1卷积改变通道数,得到特征图;
    f. 按通道维度合并五个分支的特征;
    g. 合并后的特征经过1x1卷积,得到深层特征的最终特征图x3;
  4. 将较为浅层的特征x1进行1x1卷积,得到特征图x4;
  5. 将深层特征的最终特征图x3进行上采样,恢复到和浅层特征x1一样的大小,假设称为x5;
  6. 按通道维度合并浅层特征x4和深层特征x5;
  7. 再进行一个3x3卷积,得到分类结果;
  8. 上采样,恢复成原始输入图像的大小,得到图像分割结果。

空洞卷积的内容,网上有很多介绍。大家可以自己去查阅相关资料,简单来说,空洞卷积或者叫膨胀卷积,就是为了增加感受野的一种卷积方式。



扩张率为1的时候,就是普通卷积,可以看到感受野就是3x3,当扩张率为2的时候,卷积核元素之间就会间隔1个像素点,实际参与运算的感受野范围就会扩大,等效卷积核变成了5x5,感受野变成了7x7,当扩张率为4的时候,卷积核元素之间就会间隔3个像素点,等效卷积核变成了9x9,感受野扩张到15x15。可以看到,空洞卷积的目的就是在不增加卷积核元素的前提下,增加感受野。DeeplabV3plus模型就是利用这种卷积方式,获取到不同尺度下的特征值。
DeeplabV3plus的代码实现如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet101  # 可以选择其他主干网络

class ASPPModule(nn.Module):
    def __init__(self, in_channels, out_channels, dilations):
        super(ASPPModule, self).__init__()
        self.branches = nn.ModuleList()
        self.branches.append(
            # image pooling 分支
            nn.Sequential(nn.AvgPool2d(3,1,1),
                          nn.Conv2d(in_channels, out_channels, 1, 1),
                          nn.BatchNorm2d(out_channels),
                          nn.ReLU(inplace=True))
        )
        # 四个空洞卷积分支
        for d in dilations:
            self.branches.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, 3, 1, dilation=d, padding=d),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True)
                )
            )
        # 1x1卷积
        self.conv_bn_relu = nn.Sequential(
            nn.Conv2d((len(dilations)+1) * out_channels, out_channels, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        size = x.size()[2:]
        print("size: ",size)
        features = []
        # 获取各个分支的特征,并把大小调整到一致
        for i in range(len(self.branches)):
            out = self.branches[i](x)
            print("out.shape: ",out.shape)
            out = F.interpolate(out, size=size, mode='bilinear', align_corners=True)
            print("upsample out.shape: ",out.shape)
            features.append(out)
        # 按通道维度合并五个特征分支
        features = torch.cat(features, dim=1)
        return self.conv_bn_relu(features)

# 凯明初始化
def initialize_weights(*models):
    for model in models:
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

class DeepLabV3Plus(nn.Module):
    def __init__(self, n_classes=21, backbone='resnet101', output_stride=16):
        super(DeepLabV3Plus, self).__init__()
        if backbone == 'resnet101':
            # 这里要用新的写法,否则会显示警告信息,提示过期
            #self.backbone = resnet101(pretrained=False)
            self.backbone = resnet101(weights="IMAGENET1K_V1")
            # 修改ResNet的最后几个层以适应DeepLabV3+
            # 移除最后的平均池化层和分类层
            self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
            self.first = self.backbone[0:3]
            self.layer1 = self.backbone[4]
            self.layer2 = self.backbone[5]
            self.layer3 = self.backbone[6]
            self.layer4 = self.backbone[7]
        else:
            raise ValueError('Unsupported backbone - `{}`, Use resnet101'.format(backbone))

        self.aspp = ASPPModule(2048, 256, [1, 6, 12, 18])
        self.conv1x1 = nn.Conv2d(256, 48, 1, 1)
        self.upsample4 = nn.ConvTranspose2d(48, 48, 4, stride=2, padding=1)
        self.low_level_conv = nn.Sequential(
            nn.Conv2d(256, 48, 1, 1),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )
        self.final_conv = nn.Conv2d(96, n_classes, 3, 1, 1)

        initialize_weights(self.backbone, self.aspp, self.conv1x1, self.low_level_conv, self.final_conv)

    def forward(self, x):
        # 获取主干网络的特征图
        c2, c3, c4, c5 = self._forward_backbone(x)
        size0 = x.size()[2:]
        print("size:",size0)
        # ASPP模块
        features = self.aspp(c5)
        print("features.shape: ", features.shape)
        features = self.conv1x1(features)
        print("features.shape: ", features.shape)
        features = self.upsample4(features)
        print("features.shape: ", features.shape)

        # 低级特征融合
        low_level_features = self.low_level_conv(c3)
        size = low_level_features.size()[2:]
        features = F.interpolate(features, size=size, mode='bilinear', align_corners=True)
        features = torch.cat([features, low_level_features], dim=1)

        # 最终分类层
        output = self.final_conv(features)
        # 最终上采样
        output = F.interpolate(output, size=size0, mode='bilinear', align_corners=True) 
        return output

    def _forward_backbone(self, x):
        c2 = self.first(x)
        c3 = self.layer1(c2)
        c4 = self.layer2(c3)
        c5 = self.layer3(c4)
        c5 = self.layer4(c5)
        print("c2.shape: {}".format(c2.shape))
        print("c3.shape: {}".format(c3.shape))
        print("c4.shape: {}".format(c4.shape))
        print("c5.shape: {}".format(c5.shape))
        return c2, c3, c4, c5

# 示例用法
model = DeepLabV3Plus(n_classes=21)  # Pascal VOC数据集的类别数
input_tensor = torch.randn(1, 3, 513, 513)  # 示例输入,批量大小为1,3个通道,高度和宽度为513
output = model(input_tensor)
print(output.shape)  # 输出形状应该是[1, 21, 513, 513],表示每个像素的类别预测
# 输出:
c2.shape: torch.Size([1, 64, 257, 257])
c3.shape: torch.Size([1, 256, 257, 257])
c4.shape: torch.Size([1, 512, 129, 129])
c5.shape: torch.Size([1, 2048, 33, 33])
size: torch.Size([513, 513])
size:  torch.Size([33, 33])
out.shape:  torch.Size([1, 256, 33, 33])
upsample out.shape:  torch.Size([1, 256, 33, 33])
out.shape:  torch.Size([1, 256, 33, 33])
upsample out.shape:  torch.Size([1, 256, 33, 33])
out.shape:  torch.Size([1, 256, 33, 33])
upsample out.shape:  torch.Size([1, 256, 33, 33])
out.shape:  torch.Size([1, 256, 33, 33])
upsample out.shape:  torch.Size([1, 256, 33, 33])
out.shape:  torch.Size([1, 256, 33, 33])
upsample out.shape:  torch.Size([1, 256, 33, 33])
features.shape:  torch.Size([1, 256, 33, 33])
features.shape:  torch.Size([1, 48, 33, 33])
features.shape:  torch.Size([1, 48, 66, 66])

torch.Size([1, 21, 513, 513])

对遥感影像解译数据集GID进行训练,学习率0.01,batch_size设置为8,训练100个epoch,总体精度达到0.847,各类别精度如下:

2024-12-12 13:22:26,051 - __main__ - DEBUG - --------------------------------------
2024-12-12 13:22:26,051 - __main__ - DEBUG - |0|background|0.84|
2024-12-12 13:22:26,051 - __main__ - DEBUG - |1|building|0.78|
2024-12-12 13:22:26,052 - __main__ - DEBUG - |2|farmland|0.83|
2024-12-12 13:22:26,052 - __main__ - DEBUG - |3|tree|0.21|
2024-12-12 13:22:26,052 - __main__ - DEBUG - |4|grass|0.37|
2024-12-12 13:22:26,052 - __main__ - DEBUG - |5|water|0.75|
2024-12-12 13:22:26,052 - __main__ - DEBUG - --------------------------------------
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,794评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,050评论 3 391
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,587评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,861评论 1 290
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,901评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,898评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,832评论 3 416
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,617评论 0 271
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,077评论 1 308
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,349评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,483评论 1 345
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,199评论 5 341
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,824评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,442评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,632评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,474评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,393评论 2 352

推荐阅读更多精彩内容