- 论文地址:LEDNET: A LIGHTWEIGHT ENCODER-DECODER NETWORK FOR REAL-TIME SEMANTIC SEGMENTATION
- 论文代码: github-Pytorch
概述
- 目前语义分割领域倾向于追求高精度,CNNs由上百个卷积层和上千个通道组成,实时性差。
- 要求实时性的应用:
增强现实、机器人、自动驾驶... - 相关移动平台:
遥控飞机、机器人、智能手机... - 移动平台的限制:
续航问题、内存限制、有限的计算力... - 为了保持实时性,出现了两类相关研究:
网络压缩和卷积分解。具体压缩技术包含剪枝(pruning)、量化(quantization)、散列编码(hashing);卷积分解直接使用少量参数的模型并保持较好的精确度
- 本文使用不对称编解码网络较好的平衡了准确率和速度的问题:
- 参数量少于
1M
- 单张GTX 1080ti GPU上
71FPS
- 综合准确率和速度在CityScapes上效果最好
-
整体网络模型
编码器部分
由残差(residual)
、通道拆分(split)
、通道打乱(shuffle)
三者构成了编码器的基本模块—split-shuffle-bottleneck(SS-bt)
。
具体结构如下图d所示:
图a是resnet中基本残差模块,图b组合了1维的卷积核,图b加入了通道打乱技术。图d中可以发现输入分成了两支,每支都有
一半
的通道数,每支由1维
卷积核组成,可以发现里面还包含了膨胀卷积
,这里的膨胀因子用于控制感受野,主要用于第3次下采样后的卷积,之后会把两支concat
到一起,保持输入输出通道数一致,利用残差
思想,加上输入特征,最后通道随机
打乱。具体网络参数如图所示:
SS-bt
中不包含下采样,和resnet不一样,这里有专门的Downsampling Unit
模块,由两个并行结构组成,一个3x3步长2的卷积核(输出通道数=output-input),另一个是Max-pooling(输出通道数=input),二者会concat到一起。
解码器部分
主要由金字塔状的attention分支
和全局平均池化分支
组成,其中attention这里会做三次下采样和上采样,通过point-wise sum
融合不同层信息,对应卷积核大小分别为 3×3
、5×5
、7×7
,并通过point-wise product
对每个像素点attention;全局池化后点加
到attention后的输出结果;最后上采样8倍大小,还原到输入图片尺寸大小。
实验
- 数据集:cityscape
- batch_size:5
- 显卡:GTX 1080Ti GPU
- 初始学习率:5e-4
- 学习策略:poly
- poly对应的power:0.9
- 动量:0.9
- 权重衰减:1e-4
注意:训练时有使用cityscapes的20K张粗略标注图
-
与其它实验在准确率和速度上的比较:
-
具体分类准确率的对比: