tensorflow 恢复(restore)模型的两种方式

image

1. 介绍

首先我们要理解TensorFlow的一个规则,首先构建计算图(graph),然后初始化graph中的data,这两步是分开的。

2. 如何恢复模型

有两种方式(这两种方式有比较大的不同):

2.1 重新使用代码构建图

举个例子(完整代码):

def build_graph():
    w1 = tf.Variable([1,3,10,15],name='W1',dtype=tf.float32)
    w2 = tf.Variable([3,4,2,18],name='W2',dtype=tf.float32)
    w3 = tf.placeholder(shape=[4],dtype=tf.float32,name='W3')
    w4 = tf.Variable([100,100,100,100],dtype=tf.float32,name='W4')
    add = tf.add(w1,w2,name='add')
    add1 = tf.add(add,w3,name='add1')
    return w3,add1

with tf.Session() as sess:
    ckpt_state = tf.train.get_checkpoint_state('./temp/')
    if ckpt_state:
        w3,add1=build_graph()
        saver = tf.train.Saver()
        saver.restore(sess, ckpt_state.model_checkpoint_path)
    else:
        w3,add1=build_graph()
        init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
        sess.run(init_op)
        saver = tf.train.Saver()
    a = sess.run(add1,feed_dict={
            w3:[1,2,3,4]
        })
    print(a)
    saver.save(sess,'./temp/model')

上面的流程很简单,首先build_graph(),然后如果有ckpt文件就从该文件中读取数据,否则用sess.run(init_op)初始化数据。

那么第一种restore方法就出来了:

build_graph()
saver = tf.train.Saver()
saver.restore(sess, ckpt_state.model_checkpoint_path)

首先build graph,等于是将图重新建立了一遍,和之前图的一样,然后将ckpt文件里的数据restore到图里的变量里。

当然,在build graph的过程中,你可以在原有的图里加一些变量,但是加的变量一定要初始化,但是要注意到一个问题,如果使用:

init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)

这种方式时,如果定义init_op时的graph中已经存在原有图的变量,那么sess.run(init_op)会将加载进来的数据清空。

为了解决这个问题,两种方式:

  1. 新定义的变量放在init_op之前,在init_op之后restore(注意,加载好变量后才run(init_op)同样会覆盖)
    即,init_op得到当前图中的所有变量,sess.run(init_op)对init_op中的变量进行初始化,所以什么时候定义init_op和什么时候运行run(init_op)都很重要

  2. 只初始化未初始化的变量

def get_uninitialized_variables(sess):
global_vars = tf.global_variables()

# print([str(i.name) for i in global_vars])

is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
print([str(i.name) for i in not_initialized_vars])
return not_initialized_vars
sess.run(tf.variables_initializer(get_uninitialized_variables(sess)))

PS:注意saver = tf.train.Saver()要定义在图构建完成之后

​ 即将被restore的变量不用初始化,但是只有在restore之后,这些变量才会被初始化,所以在restore之前运行这些值会报没有初始化的错。

2.2 利用保存的.meta文件恢复图

参考:Tensorflow如何保存、读取model (即利用训练好的模型测试新数据的准确度)

上面的方式适用于断点续训,且自己有构建图的完整代码,如果我要用别人的网络(fine tune),或者在自己原有网络上修改(即修改原有网络的某个部分),那么将网络的图重新构建一遍会很麻烦,那么我们可以直接从.meta文件中加载网络结构。

2.2.1 get_tensor_by_name

完整代码:

def build_graph():
    w1 = tf.Variable([1,3,10,15],name='W1',dtype=tf.float32)
    w2 = tf.Variable([3,4,2,18],name='W2',dtype=tf.float32)
    w3 = tf.placeholder(shape=[4],dtype=tf.float32,name='W3')
    w4 = tf.Variable([100,100,100,100],dtype=tf.float32,name='W4')
    add = tf.add(w1,w2,name='add')
    add1 = tf.add(add,w3,name='add1')
    return w3,add1

with tf.Session() as sess:
    ckpt_state = tf.train.get_checkpoint_state('./temp/')
    if ckpt_state:
        saver = tf.train.import_meta_graph('./temp/model.meta')
        graph = tf.get_default_graph()
        w3 = graph.get_tensor_by_name('W3:0')
        add1 = graph.get_tensor_by_name('add1:0')
        saver.restore(sess, tf.train.latest_checkpoint('./temp/'))
        print(sess.run(tf.get_collection('w1')[0]))
    else:
        w3,add1=build_graph()
        init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
        sess.run(init_op)
        saver = tf.train.Saver()
    a = sess.run(add1,feed_dict={
            w3:[1,2,3,4]
        })
    print(a)
    saver.save(sess,'./temp/model')

上面使用了import_meta_graph()来加载图,并用restore给变量赋值。

通过get_tensor_by_name来获取保存的图中的op或变量,之后可以对获取的值进行操作,如果之后save的话,也会将import_meta_graph()中图引用的部分保存下来。

2.2.2

def build_graph():
    w1 = tf.Variable([1,3,10,15],name='W1',dtype=tf.float32)
    w2 = tf.Variable([3,4,2,18],name='W2',dtype=tf.float32)
    w3 = tf.placeholder(shape=[4],dtype=tf.float32,name='W3')
    w4 = tf.Variable([100,100,100,100],dtype=tf.float32,name='W4')
    add = tf.add(w1,w2,name='add')
    add1 = tf.add(add,w3,name='add1')
    tf.add_to_collection('w1','W1:0')
    tf.add_to_collection('w3',w3)
    tf.add_to_collection('add1',add1)
    return w3,add1

with tf.Session() as sess:
    ckpt_state = tf.train.get_checkpoint_state('./temp/')
    if ckpt_state:
        saver = tf.train.import_meta_graph('./temp/model.meta')
        w3 = tf.get_collection('w3')[0]
        add1 = tf.get_collection('add1')[0]
        # run init_op before restore
        saver.restore(sess, tf.train.latest_checkpoint('./temp/'))
    else:
        w3,add1=build_graph()
        init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
        sess.run(init_op)
        saver = tf.train.Saver()
    a = sess.run(add1,feed_dict={
            w3:[1,2,3,4]
        })
    print(a)
    saver.save(sess,'./temp/model')

通过import_meta_graph引进图,通过get_collection获得变量,其实和get_tensor_by_name差不多,但是可能会更方便一点。

3. 总结

总的来说,两种方式都是先构造好图,然后通过restore来给图里的变量赋值。

一个常见的问题是,要引入新的变量,对以前的图进行改造,那么如何初始化新的变量且不覆盖原来的数据?

  • 可以先啥都不管把所有的图相关的部分构造好后,得到init_op,然后在restore前run(init_op)
  • 对未初始化的变量进行初始化

4. 最后

image
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,014评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,796评论 3 386
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,484评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,830评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,946评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,114评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,182评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,927评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,369评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,678评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,832评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,533评论 4 335
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,166评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,885评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,128评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,659评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,738评论 2 351

推荐阅读更多精彩内容