tensorflow2.0(3)-Resnet模型

  tensorflow2不再需要静态建图启动session(),抛弃很多繁杂的功能设计,代码上更加简洁清晰,而在工程上也更加灵活。
但是一些基础的用法,单靠api接口去训练模型是远远无法满足实际的应用,基于这种框架,更多还需要自己在其上自定义开发。

例如:model.fit() 虽然能一句代码把训练跑起来,但你根本无法知道整个模型内部数据的变化,也难以去查看某些变量。我们不可能永远停留在MNIST之类的数据集上。

Resnet

  个人更倾向在实战中学习深化基础,而不是把基础理论学好了再去实践。本篇基于tf2.0是搭建Resnet网络,Resnet有很多变种,也作为很多模型的骨干网络,这次实战项目就从它开始
需要对Resnet有一定的认知了解,本文只是代码实现

网络结构

   官方给出的Resnet网络结构,分别为18,34,50,101,152层,可以看出,不同层数之间总体的结构是一样的,这样就很方便用类去实例化每一个模块了


image

基础模块

  从conv2_x到conv5_x,18和34layer的结构是一样的,50,101和152是一样的,具体分别为:

image

先定义18or34layer的模块

# for 18 or 34 layers
class Basic_Block(keras.Model):

    def __init__(self, filters, downsample=False, stride=1):
        self.expasion = 1
        super(Basic_Block, self).__init__()

        self.downsample = downsample

        self.conv2a = keras.layers.Conv2D(filters=filters,
                                          kernel_size=3,
                                          strides=stride,
                                          kernel_initializer='he_normal',
                                          )
        self.bn2a = keras.layers.BatchNormalization(axis=-1)

        self.conv2b = keras.layers.Conv2D(filters=filters,
                                          kernel_size=3,
                                          padding='same',
                                          kernel_initializer='he_normal'
                                          )
        self.bn2b = keras.layers.BatchNormalization(axis=-1)

        self.relu = keras.layers.ReLU()

        if self.downsample:
            self.conv_shortcut = keras.layers.Conv2D(filters=filters,
                                                     kernel_size=1,
                                                     strides=stride,
                                                     kernel_initializer='he_normal',
                                                     )
            self.bn_shortcut = keras.layers.BatchNormalization(axis=-1)

    def call(self, inputs, **kwargs):
        x = self.conv2a(inputs)
        x = self.bn2a(x)
        x = self.relu(x)

        x = self.conv2b(x)
        x = self.bn2b(x)
        x = self.relu(x)

        if self.downsample:
            shortcut = self.conv_shortcut(inputs)
            shortcut = self.bn_shortcut(shortcut)
        else:
            shortcut = inputs

        x = keras.layers.add([x, shortcut])
        x = self.relu(x)

代码虽然长了点,但看一下call() 里面就很清晰了,就是2个 conv+bn+relu,最后与input做点加操作
同理应用在50,101or152layer:

# for 50, 101 or 152 layers
class Block(keras.Model):

    def __init__(self, filters, block_name,
                 downsample=False, stride=1, **kwargs):
        self.expasion = 4
        super(Block, self).__init__(**kwargs)

        conv_name = 'res' + block_name + '_branch'
        bn_name = 'bn' + block_name + '_branch'
        self.downsample = downsample

        self.conv2a = keras.layers.Conv2D(filters=filters,
                                          kernel_size=1,
                                          strides=stride,
                                          kernel_initializer='he_normal',
                                          name=conv_name + '2a')
        self.bn2a = keras.layers.BatchNormalization(axis=3, name=bn_name + '2a')

        self.conv2b = keras.layers.Conv2D(filters=filters,
                                          kernel_size=3,
                                          padding='same',
                                          kernel_initializer='he_normal',
                                          name=conv_name + '2b')
        self.bn2b = keras.layers.BatchNormalization(axis=3, name=bn_name + '2b')

        self.conv2c = keras.layers.Conv2D(filters=4 * filters,
                                          kernel_size=1,
                                          kernel_initializer='he_normal',
                                          name=conv_name + '2c')
        self.bn2c = keras.layers.BatchNormalization(axis=3, name=bn_name + '2c')

        if self.downsample:
            self.conv_shortcut = keras.layers.Conv2D(filters=4 * filters,
                                                     kernel_size=1,
                                                     strides=stride,
                                                     kernel_initializer='he_normal',
                                                     name=conv_name + '1')
            self.bn_shortcut = keras.layers.BatchNormalization(axis=3, name=bn_name + '1')

    def call(self, inputs, **kwargs):
        x = self.conv2a(inputs)
        x = self.bn2a(x)
        x = tf.nn.relu(x)

        x = self.conv2b(x)
        x = self.bn2b(x)
        x = tf.nn.relu(x)

        x = self.conv2c(x)
        x = self.bn2c(x)

        if self.downsample:
            shortcut = self.conv_shortcut(inputs)
            shortcut = self.bn_shortcut(shortcut)
        else:
            shortcut = inputs

        x = keras.layers.add([x, shortcut])
        x = tf.nn.relu(x)

        return x

