Squeeze-and-Excitation Networks (SENet)获得了2017年ImageNet的分类冠军。
论文地址:https://arxiv.org/abs/1709.01507
本文简单介绍了SENet这篇文章,并附上了SE-ResNet基于MXNet(主要基于是gluon接口)的代码实现。
SENet中,Squeeze和Excitation是两个关键性操作,示意图如下:
d80b0d64610e4610875850b69d68779a_th.jpg
第一步:Squeeze是在空间维度对特征进行压缩,即Global Average Pooling。
第二步:Excitation是用Sigmoid Function为每个特征通道生成权重,权重表示特征通道间的相关性。
第三步:Reweight操作,将Excitation生成的权重通过乘法逐通道加权到CNN提取的特征图上,完成在通道维度上的对原始特征的重标定。
SE模块可以简单地嵌入到任何神经网络当中,下面是SE-ResNet的网络结构图:
SE-ResNet.png
直接上代码:
这是原始的Residual Block,我们拿来做个参考
class Residual(nn.HybridBlock):
def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
super(Residual, self).__init__(**kwargs)
self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1,
strides=strides)
self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
if use_1x1conv:
self.conv3 = nn.Conv2D(num_channels, kernel_size=1,
strides=strides)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm()
self.bn2 = nn.BatchNorm()
def forward(self, X):
Y = nd.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3:
X = self.conv3(X)
return nd.relu(Y + X)
重点在这里,SE-Module,为了方便理解我们把Squeeze和Excitation单独写:
def Attention(num_channels):
net = nn.HybridSequential()
with net.name_scope():
net.add(
nn.GlobalAvgPool2D(),
nn.Dense(num_channels),
nn.Activation('relu'),
nn.Dense(num_channels),
nn.Activation('sigmoid')
)
return net
再将SE-Module嵌入到Residual Block里面去,做一个broadcast_multiply
class SEResidual(nn.HybridBlock):
def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
super(SEResidual, self).__init__(**kwargs)
self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1,
strides=strides)
self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
if use_1x1conv:
self.conv3 = nn.Conv2D(num_channels, kernel_size=1,
strides=strides)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm()
self.bn2 = nn.BatchNorm()
self.weight = Attention(num_channels)
def forward(self, X):
Y = nd.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
W = Y
for layer in self.weight: #W就是Attention的权重
W = layer(W)
if self.conv3:
X = self.conv3(X)
Y = nd.broadcast_mul(Y,nd.reshape(W,shape=(-1,num_channels,1,1)))
return nd.relu(Y + X)
最后再用SE-Residual Block搭积木就好啦。
啾咪~