TensorFlow 自定义模型训练 Mask R-CNN

        上一个系列文章 TensorFlow 训练自己的目标检测器 以及 TensorFlow 训练 Mask R-CNN 模型, 说明了怎么用 TensorFlow 开源的目标检测和实例分割接口 TensorFlow/models/research/object_detection 来基于自己的数据训练 Mask R-CNN 模型。但这些都是使用它自带的、而不是自己定义的模型来训练,因此不一定适用自己的数据集,比如,对于一些很简单的数据集,就没必要用它自带的非常复杂的模型。这一篇文章试图弥补这个缺点,来讲述怎么,基于 Tensorflow Detection API,自定义模型训练 Mask R-CNN,相关官方简易文档见 So you want to create a new model!

        作为开始,让我们回顾一下 Mask R-CNN 目标检测与实例分割的过程:对于一张给定的图像,首先经过一个卷积神经网络(称为特征提取器)从图像中提取特征,得到一张特征映射;接着,从这张特征映射预测 2 个分支,分别预测:目标得分目标边框;以上整个过程便组成了 Mask R-CNN 的第一阶段,整个网络称为候选区域网络(Region Proposal Network, RPN);其次,从第一阶段预测的结果中筛选出高置信度的目标区域,将这些区域从特征映射中裁剪出来,送入第二阶段的预测网络,更精细的预测 3 个分支:类概率目标边框目标实例;最后,通过非极大值抑制等后处理后操作后输出最终的检测结果。从以上过程可见,整个 Mask R-CNN 框架最灵活的地方是特征提取器(不同的卷积神经网络定义不同的特征提取器),而其它地方基本都可以固定不变。因此,借助 Tensorflow Detection API 搭建的 Mask R-CNN 框架,要实现自定义模型训练已变得非常简单,只需要重写一个特征提取器(Feature Extractor)即可。而这可以模仿 models/research/object_detection/models 文件夹的文件来写,这个文件夹内的所有文件都是特征提取器的定义。

        本文给出一个简单的示例,编写一个深度较小的特征提取器,嵌入到 Tensorflow Detection API 来实现自定义模型训练 Mask R-CNN。

        所有代码见 GitHub: mask_rcnn_customized

一、自定义特征提取器

        要自定义特征提取器,需要重载 models/research/object_detection/meta_architectures 文件夹内的 Mask R-CNN 框架类 faster_rcnn_meta_arch.py 中的特征提取器抽象类 FasterRCNNFeatureExtractor。因此,只需要继承该类,再重新定义该类的抽象函数 process、_extract_proposal_features、_extract_box_classifier_features 即可。我们仿造文件夹 models/research/object_detection/models 内的 faster_rcnn_resnet_v1_feature_extractor.py 来写。

(1) 定义简单卷积模型:ResNet-20

        前面已经交代过,特征提取器是一个卷积神经网络,它负责从图像中提取特征,用于后续两阶段的预测。简单起见,我们用 ResNet 的残差模块(Residual Module)来定义一个 20 层的网络,如下(见文件 custom_resnet.py):

from tensorflow.contrib.slim import nets

resnet_v1_block = nets.resnet_v1.resnet_v1_block

def resnet_v1_20(inputs,
                 num_classes=None,
                 is_training=True,
                 global_pool=True,
                 output_stride=None,
                 spatial_squeeze=True,
                 store_non_strided_activations=False,
                 reuse=None,
                 scope='resnet_v1_20'):
    """ResNet-20 model. See resnet_v1() for arg and return description."""
    blocks = [
        resnet_v1_block('block1', base_depth=64, num_units=1, stride=2),
        resnet_v1_block('block2', base_depth=128, num_units=1, stride=2),
        resnet_v1_block('block3', base_depth=256, num_units=1, stride=2),
        resnet_v1_block('block4', base_depth=512, num_units=3, stride=1)
    ]
    return nets.resnet_v1.resnet_v1(
        inputs, 
        blocks, 
        num_classes, 
        is_training,
        global_pool=global_pool, 
        output_stride=output_stride,
        include_root_block=True, 
        reuse=reuse,
        scope=scope)

一个残差模块由 3 个卷积层组成(见下图-右),在 TensorFlow 的实现里称为一个 Unit,多个 Unit 的组合称为一个 Block。如上面的代码,使用了 4 个 Block,它们的 Unit 个数(num_units)分别为 1,1,1,3,因此总共有 20 = (1 + 1 + 1 + 3) x 3 + 1 + 1 个卷积层,最后的 1 + 1 指的分别是卷积核为 11 x 11 的网络第一个卷积层和最后的 softmax 输出层。

residual module