对于downsample的操作,如果input和最后一层输出的chanels不一样就需要downsample来保持chanel一致,这样才能相加,一般解析resnet的文章都会提到。
用类封装了模块的功能,接下来只需要在主体网路结构里添加这个模块就好了

主体结构

用subclassing的方式去搭建model,就像砌墙一样,一个模块一个模块拼上去就好了,先在init()里面定义好需要用到的方法,再在call()把他们调用起来。
对于resnet的主体结构,先看一下call()里是该如何写的:

def call(self, inputs, **kwargs):
    x = self.padding(inputs)
    x = self.conv1(x)
    x = self.bn_conv1(x)
    x = tf.nn.relu(x)
    x = self.max_pool(x)

    # layer2
    x = self.res2(x)
    # layer3
    x = self.res3(x)
    # layer4
    x = self.res4(x)
    # layer5
    x = self.res5(x)

    x = self.avgpool(x)
    x = self.fc(x)
    return x

一目了然,跟文章开头的结构图一摸一样,
最重要的是中间conv2-5 的操作,这个需要对resnet结构熟悉

在Resnet的init()里面,这样去定义中间的4个层

# layer2
self.res2 = self.mid_layer(block, 64, layers[0], stride=1, layer_number=2)

# layer3
self.res3 = self.mid_layer(block, 128, layers[1], stride=2, layer_number=3)

# layer4
self.res4 = self.mid_layer(block, 256, layers[2], stride=2, layer_number=4)

# layer5
self.res5 = self.mid_layer(block, 512, layers[3], stride=2, layer_number=5)

函数self.mid_layer() 就是把block模块串起来

def mid_layer(self, block, filter, block_layers, stride=1, layer_number=1):
    layer = keras.Sequential()
    if stride != 1 or filter * 4 != 64:
        layer.add(block(filters=filter,
                        downsample=True, stride=stride,
                        block_name='{}a'.format(layer_number)))

    for i in range(1, block_layers):
        p = chr(i + ord('a'))
        layer.add(block(filters=filter,
                        block_name='{}'.format(layer_number) + p))

    return layer

到此主体的结构就定义好了,官方源码Resnet,是直接从上到下直接编写的,就是一边构建网络一边计算,类似于这样

x = input()
x = keras.layers.Conv2D()(x)
x = keras.layers.MaxPooling2D()(X)
x = keras.layers.Dense(num_classes)(x)

  相对来说更喜欢用subclassing的方式去搭建model,虽然代码量多了点,但是结构清晰,自己要中间修改的时候也很简单,也方便别的地方直接调用,但有一点不好就是,当想打印模型model.summary() 的时候,看不到图像在各个操作后的shape,直接显示multiple,目前不知道有没其他的方法。。

image

代码

  上述代码呈现了Resnet的大部分内容,可以随便实现18-152layer,全部代码放在了我的github里:https://github.com/angryhen/learning_tensorflow2.0/blob/master/base_model/ResNet.py

  持续更新中,tensorflow2.0这一系列的代码也会放在上面,包括VGG,Mobilenet的基础网络,以后也会更新引入senet这种变种网络。

Thanks

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,816评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,729评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,300评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,780评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,890评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,084评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,151评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,912评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,355评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,666评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,809评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,504评论 4 334
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,150评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,882评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,121评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,628评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,724评论 2 351

推荐阅读更多精彩内容

  • Training spaCy’s Statistical Models训练spaCy模型 This guide d...
    Joe_Gao_89f1阅读 6,511评论 1 5
  • Swift1> Swift和OC的区别1.1> Swift没有地址/指针的概念1.2> 泛型1.3> 类型严谨 对...
    cosWriter阅读 11,093评论 1 32
  • 朋友问我,你会一个人过日子吗?我想反问你,你听说过谁,在这世界上,不是孤独的生,不是孤独的死?有谁?请你告...
    顾城的诗阅读 241评论 0 0
  • “我曾经喜欢过一个人。”这面有天蓝色边框的四四方方的小镜子和我说。“那个人是主人的男朋友,他第一次来主人的家里的时...
    洛可可_阅读 333评论 0 0
  • 到底要怎样才能做到既简略又全面呢?为了写出精练而深刻的读书笔记,你应该在读完一本书以后认真回想需要摘抄哪一页哪一行...
    naughty心阅读 76评论 0 0