TensorFlow 使用预训练模型 ResNet-50

        升级版见:TensorFlow 使用 tf.estimator 训练模型(预训练 ResNet-50)

        前面的文章已经说明了怎么使用 TensorFlow 来构建、训练、保存、导出模型等,现在来说明怎么使用 TensorFlow 调用预训练模型来精调神经网络。为了简单起见,以调用预训练的 ResNet-50 用于图像分类为例,使用的模块仍然是 tf.contrib.slim

        TensorFlow 的所有用于图像分类的预训练模型的下载地址为 models/research/slim,包含常用的 VGG,Inception,ResNet,MobileNet 以及最新的 NasNet 模型等。要使用这些预训练模型的关键是将这些预训练的参数正确的导入到定义好的神经网络,这可以通过函数 slim.assign_from_checkpoint_fn 来方便的实现。下面,用代码来说明。

        所有代码见 GitHub/finetune_classification

一、Fine tuning 模型定义

        前已提及,TensorFlow 所有预训练模型均在 GitHub 项目 models/research/slim,而其对应的神经网络实现则在其子文件夹 nets。我们以调用 ResNet-50 为例(其它模型类似),首先来定义网络结构:

import tensorflow as tf

from tensorflow.contrib.slim import nets

slim = tf.contrib.slim


def predict(self, preprocessed_inputs):
    """Predict prediction tensors from inputs tensor.

    Outputs of this function can be passed to loss or postprocess functions.

    Args:
        preprocessed_inputs: A float32 tensor with shape [batch_size,
            height, width, num_channels] representing a batch of images.
            
    Returns:
        prediction_dict: A dictionary holding prediction tensors to be
            passed to the Loss or Postprocess functions.
    """
    net, endpoints = nets.resnet_v1.resnet_v1_50(
        preprocessed_inputs, num_classes=None,
        is_training=self._is_training)
    net = tf.squeeze(net, axis=[1, 2])
    net = slim.fully_connected(net, num_outputs=self.num_classes,
                               activation_fn=None, scope='Predict')
    prediction_dict = {'logits': net}
    return prediction_dict

        我们假设要分类的图像有 self.num_classes 个类,随机选择一个批量的图像,对这些图像进行预处理后,把它们作为参数传入 predict 函数,此时直接调用 TensorFlow-slim 封装好的 nets.resnet_v1.resnet_v1_50 神经网络得到图像特征,因为 ResNet-50 是用于 1000 个类的分类的,所以需要设置参数 num_classes=None 禁用它的最后一个输出层。我们假设输入的图像批量形状为 [None, 224, 224, 3],则 resnet_v1_50 函数返回的形状为 [None, 1, 1, 2048],为了输入到全连接层,需要用函数 tf.squeeze 去掉形状为 1 的第 1,2 个索引维度。最后,连接再一个全连接层得到 self.num_classes 个类的预测输出。

        可以看到,使用 tf.contrib.slim 模块,调用 ResNet-50 等神经网络变得异常简单。而接下来的关键问题是怎么导入预训练的参数,进而使用我们自己的数据来对预训练模型进行精调。在阐述怎么解决这个问题之前,先将整个模型定义的文件 model.py 列出以方便阅读:

# -*- coding: utf-8 -*-
"""
Created on Thu Oct 11 17:21:12 2018

@author: shirhe-lyh
"""

import tensorflow as tf

from tensorflow.contrib.slim import nets

import preprocessing

