tensroflow 模型保存、加载

参考:
[http://blog.csdn.net/scotthuang1989/article/details/77769412]
[http://blog.csdn.net/LordofRobots/article/details/77719020]

tensorflow由于其开源特性,因此API经常发生变化。保存加载模型也发生了一些变化。
本次博客针对与tensorflow1.0以后1.3以前的版本,之后又有什么变化就不知道了。下面进入正题。

常见的持久化方法有两种:一种是保存为ckpt(以前的,现在该改后缀了就叫多文件),一种是graph_def文件。

一、多文件

源代码位于 tensorflow/python/training/saver.py
Saver类可以使用保存以及从某一个检查点恢复数据。
你可以之保存固定数量的检查点(恢复点),比如你可以只保存最近的多少个检查点文件或者是在训练时每隔几个小时保存一次。

tf.train.Saver的初始化

__init__(
    var_list=None,#一系列的Variable,SaveableObject, 或者dict名称,如果没有则表示所有可保存的变量
    reshape=False,#如果为True,表示允许恢复数据室variables有着不同的shape。
    sharded=False,#如果为Ture,表示在不同的设备上通向检查点
    max_to_keep=5,#保持几个检查点,默认是五个
    keep_checkpoint_every_n_hours=10000.0,#间隔多久保存一次检查点。
    name=None,#string,可选的变量,具体没看明白
    restore_sequentially=False,#一个Bool值,可以在恢复大量模型的时候减少内存的使用
    saver_def=None,#用于替换当前的保存builder,这个builder是之前的只保存Graph的那种Saver。将其保存成proto的形式。有点没有理解。
    builder=None,#当没有提供saver_def的时候可选的,默认为BaseSaverBuilder()
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,#是用那个版本进行保存
    pad_step_number=False,
    save_relative_paths=False,
    filename=None
)

之后你会得到tf.train.Saver的一个变量。

save

保存变量。
需要包含了需要保存的图的session,并且这些变量需要已经初始化过了,这个方法会返回最新的检查点的文件名称位置,这个位置可以用于调用restore().

save(
    sess,
    save_path,
    global_step=None,
    latest_filename=None,
    meta_graph_suffix='meta',
    write_meta_graph=True,
    write_state=True
)

示例代码:

import tensorflow as tf

# prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')  # name is very important in restoration
w2 = tf.Variable(tf.random_normal(shape=[2]), name='w2')
b1 = tf.Variable(2.0, name='bias1')
feed_dict = {w1: [10, 3], w2: [5, 5]}
# define a test operation that will be restored
w3 = tf.add(w1, w2)  # without name, w3 will not be stored
w4 = tf.multiply(w3, b1, name="op_to_restore")
#最多备份四次,默认每一个小时保存一次
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=1)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(w1))
print(sess.run(w4, feed_dict))
# 保存模型,保存在metaTest文件夹下,文件的名称为my_test_model
saver.save(sess, 'metaTest/my_test_model')

生成的结果为:

2017-09-13 21-51-25 的屏幕截图.png

其中,.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

模型恢复

import tensorflow as tf
sess = tf.Session()
#恢复网络结构
saver = tf.train.import_meta_graph('metaTest/my_test_model.meta')
#恢复参数tf.train.latest_checkpoint('save/'),获取保存在那个位置的最新的文件
saver.restore(sess,tf.train.latest_checkpoint('metaTest/'))
#获取当前的默认图结构
graph = tf.get_default_graph()
#获取某一个tensor
w1 = graph.get_tensor_by_name('w1:0')
print(sess.run(w1))
w2 = graph.get_tensor_by_name('w2:0')
feed_dict = {w1:[-1,1],w2:[4,6]}
op_to_restore = graph.get_tensor_by_name('op_to_restore:0')
print(sess.run(op_to_restore,feed_dict))

如果删除saver.restore(sess,tf.train.latest_checkpoint('metaTest /')),则会报错。

二、graph_def文件

我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。

graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

#保存模型以及参数
import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 这里需要填入输出tensor的名字
    graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
    tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

会在当前文件夹下生成graph.pb文件,其中包含了网络结构以及所有的参数。

#恢复参数
import tensorflow as tf
with tf.Session() as sess:
    with open('./graph.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read()) 
        output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
        print(sess.run(output))

恢复模型的API:
作用:将graph_def引入到默认的Graph中。该方法提供了将序列化了的graph_def的pb文件重新加载到网络中的功能,并且将graph_def中的每一个objects转化为tf.Tensor或者tf.Operation的格式。一旦使用了这个方法,这些结构就会在当前的Graph中。可看 tf.Graph.as_graph_def查看关于GraphDef更详细的定义。

import_graph_def(
    graph_def,#一个graphDef文件
    input_map=None,#将graph_def中定义的map名称的输入转化为Tensor,方便给值
    return_elements=None,#将在[]中出现的graph_def中的操作转化为Operation,或者将graph_def中的tensor names转化为Tensor。
    name=None,
    op_dict=None,
    producer_op_list=None
)
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,445评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,889评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,047评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,760评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,745评论 5 367
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,638评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,011评论 3 398
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,669评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,923评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,655评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,740评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,406评论 4 320
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,995评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,961评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,197评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,023评论 2 350
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,483评论 2 342

推荐阅读更多精彩内容

  • 近期做了一些反垃圾的工作,除了使用常用的规则匹配过滤等手段,也采用了一些机器学习方法进行分类预测。我们使用Tens...
    liuyan731阅读 12,700评论 0 19
  • 明天军训就要结束了,我想写写这几天的感悟。通过这几天的军训,我的感悟很多,首先,认识了很多的新同事,通过这几天的了...
    Slowsoul_dc8c阅读 254评论 0 0