(2)重载特征提取器

        接下来,来写自定义的 ResNet-20 对应的特征提取器,代码如下(见 custom_faster_rcnn_resnet_v1_feature_extractor.py):

# -*- coding: utf-8 -*-
"""
Created on Thu Nov  1 14:18:07 2018

@author: shirhe-lyh


ResNet V1 Faster R-CNN customized implementation.
"""

import tensorflow as tf

from tensorflow.contrib.slim import nets
from object_detection.meta_architectures import faster_rcnn_meta_arch

from object_detection.models import custom_resnet

slim = tf.contrib.slim
resnet_v1_block = nets.resnet_v1.resnet_v1_block


class CustomFasterRCNNResnetV1FeatureExtractor(
    faster_rcnn_meta_arch.FasterRCNNFeatureExtractor):
    """Faster R-CNN ResNet v1 feature extractor customized implementation."""
    
    def __init__(self,
                 architecture,
                 resnet_model,
                 is_training,
                 first_stage_features_stride,
                 batch_norm_trainable=False,
                 reuse_weights=None,
                 weight_decay=0.0):
        """Constructor.
        
        Args:
            architecture: Architecture name of the ResNet V1 model.
            resnet_model: Definition of the ResNet V1 model.
            is_training: See base class.
            batch_norm_trainable: See base class.
            first_stage_features_stride: See base class.
            batch_norm_trainable: See base class.
            reuse_weights: See base class.
            weight_decay: See base class.
            
        Raises:
            ValueError: If `first_stage_features_stride` is not 8 or 16.
        """
        if first_stage_features_stride != 8 and first_stage_features_stride !=16:
            raise ValueError('`first_stage_features_stride` must be 8 or 16.')
        
        self._architecture = architecture
        self._resnet_model = resnet_model
        super(CustomFasterRCNNResnetV1FeatureExtractor, self).__init__(
            is_training, first_stage_features_stride, batch_norm_trainable,
            reuse_weights, weight_decay)
        
    def preprocess(self, resized_inputs):
        """Faster R-CNN ResNet V1 preprocessing.
        
        Args:
            resized_inputs: A [batch, height_in, width_in, channels] float32
                tensor representing a batch of images with values between 0
                and 255.0.
                
        Returns:
            preprocessed_inputs: A [batch, height_out, width_out, channels]
                float32 tensor representing a batch of images.
        """
        channel_means = [123.68, 116.779, 103.939]
        return resized_inputs - [[channel_means]]
    
    def _extract_proposal_features(self, preprocessed_inputs, scope):
        """Extracts first stage RPN features.
        
        Args:
            preprocessed_inputs: A [batch, height, width, channels] float32
                tensor representing a batch of images.
            scope: A scope name.
            
        Returns:
            rpn_feature_map: A tensor with shape [batch, height, width, depth].
            activations: A dictionary mapping feature extractor tensor names
                to tensors.
                
        Raises:
            InvalidArgumentError: If the spatial size of `preprocessed_inputs`
                (height or width) is less than 33.
            ValueError: If the created network is missing the required
                activation.
        """
        if len(preprocessed_inputs.get_shape().as_list()) != 4:
            raise ValueError('`preprocessed_inputs` must be 4 dimensional, '
                             'got a tensor of shape %s' % 
                             preprocessed_inputs.get_shape())
            
        shape_assert = tf.Assert(
            tf.logical_and(
                tf.greater_equal(tf.shape(preprocessed_inputs)[1], 33),
                tf.greater_equal(tf.shape(preprocessed_inputs)[2], 33)),
            ['image size must at least be 33 in both height and width.'])
            
        with tf.control_dependencies([shape_assert]):
            # Disables batchnorm for fine-tuning with smaller batch sizes.
            # TODO(chensun): Figure out if it is needed when image
            # batch size is bigger.
            with slim.arg_scope(nets.resnet_utils.resnet_arg_scope(
                batch_norm_epsilon=1e-5,
                batch_norm_scale=True,
                weight_decay=self._weight_decay)):
                with tf.variable_scope(self._architecture,
                                       reuse=self._reuse_weights) as var_scope:
                    _, activations = self._resnet_model(
                        preprocessed_inputs,
                        num_classes=None,
                        is_training=self._train_batch_norm,
                        global_pool=False,
                        output_stride=self._first_stage_features_stride,
                        spatial_squeeze=False,
                        scope=var_scope)
                    
        handle = scope + '/%s/block3' % self._architecture
        return activations[handle], activations
    
    def _extract_box_classifier_features(self, proposal_feature_maps, scope):
        """Extracts second stage box classifier features.
        
        Args:
            proposal_feature_maps: A 4-D float tensor with shape [batch_size *
                self.max_num_proposals, crop_height, crop_width, depth]
                representing the feature map croped to each proposal.
            scope: A scope name (unused).
            
        Returns:
            proposal_classifier_features: A 4-D float tensor with shape
                [batch_size * self.max_num_proposals, height, width, depth]
                representing box classifier features for each proposal.
        """
        with tf.variable_scope(self._architecture, reuse=self._reuse_weights):
            with slim.arg_scope(nets.resnet_utils.resnet_arg_scope(
                batch_norm_epsilon=1e-5,
                batch_norm_scale=True,
                weight_decay=self._weight_decay)):
                with slim.arg_scope([slim.batch_norm],
                                    is_training=self._train_batch_norm):
                    blocks = [
                        nets.resnet_utils.Block(
                            'block4', nets.resnet_v1.bottleneck,
                            [{'depth': 2048,
                              'depth_bottleneck': 512,
                              'stride': 1
                            }] * 3)
                    ]
                    proposal_classifier_features = (
                        nets.resnet_utils.stack_blocks_dense(
                            proposal_feature_maps, blocks))
        return proposal_classifier_features
    
    