slim = tf.contrib.slim
    
        
class Model(object):
    """xxx definition."""
    
    def __init__(self, num_classes, is_training,
                 fixed_resize_side=368,
                 default_image_size=336):
        """Constructor.
        
        Args:
            is_training: A boolean indicating whether the training version of
                computation graph should be constructed.
            num_classes: Number of classes.
        """
        self._num_classes = num_classes
        self._is_training = is_training
        self._fixed_resize_side = fixed_resize_side
        self._default_image_size = default_image_size
        
    @property
    def num_classes(self):
        return self._num_classes
        
    def preprocess(self, inputs):
        """preprocessing.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
            
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        preprocessed_inputs = preprocessing.preprocess_images(
            inputs, self._default_image_size, self._default_image_size, 
            resize_side_min=self._fixed_resize_side,
            is_training=self._is_training,
            border_expand=True, normalize=False,
            preserving_aspect_ratio_resize=False)
        preprocessed_inputs = tf.cast(preprocessed_inputs, tf.float32)
        return preprocessed_inputs
    
    def predict(self, preprocessed_inputs):
        """Predict prediction tensors from inputs tensor.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
            
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
            net, endpoints = nets.resnet_v1.resnet_v1_50(
                preprocessed_inputs, num_classes=None,
                is_training=self._is_training)
        net = tf.squeeze(net, axis=[1, 2])
        logits = slim.fully_connected(net, num_outputs=self.num_classes,
                                      activation_fn=None, scope='Predict')
        prediction_dict = {'logits': logits}
        return prediction_dict
    
    def postprocess(self, prediction_dict):
        """Convert predicted output tensors to final forms.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            **params: Additional keyword arguments for specific implementations
                of specified models.
                
        Returns:
            A dictionary containing the postprocessed results.
        """
        logits = prediction_dict['logits']
        logits = tf.nn.softmax(logits)
        classes = tf.argmax(logits, axis=1)
        postprocessed_dict = {'logits': logits,
                              'classes': classes}
        return postprocessed_dict
    
    def loss(self, prediction_dict, groundtruth_lists):
        """Compute scalar loss tensors with respect to provided groundtruth.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            groundtruth_lists_dict: A dict of tensors holding groundtruth
                information, with one entry for each image in the batch.
                
        Returns:
            A dictionary mapping strings (loss names) to scalar tensors
                representing loss values.
        """
        logits = prediction_dict['logits']
        slim.losses.sparse_softmax_cross_entropy(
            logits=logits, 
            labels=groundtruth_lists,
            scope='Loss')
        loss = slim.losses.get_total_loss()
        loss_dict = {'loss': loss}
        return loss_dict
        
    def accuracy(self, postprocessed_dict, groundtruth_lists):
        """Calculate accuracy.
        
        Args:
            postprocessed_dict: A dictionary containing the postprocessed 
                results
            groundtruth_lists: A dict of tensors holding groundtruth
                information, with one entry for each image in the batch.
                
        Returns:
            accuracy: The scalar accuracy.
        """
        classes = postprocessed_dict['classes']
        accuracy = tf.reduce_mean(
            tf.cast(tf.equal(classes, groundtruth_lists), dtype=tf.float32))
        return accuracy

二、预训练模型导入

        要将预训练模型 ResNet-50 的参数导入到前面定义好的模型,需要继续借助 tf.contrib.slim 模块,而且方法很简单,只需要在训练函数 slim.learning.train 中指定初始化参数来源函数 init_fn 即可,而这可以通过函数

slim.assign_from_checkpoint_fn(model_path, var_list,
                               ignore_missing_vars=False,
                               reshape_variables=False)

很方便的实现。其中,第一个参数 model_path 指定预训练模型 xxx.ckpt 文件的路径,第二个参数 var_list 指定需要导入对应预训练参数的所有变量,通过函数

slim.get_variables_to_restore(include=None,
                              exclude=None)

可以快速指定,如果需要排除一些变量,也就是如果想让某些变量随机初始化而不是直接使用预训练模型来初始化,则直接在参数 exclude 中指定即可。第三个参数 ignore_missing_vars 非常重要,一定要将其设置为 True,也就是说,一定要忽略那些在定义的模型结构中可能存在的而在预训练模型中没有的变量,因为如果自己定义的模型结构中存在一个参数,而这些参数在预训练模型文件 xxx.ckpt 中没有,那么如果不忽略的话,就会导入失败(这样的变量很多,比如卷积层的偏置项 bias,一般预训练模型中没有,所以需要忽略,即使用默认的零初始化)。最后一个参数 reshape_variabels 指定对某些变量进行变形,这个一般用不到,使用默认的 False 即可。

        有了以上的基础,而且你还阅读过上一篇文章 TensorFlow-slim 训练 CNN 分类模型(续) 的话,那么整个使用预训练模型的训练文件 train.py 就很容易写出了,如下(重点在函数 get_init_fn):

