TensorFlow 模型保存与恢复

        上一篇文章 TensorFlow 训练 CNN 分类器 中说明了训练简单 CNN 模型的整个过程,并在训练结束后使用 .save 函数来保存训练的结果,其后通过使用 tf.train.import_meta_graph.restore 函数来导入模型进行推断。本文承接上文,对模型保存与恢复做一个总结。

        总的来说,模型在保存和恢复时最重要的是留下数据接口,方便使用时传入数据和获取结果。TensorFlow 中常用的模型保存格式为 .ckpt 和 .pb,下面分别进行详细说明。

一、ckpt 格式模型保存与恢复

        .ckpt 格式保存与恢复都很简单,具体可参考 TensorFlow 训练 CNN 分类器

1. ckpt 格式模型保存

inputs = tf.placeholder(tf.float32, shape=[None, ···], name='inputs')  <-- 入口
···
prediction = tf.nn.softmax(logits, name='prediction')  <-- 出口(仅作为例子,下同)
···
saver = tf.train.Saver()
···

with tf.Session() as sess:
    ···    <-- 训练过程
    saver.save(sess, './xxx/xxx.ckpt')  <-- 模型保存

        如上述代码所示,假设你定义了一个 TensorFlow 模型,数据入口由占位符 inputs 给定,结果出口由张量 prediction 给定。通过语句 saver = tf.train.Saver() 定义了模型保存的一个实例对象 saver,当模型训练结束之后只需要简单的一条语句:

saver.save(sess, path_to_model.ckpt)

就把训练结果保存到了指定的路径。

        以上代码之所以把变量 inputsprediction 单独列出,一方面是因为它们是模型 Graph 的起点和终点(戏称为数据入口、出口),另一方面的原因是它们被特别的指定了名称,因而在模型恢复时可以通过它们的名称而得到 Graph 中对应的节点。

2. ckpt 格式模型恢复

        当你需要导入模型进行推断时,只需要通过张量名获取数据入口和出口,然后传入数据即可:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('./xxx/xxx.ckpt.meta')
    saver.restore(sess, './xxx/xxx.ckpt')

    inputs = tf.get_default_graph().get_tensor_by_name('inputs:0')
    prediction = tf.get_default_graph().get_tensor_by_name('prediction:0')

    pred = sess.run(prediction, feed_dict={inputs: xxx}

        保存为 .ckpt 模型的一个好处是,当需要继续训练时,只需要将训练过的模型结果导入,然后在这个基础上再继续训练。而下面的 .pb 格式则不能继续训练,因为这种格式保存的模型参数都已经转化为了常量(而不再是变量)。

二、pb 格式模型保存与恢复

        .pb 格式模型保存与恢复相比于前面的 .ckpt 格式而言要稍微麻烦一点,但使用更灵活,特别是模型恢复,因为它可以脱离会话(Session)而存在,便于部署。

1. pb 格式模型保存

        与 .ckpt 格式模型保存类似,首先定义数据入口、出口:

from tensorflow.python.framework import graph_util

···
inputs = tf.placeholder(tf.float32, shape=[None, ···], name='inputs') 
···
prediction = tf.nn.softmax(logits, name='prediction') 
···

with tf.Session() as sess:
    ···    <-- 训练过程
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, 
        graph_def, 
        ['prediction']  <-- 参数:output_node_names,输出节点名
    )
    with tf.gfile.GFile('./xxx/xxx.pb', 'wb') as fid:
        serialized_graph = output_graph_def.SerializeToString()
        fid.write(serialized_graph)

然后通过函数 graph_util.convert_variables_to_constants 将模型固话,使得所有变量转化为常量,之后写入到指定的路径完成模型保存过程。

2. pb 格式模型恢复

        .pb 格式模型恢复自由度较大,不需要在会话里进行操作,可以独立存在:

import os

def load_model(path_to_model.pb):
    if not os.path.exists(path_to_model.pb):
        raise ValueError("'path_to_model.pb' is not exist.")

    model_graph = tf.Graph()
    with model_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(path_to_model.pb, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    return model_graph

模型导入之后,便可以获取数据入口和出口,然后进行推断:

model_graph = load_model('./xxx/xxx.pb')

inputs = model_graph.get_tensor_by_name('inputs:0')
prediction = model_graph.get_tensor_by_name('prediction:0')

with model_graph.as_default():
    with tf.Session(graph=model_graph) as sess:
        ···
        pred = sess.run(prediction, feed_dict={inputs: xxx}

三、ckpt 格式转 pb 格式

        一般情况下,为了便于从断点之处继续训练,模型通常保存为 .ckpt 格式,而一旦对训练结果很满意之后则可能需要将 .ckpt 格式转化为 .pb 格式。转化方法很简单,只需要综合前面的一、二两步即可:

from tensorflow.python.framework import graph_util

with tf.Session() as sess:
    # Load .ckpt file
    saver = tf.train.import_meta_graph('./xxx/xxx.ckpt.meta')
    saver.restore(sess, './xxx/xxx.ckpt')

    # Save as .pb file
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, 
        graph_def, 
        ['prediction']  <-- 输出节点名,以实际情况为准
    )
    with tf.gfile.GFile('./xxx/xxx.pb', 'wb') as fid:
        serialized_graph = output_graph_def.SerializeToString()
        fid.write(serialized_graph)

        预告:下一篇文章将简单介绍 tensorflow.contrib.slim 的应用,敬请关注!

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,332评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,508评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,812评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,607评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,728评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,919评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,071评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,802评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,256评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,576评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,712评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,389评论 4 332
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,032评论 3 316
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,798评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,026评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,473评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,606评论 2 350

推荐阅读更多精彩内容