class CustomFasterRCNNResnet20FeatureExtractor(
    CustomFasterRCNNResnetV1FeatureExtractor):
    """Faster R-CNN ResNet V1 20 feature extractor implementation."""
    
    def __init__(self,
                 is_training,
                 first_stage_features_stride,
                 batch_norm_trainable=False,
                 reuse_weights=None,
                 weight_decay=0.0):
        """Construtor.
        
        Args:
            is_training: See base class.
            first_stage_features_stride: See base class.
            batch_norm_trainable: See base class.
            reuse_weights: See base class.
            weight_decay: See base class.
            
        Raises:
            ValueError: If `first_stage_features_stride` is not 8 or 16, or
                if `architecture` is not supported.
        """
        super(CustomFasterRCNNResnet20FeatureExtractor, self).__init__(
            'resnet_v1_20', custom_resnet.resnet_v1_20, is_training,
            first_stage_features_stride, batch_norm_trainable,
            reuse_weights, weight_decay)

因为我们自定义的模型恰好也是 ResNet,因此基本照抄了官方示例文件 faster_rcnn_resnet_v1_feature_extractor.py 的内容,除了类名和所用卷积模型不同之外。

        重载特征提取器,需要按照官方的规定,一方面要重载 3 个函数 process、_extract_proposal_features、_extract_box_classifier_features,另一方面,还要遵循约定:_extract_proposal_features 返回的应该是自定义模型的 中间 某个卷积层的结果,而剩下的卷积层由 _extract_box_classifier_features 返回。比如,对于我们自定义的 ResNet-20_extract_proposal_features 函数只返回前 3Block,最后的第 4Block_extract_box_classifier_features 返回。官方的 Remark 如下:

将卷积网络拆分的说明

        最后的类 CustomFasterRCNNResnet20FeatureExtractor 就是我们自定义的特征提取器,要使用这个类来训练 Mask R-CNN,还需要将它加入到 TensorFlow Object Detection API 框架里,即需要对它进行注册。

        还需要注意的一点是:_extract_proposal_features 这个函数要求输入的图像分辨率必须不小于 33 x 33,因对图像做预处理(比如对图像做缩放)后不要违背了这个硬性要求。

二、模型注册

        模型注册非常简单:首先将自定义的文件 custom_resnet.pycustom_faster_rcnn_resnet_v1_feature_extractor.py 复制到文件夹 models/research/object_detection/models 里,然后修改文件夹 models/research/object_detection/builders 里的文件 model_builder.py,在模块导入最后加入一条导入语句:

from object_detection.models import \
custom_faster_rcnn_resnet_v1_feature_extractor as custom_frcnn_resnet_v1

之后在该文件的 FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP 字典里加入 key-value 对:

'custom_frcnn_resnet20':
custom_frcnn_resnet_v1.CustomFasterRCNNResnet20FeatureExtractor,

自此,模型注册就完成了。model_builder.py 改动后的文件也也已经上传到 GitHub: mask_rcnn_customized。接下来,只需要准备数据启动训练就可以验证我们前面的工作是否成功了。

三、数据准备

        GitHub: mask_rcnn_customized 这个项目的里的文件 shape_mask_generator.py 文件用来生成简单的正多边形几何形状,我们以此来生成训练数据。简单起见,我们以生成等边三角形正方形,且每幅图像有且仅有一个对象为例:

生成的 triangle 和 rectangle 两类数据

运行项目里的文件 generate_datasets.py(GitHub 项目内已经生成了数据,此步可跳过;如果要重新生成,请删除 datasets 文件夹):

$ python3 generate_datasets.py