# -*- coding: utf-8 -*-
"""
Created on Thu Oct 11 17:21:35 2018

@author: shirhe-lyh
"""

"""Train a CNN classification model via pretrained ResNet-50 model.

Example Usage:
---------------
python3 train.py \
    --checkpoint_path: Path to pretrained ResNet-50 model.
    --record_path: Path to training tfrecord file.
    --logdir: Path to log directory.
"""

import os
import tensorflow as tf

import model
import preprocessing

slim = tf.contrib.slim
flags = tf.app.flags

flags.DEFINE_string('record_path', 
                    '/data2/raycloud/jingxiong_datasets/AIChanllenger/' +
                    'AgriculturalDisease_trainingset/train.record',
                    'Path to training tfrecord file.')
flags.DEFINE_string('checkpoint_path', 
                    '/home/jingxiong/python_project/model_zoo/' +
                    'resnet_v1_50.ckpt', 
                    'Path to pretrained ResNet-50 model.')
flags.DEFINE_string('logdir', './training', 'Path to log directory.')
flags.DEFINE_float('learning_rate', 0.0001, 'Initial learning rate.')
flags.DEFINE_float(
    'learning_rate_decay_factor', 0.1, 'Learning rate decay factor.')
flags.DEFINE_float(
    'num_epochs_per_decay', 3.0,
    'Number of epochs after which learning rate decays. Note: this flag counts '
    'epochs per clone but aggregates per sync replicas. So 1.0 means that '
    'each clone will go over full epoch individually, but replicas will go '
    'once across all replicas.')
flags.DEFINE_integer('num_samples', 32739, 'Number of samples.')
flags.DEFINE_integer('num_steps', 10000, 'Number of steps.')
flags.DEFINE_integer('batch_size', 48, 'Batch size')

FLAGS = flags.FLAGS


def get_record_dataset(record_path,
                       reader=None, 
                       num_samples=50000, 
                       num_classes=7):
    """Get a tensorflow record file.
    
    Args:
        
    """
    if not reader:
        reader = tf.TFRecordReader
        
    keys_to_features = {
        'image/encoded': 
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': 
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/class/label': 
            tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1], 
                               dtype=tf.int64))}
        
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(image_key='image/encoded',
                                              format_key='image/format'),
        'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)
    
    labels_to_names = None
    items_to_descriptions = {
        'image': 'An image with shape image_shape.',
        'label': 'A single integer.'}
    return slim.dataset.Dataset(
        data_sources=record_path,
        reader=reader,
        decoder=decoder,
        num_samples=num_samples,
        num_classes=num_classes,
        items_to_descriptions=items_to_descriptions,
        labels_to_names=labels_to_names)
    
    
def configure_learning_rate(num_samples_per_epoch, global_step):
    """Configures the learning rate.
    
    Modified from:
        https://github.com/tensorflow/models/blob/master/research/slim/
        train_image_classifier.py
    
    Args:
        num_samples_per_epoch: he number of samples in each epoch of training.
        global_step: The global_step tensor.
        
    Returns:
        A `Tensor` representing the learning rate.
    """
    decay_steps = int(num_samples_per_epoch * FLAGS.num_epochs_per_decay /
                      FLAGS.batch_size)
    return tf.train.exponential_decay(FLAGS.learning_rate,
                                      global_step,
                                      decay_steps,
                                      FLAGS.learning_rate_decay_factor,
                                      staircase=True,
                                      name='exponential_decay_learning_rate')
    
    
