DBnet的整体网络结构如下
网络结构由resnet-18、FPN和head组成。
1、resnet-18骨干网络
resnet-18用于特征提取,输出是x2,x3,x4,x5,分别为原图大小的1/4,1/8,1/16,1/32。
def constant_init(module, constant, bias=0):
nn.init.constant_(module.weight, constant)
if hasattr(module, 'bias'):
nn.init.constant_(module.bias, bias)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
super(BasicBlock, self).__init__()
self.with_dcn = dcn is not None
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.with_modulated_dcn = False
if not self.with_dcn:
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
else:
from torchvision.ops import DeformConv2d
deformable_groups = dcn.get('deformable_groups', 1)
offset_channels = 18
self.conv2_offset = nn.Conv2d(planes, deformable_groups * offset_channels, kernel_size=3, padding=1)
self.conv2 = DeformConv2d(planes, planes, kernel_size=3, padding=1, bias=False)
self.bn2 = BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
# out = self.conv2(out)
if not self.with_dcn:
out = self.conv2(out)
else:
offset = self.conv2_offset(out)
out = self.conv2(out, offset)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, in_channels=3, dcn=None):
self.dcn = dcn
self.inplanes = 64
super(ResNet, self).__init__()
self.out_channels = []
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dcn=dcn)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dcn=dcn)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dcn=dcn)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
if self.dcn is not None:
for m in self.modules():
if isinstance(m, Bottleneck) or isinstance(m, BasicBlock):
if hasattr(m, 'conv2_offset'):
constant_init(m.conv2_offset, 0)
def _make_layer(self, block, planes, blocks, stride=1, dcn=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, dcn=dcn))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dcn=dcn))
self.out_channels.append(planes * block.expansion)
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x2 = self.layer1(x)
x3 = self.layer2(x2)
x4 = self.layer3(x3)
x5 = self.layer4(x4)
return x2, x3, x4, x5
def resnet18(pretrained=True, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'
print('load from imagenet')
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
return model
if __name__ == '__main__':
import torch
x = torch.zeros(1, 3, 640, 640)
net = resnet18(pretrained=True, **{"in_channels":3})
2、FPN结构
FPN结构图
代码部分
卷积、BN层、relu激活函数
from torch import nn
class ConvBnRelu(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', inplace=True):
super().__init__()
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias, padding_mode=padding_mode)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=inplace)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
if __name__ == '__main__':
import torch
conv_bn_relu = ConvBnRelu(3, 64, 3, padding=1)
inputs = torch.randn([1, 3, 224, 224])
print(conv_bn_relu(inputs).size())
FPN代码
import torch
import torch.nn.functional as F
from torch import nn
from models.basic import ConvBnRelu
class FPN(nn.Module):
def __init__(self, in_channels, inner_channels=256, **kwargs):
"""
:param in_channels: 基础网络输出的维度
:param kwargs:
"""
super().__init__()
inplace = True
self.conv_out = inner_channels
inner_channels = inner_channels // 4
# reduce layers
self.reduce_conv_c2 = ConvBnRelu(in_channels[0], inner_channels, kernel_size=1, inplace=inplace)
self.reduce_conv_c3 = ConvBnRelu(in_channels[1], inner_channels, kernel_size=1, inplace=inplace)
self.reduce_conv_c4 = ConvBnRelu(in_channels[2], inner_channels, kernel_size=1, inplace=inplace)
self.reduce_conv_c5 = ConvBnRelu(in_channels[3], inner_channels, kernel_size=1, inplace=inplace)
# Smooth layers
self.smooth_p4 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace)
self.smooth_p3 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace)
self.smooth_p2 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace)
self.conv = nn.Sequential(
nn.Conv2d(self.conv_out, self.conv_out, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(self.conv_out),
nn.ReLU(inplace=inplace)
)
self.out_channels = self.conv_out
def forward(self, x):
c2, c3, c4, c5 = x
# Top-down
# 调整通道数
p5 = self.reduce_conv_c5(c5)
p4 = self._upsample_add(p5, self.reduce_conv_c4(c4))
p4 = self.smooth_p4(p4)
p3 = self._upsample_add(p4, self.reduce_conv_c3(c3))
p3 = self.smooth_p3(p3)
p2 = self._upsample_add(p3, self.reduce_conv_c2(c2))
p2 = self.smooth_p2(p2)
x = self._upsample_cat(p2, p3, p4, p5)
x = self.conv(x)
return x
def _upsample_add(self, x, y):
"""
上采样
:param x:
:param y:
:return:
"""
return F.interpolate(x, size=y.size()[2:]) + y
def _upsample_cat(self, p2, p3, p4, p5):
h, w = p2.size()[2:]
p3 = F.interpolate(p3, size=(h, w))
p4 = F.interpolate(p4, size=(h, w))
p5 = F.interpolate(p5, size=(h, w))
return torch.cat([p2, p3, p4, p5], dim=1)
if __name__ == '__main__':
in_channels = [64, 128, 256, 512]
fpn = FPN(in_channels)
x = (torch.randn((1, 64, 160, 160)),
torch.randn((1, 128, 80, 80)),
torch.randn((1, 256, 40, 40)),
torch.randn((1, 512, 20, 20)))
print(fpn(x))
3、head部分
转置卷积(Transpose Convolution)
示例图
s=1,p=0,k=3
|
s=2,p=0,k=3
|
s=2,p=1,k=3
|
---|
转置卷积运算步骤
1、在输入特征图元素间填充s-1行、列0
2、在输入特征图元素间填充k-p-1行、列0
3、将卷积核参数上下、左右翻转
4、普通卷积运算(padding=0,stride=1)
输出特征图大小的计算公式
转置卷积计算过程
pytorch实验
代码
from torch import nn
import torch
b = torch.Tensor([[[[3,5],
[2,1]]]])
k = torch.Tensor([[1,0,1],
[0,0,1],
[0,1,1]])
kernel = torch.FloatTensor(k).expand(1,1,3,3)
weight = nn.Parameter(data=kernel, requires_grad=False)
conv = nn.ConvTranspose2d(1, 1, 3, bias=False)
conv.weight = weight
print(conv.weight)
print(conv(b))
# tensor([[[[ 3., 5., 3., 5.],
# [ 2., 1., 5., 6.],
# [ 0., 3., 10., 6.],
# [ 0., 2., 3., 1.]]]])
回顾普通卷积
普通卷积计算过程
对卷积核矩阵进行变换
最终进行矩阵相乘
转置卷积并不是普通卷积的逆运算
将矩阵转换成等效矩阵
import torch
from torch import nn
class DBHead(nn.Module):
def __init__(self, in_channels, out_channels, k = 50):
super().__init__()
self.k = k
self.binarize = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2),
nn.Sigmoid())
self.binarize.apply(self.weights_init)
self.thresh = self._init_thresh(in_channels)
self.thresh.apply(self.weights_init)
def forward(self, x):
shrink_maps = self.binarize(x)
threshold_maps = self.thresh(x)
if self.training:
binary_maps = self.step_function(shrink_maps, threshold_maps)
y = torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)
else:
y = torch.cat((shrink_maps, threshold_maps), dim=1)
return y
def weights_init(self, m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.kaiming_normal_(m.weight.data)
elif classname.find('BatchNorm') != -1:
m.weight.data.fill_(1.)
m.bias.data.fill_(1e-4)
def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
"""
:param inner_channels:
:param serial: thresh prediction will combine segmentation result as input
:param smooth: use bilinear instead of deconv
:param bias: Whether conv layers have bias or not
:return:
"""
in_channels = inner_channels
if serial:
in_channels += 1
self.thresh = nn.Sequential(
nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),
nn.BatchNorm2d(inner_channels // 4),
nn.ReLU(inplace=True),
self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),
nn.BatchNorm2d(inner_channels // 4),
nn.ReLU(inplace=True),
self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
nn.Sigmoid())
return self.thresh
def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
if smooth:
inter_out_channels = out_channels
if out_channels == 1:
inter_out_channels = in_channels
module_list = [
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]
if out_channels == 1:
module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))
return nn.Sequential(module_list)
else:
return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
def step_function(self, x, y):
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
if __name__ == '__main__':
db = DBHead(in_channels=256, out_channels=1)
x = torch.randn((1, 256, 160, 160))
print(db(x))