经典分类网络 ResNet 论文阅读及PYTORCH示例代码

上一篇说要尝试一下用 se_ResNeXt 来给 WS-DAN 网络提取特征,在此之前需要先搞懂 ResNeXt 的原理,而 ResNeXt 则是在 ResNet 基础上的改进,所以绕了一大圈,还得从 ResNet 开始。说来惭愧,之前只是用过 ResNet 来做分类任务,论文还真没有仔细读过,正好趁这个机会读一读这篇“神作”。

论文地址: https://arxiv.org/pdf/1512.03385.pdf

论文阅读

其实论文的思想在今天看来是不难的,不过在当时 ResNet 提出的时候可是横扫了各大分类任务,这个网络解决了随着网络的加深,分类的准确率不升反降的问题。通过一个名叫“残差”的网络结构(如下图所示),使作者可以只通过简单的网络深度堆叠便可达到提升准确率的目的。

残差结构

残差结构的处理过程分成两个部分,左边的 与右边的 ,最后结果为两者相加。其中右边那根线不会对 做任何处理,所以没有可学习的参数; 为网络中负责学习特征的部分,把整个残差结构看做是一个 函数的话,则负责学习的部分可以表示为 ,这个结构学习的其实是输出结果与输入的差值,这也是残差名字的由来。完整的 ResNet 网络由多个上图中所示的残差结构组成,每个结构学习的都是输出与输入之间的差值,通过步步逼近,达到了比直接学习输入好得多的效果。

文中残差结构的具体实现分为两种,首先介绍 ResNet-18 与 ResNet-34 使用的残差结构称为 Basic Block,如下图所示,图中的结构包含了两个卷积操作用于提取特征。

Basic Block

对应到代码中,这是 Pytorch 自带的 ResNet 实现中的一部分,跟上图对应起来看更加好理解,我个人比较喜欢论文与代码结合起来看,因为我除了需要知道原理之外,也要知道如何去使用,而代码更给我一种一目了然的感觉:

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

另一种残差结构称为 Bottleneck,就是瓶颈的意思:

瓶颈

作者起名字真的很形象,网络结构也正如这瓶颈一样,首先做一个降维,然后做卷积,然后升维,这样做的好处是可以大大减少计算量,专门用于网络层数较深的的网络,ResNet-50 以上的网络都有这种基础结构构成(不同层级的输入输出维度可能会不一样,但结构类似):
Bottleneck

Pytorch 中的代码,注意到上图中为了减少计算量,作者将 256 维的输入缩小了 4 倍变为 64 进入卷积,在升维时需要升到 256 维,对应代码中的 expansion 参数:

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

由上面介绍的基本结构再加上池化以及全连接层,就构成了各种完整的网络:


各网络结构

图中的网络在 Pytorch 中都已经集成进去了,而且都是预训练好的,我们可以在预训练好的模型上面训练自己的分类器,大大减少我们的训练时间。下面简单介绍一下如何使用 ResNet。

在 Pytorch 中使用 ResNet

Pytorch 是一个对初学者很友好的深度学习框架,入门的话非常推荐,官方提供了一小时入门教程:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html
在 Pytorch 中使用 ResNet 只需要 4 行代码:

from torch import nn
# torchvision 专用于视觉方面
import torchvision 
  
# pretrained :使用在 ImageNet 数据集上预训练的模型
model = torchvision.models.resnet18(pretrained=True)
# 修改模型的全连接层使其输出为你需要类型数,这里是10
# 由于使用了预训练的模型 而预训练的模型输出为1000类,所以要修改全连接层
# 若不使用预训练的模型可以直接在创建模型时添加参数 num_classes=10 而不需要修改全连接层
model.fc = nn.Linear(model.fc.in_features, 10)

下面你就可以使用这个模型来做分类了,当然到这里还没在自己的数据集上进行训练,关于如何训练可以参考官方教程:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
如果对代码以及源码有疑问的话可以在下面留言我们一起讨论。

最后,求赞求关注,欢迎关注我的微信公众号[MachineLearning学习之路] ,深度学习 & CV 方向的童鞋不要错过!!

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,384评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,845评论 3 391
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,148评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,640评论 1 290
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,731评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,712评论 1 294
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,703评论 3 415
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,473评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,915评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,227评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,384评论 1 345
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,063评论 5 340
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,706评论 3 324
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,302评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,531评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,321评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,248评论 2 352

推荐阅读更多精彩内容