def get_init_fn(checkpoint_exclude_scopes=None):
    """Returns a function run by che chief worker to warm-start the training.
    
    Modified from:
        https://github.com/tensorflow/models/blob/master/research/slim/
        train_image_classifier.py
    
    Note that the init_fn is only run when initializing the model during the 
    very first global step.
    
    Args:
        checkpoint_exclude_scopes: Comma-separated list of scopes of variables
            to exclude when restoring from a checkpoint.
    
    Returns:
        An init function run by the supervisor.
    """
    if FLAGS.checkpoint_path is None:
        return None
    
    # Warn the user if a checkpoint exists in the train_dir. Then we'll be
    # ignoring the checkpoint anyway.
    if tf.train.latest_checkpoint(FLAGS.logdir):
        tf.logging.info(
            'Ignoring --checkpoint_path because a checkpoint already exists ' +
            'in %s' % FLAGS.logdir)
        return None
    
    exclusions = []
    if checkpoint_exclude_scopes:
        exclusions = [scope.strip() for scope in 
                     checkpoint_exclude_scopes.split(',')]
    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
        if not excluded:
            variables_to_restore.append(var)
    
    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
        checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
        checkpoint_path = FLAGS.checkpoint_path

    tf.logging.info('Fine-tuning from %s' % checkpoint_path)
    
    return slim.assign_from_checkpoint_fn(
        checkpoint_path,
        variables_to_restore,
        ignore_missing_vars=True)


def get_trainable_variables(checkpoint_exclude_scopes=None):
    """Return the trainable variables.
    
    Args:
        checkpoint_exclude_scopes: Comma-separated list of scopes of variables
            to exclude when restoring from a checkpoint.
    
    Returns:
        The trainable variables.
    """
    exclusions = []
    if checkpoint_exclude_scopes:
        exclusions = [scope.strip() for scope in 
                     checkpoint_exclude_scopes.split(',')]
    variables_to_train = []
    for var in tf.trainable_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
        if not excluded:
            variables_to_train.append(var)
    return variables_to_train


def main(_):
    # Specify which gpu to be used
    os.environ["CUDA_VISIBLE_DEVICES"] = '1'
    
    num_samples = FLAGS.num_samples
    dataset = get_record_dataset(FLAGS.record_path, num_samples=num_samples, 
                                 num_classes=61)
    data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
    image, label = data_provider.get(['image', 'label'])
    
    # Border expand and resize
    image = preprocessing.border_expand(image, resize=True, output_height=368,
                                        output_width=368)
        
    inputs, labels = tf.train.batch([image, label],
                                    batch_size=FLAGS.batch_size,
                                    #capacity=5*FLAGS.batch_size,
                                    allow_smaller_final_batch=True)
    
    cls_model = model.Model(is_training=True, num_classes=61)
    preprocessed_inputs = cls_model.preprocess(inputs)
    prediction_dict = cls_model.predict(preprocessed_inputs)
    loss_dict = cls_model.loss(prediction_dict, labels)
    loss = loss_dict['loss']
    postprocessed_dict = cls_model.postprocess(prediction_dict)
    acc = cls_model.accuracy(postprocessed_dict, labels)
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('accuracy', acc)

    global_step = slim.create_global_step()
    learning_rate = configure_learning_rate(num_samples, global_step)
    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, 
                                           momentum=0.9)
#    optimizer = tf.train.AdamOptimizer(learning_rate=0.00001)
    vars_to_train = get_trainable_variables()
    train_op = slim.learning.create_train_op(loss, optimizer,
                                             summarize_gradients=True,
                                             variables_to_train=vars_to_train)
    tf.summary.scalar('learning_rate', learning_rate)
    
    init_fn = get_init_fn()
    
    slim.learning.train(train_op=train_op, logdir=FLAGS.logdir, 
                        init_fn=init_fn, number_of_steps=FLAGS.num_steps,
                        save_summaries_secs=20,
                        save_interval_secs=600)
    