会在当前路径下生成一个 datasets 的文件夹,里面包含 5000 张图像(images 文件夹)和对应的掩模(masks 文件夹),另外还有一个记录图像与掩模对应关系,以及记录每个目标的 boundingbox 的标注文件:annotations.json。如果,你想生成更多的图像,请修改文件 generate_datasets.py 的参数 num_samples。因为,每张掩模都是 0-1 二值的灰度图,所以我们直接观看 masks 文件夹里面的 .png 图像时都是全黑的。如果要将其 mask 显示出来,请执行:

$ python3 visualize_masks.py

然后到 /datasets/masks_recgonized 文件夹内查看。

        图像和掩模生成好后,需要将它们写入 TFRecord 文件,执行(GitHub 项目内已经生成了 .record 文件,此步可跳过;如果要重新生成,请直接执行):

$ python3 generate_tfrecord.py

datasets 文件夹内生成 train.recordval.record 文件,下面,就可以开始训练了。

四、模型训练

        mask_rcnn_customized 里已经配置好了类名与类标号转化文件 shape_label_map.pbtxt 以及模型参数配置文件 mask_rcnn_customized_resnet_v1_shape.config(需要修改路径train_input_reader: {...} 中的 input_pathlabel_map_path,以及 eval_input_reader: {...} 中的 input_pathlabel_map_path),要启动训练,进入你配置好的 TensorFlow Object Detection API 项目的文件夹 models/research/object_detection 内,执行:

$ python3 model_main.py \
--model_dir Path/to/mask_rcnn_customized/training \
--pipeline_config_path Path/to/mask_rcnn_customized_resnet_v1_shape.config

即可。因为要使用自己定义的特征提取器,在配置 .config 文件时必须要将其中的

model {
    faster_rcnn {
        feature_extractor {
            type: 'custom_frcnn_resnet20'
            ...
        }
        ...
    }
    ...
}

type 字段修改为自己在 中注册特征提取器时的特征提取器名。

        训练开始后,使用:

$ tensorboard --logdir Path/to/mask_rcnn_customized/training 

命令得到浏览器链接,打开该链接可实时监督训练过程。训练过程的精度/损失曲线大致如下:

检测召回率曲线

损失曲线

损失曲线

        训练结束后,在路径 models/research/object_detection 下执行:

$ python3 export_inference_graph.py \
--trained_checkpoint_prefix Path/to/training/model.ckpt-40000 \
--output_directory Path/to/converted_pb_file_saving_directory \
--pipeline_config_path Path/to/mask_rcnn_customized_resnet_v1_shape.config

将训练保存的 .ckpt 模型转化为 .pb 格式,方便后续调用。其中,参数 trained_checkpoint_prefix 指定训练后 .ckpt 模型保存的路径(详细指定到某个训练次数时的模型),output_directory 指定转化后的 .pb 格式模型的保存路径(填写某个路径下的文件夹,该文件夹可以不存在,比如填写 /home/.../mask_rcnn_customized/training/frozen_inference_graph_pb),pipeline_config_path 指定 mask_rcnn_customized_resnet_v1_shape.config 文件路径。

五、结果展示

        当顺利完成以上所有步骤之后,就可以运行:

$ python3 predict.py

来进行预测了(注意:需要将 predict.py 文件中的 PATH_TO_CKPT 填写为上一步转化来的 .pb 文件所在的路径。如果,执行 python3 export_inference_graph.py ... 命令时,你填写的 output_directory/home/.../mask_rcnn_customized/training/frozen_inference_graph_pb,则使用默认路径而不需要修改)。执行后,会在当前路径下生成 test_images 文件夹,里面会输出 10 张测试图像,以及他们的检测结果:

检测及分割结果

image_6.jpg
image_6_out.jpg
image_8.jpg
image_8_out.jpg

        由于只使用了 5000 个训练样本,且没有仔细调参,可以发现模型对于三角形的检测和分割效果不是太理想。如果,你想改善效果,可以使用更多的训练数据。

说明
        1.如果不使用预训练模型,比如我们自定义的 ResNet-20 就找不到训练好的参数,在修改配置文件 xxx.config 时要将 train_config: {...} 里的其中两行:

#fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
#from_detection_checkpoint: true

注释掉,这样所有模型参数都会随机初始化。

        2.训练过程如果报如下错误:TypeError: can't pickle dict_values objects,则将 models/research/object_detection/model_lib.py 中第 418 行的 category_index.values() 改成 list(category_index.values()) 即可

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,222评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,455评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,720评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,568评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,696评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,879评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,028评论 3 409
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,773评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,220评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,550评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,697评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,360评论 4 332
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,002评论 3 315
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,782评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,010评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,433评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,587评论 2 350

推荐阅读更多精彩内容