if __name__ == '__main__':
    tf.app.run()

        函数 get_init_fn 从指定路径下读取预训练模型。如果没有指定预训练模型路径(FLAGS.checkpoint_path),则返回 None(表示随机初始化参数)。如果在训练路径下(FLAGS.logdir)已经保存过训练后的模型,也返回 None(即忽略预训练模型参数,而使用最后训练保存下来的模型初始化参数)。如果你只想导入部分层的预训练参数,而忽略其它层的预训练参数,则可以设置 checkpoint_exclude_scopes 这个参数,用来指定你要排除掉(即不需要导入预训练参数)的那些层的名字,比如你要禁用第一卷积层,以及第一个 block1,只需要指定:

checkpoint_exclude_scopes = 'resnet_v1_50/conv1,resnet_v1_50/block1'
init_fn = get_init_fn(checkpoint_exclude_scopes)

函数 get_trainable_variables 的作用是获取需要训练的变量,它默认返回所有可训练的变量。当你需要冻结一些层,让这些层的参数不更新时,通过参数 checkpoint_exclude_scopes 指定,比如我想让 ResNet-50 的 block1block2/unit_1 冻结时,通过:

scopes_to_freeze = 'resnet_v1_50/block1,resnet_v1_50/block2/unit_1'
vars_to_train = get_trainable_variables(scopes_to_freeze )

调用即可。

三、数据集以及训练

        本文 GitHub/finetune_classification 上的代码默认使用 AI Challenger 全球AI挑战赛/农作物病害检测 数据集。下载好数据集之后,执行如下指令:

$ python3 generate_tfrecord.py \
    --images_dir Path/to/AgriculturalDisease_trainingset/images \
    --annotation_path Path/to/AgriculturalDisease_train_annotations.json \
    --output_path Path/to/train.record

将训练集图像写入到 train.record 文件中。之后,执行:

$ python3 train.py \
    --record_path Path/to/train.record \
    --checkpoint_path Path/to/pretrained_ResNet-50_model/resnet_v1_50.ckpt

开始训练。训练开始之后,会在当前 train.py 路径下生成一个文件夹 training 用来保存训练模型。需要额外说明的是,训练过程不会在终端输出准确率、损失等数据,需要在终端执行:

$ tensorboard --logdir Path/to/training

之后,打开返回的 http 链接在浏览器查看准确率、损失等训练曲线(训练过程中,训练结束后都可查看)。训练正常启动后,每 10 分钟会保存一次模型到 training 文件夹(诸如 model.ckpt-xxx 之类的文件),你可以选择使用其中的 model.ckpt-xxx 模型来直接进行预测,也可以选择将 model.ckpt-xxx 转化为 .pb 文件之后再进行预测,如果选择转化,执行:

$ python3 export_inference_graph.py \
    --trained_checkpoint_prefix Path/to/model.ckpt-xxx \
    --output_directory Path/to/exported_pb_file_directory

之后,在指定的输出路径下(Path/to/exported_pb_file_directory)会生成一个文件夹,该文件内的 frozen_inference_graph.pb 即是转化成的固化模型文件(固化指的是所有参数都转化成了常数)。之后就可以使用 evaluate.py 或者 predict.py 进行验证或预测了。

        如果你使用其它数据集,整个训练过程和上面的步骤一样,只需要根据具体的标注文件来修改文件data_provider.py 中函数 provide,该函数返回一个字典,其中 key 代表训练数据集中图像的路径,value 代表图像对应的类标号;其它参数,比如训练图像个数,类别数目,学习率等,在 train.py 中修改。

预告:下一篇文章将要介绍如何用 TensorFlow 来训练多任务多标签模型,敬请期待!

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 205,236评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,867评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,715评论 0 340
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,899评论 1 278
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,895评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,733评论 1 283
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,085评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,722评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,025评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,696评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,816评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,447评论 4 322
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,057评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,009评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,254评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,204评论 2 352
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,561评论 2 343