Tensorflow 训练好的模型保存和载入

方法一 这种存储方式在加载模型时需要再次定义网络结构

模型训练和存储

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os

mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)
print (mnist)

learning_rate = 0.01
training_epochs = 5
batch_size = 100
display_step = 1

X = tf.placeholder(tf.float32,[None,784])
Y = tf.placeholder(tf.float32,[None,10])

W = tf.Variable(tf.zeros([784,10]),name="W")
b = tf.Variable(tf.zeros([10]),name="b")

pred = tf.nn.softmax(tf.matmul(X,W) + b)
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(pred), reduction_indices =1))

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()

##初始化存储器和存储路径
saver = tf.train.Saver(max_to_keep=4)
model_path = "./model/lr"
path = os.path.dirname(os.path.abspath(model_path))
if os.path.isdir(path) is False:
    os.makedirs(path)

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(training_epochs):
        avg_cost = 0
        total_batch = int(mnist.train.num_examples/batch_size)
        for i in range(total_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            _,c = sess.run([optimizer,cost],feed_dict={X:batch_xs,Y:batch_ys})
            avg_cost += c / total_batch
        if (epoch + 1) % display_step == 0:
            print ("Epoch:","%04d" % (epoch + 1),"cost=","{}".format(avg_cost))
        saver.save(sess,model_path,write_meta_graph=True)
    print ("Optimization Finished")

    correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))
    accuracy  = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    print ("Accuracy:",accuracy.eval({X:mnist.test.images[:3000],Y:mnist.test.labels[:3000]}))

加载模型

import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_data


mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)

X = tf.placeholder(tf.float32,[None,784])
Y = tf.placeholder(tf.float32,[None,10])

with tf.Session() as sess:
    saver = tf.train.import_meta_graph("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model/lr.meta")
    saver.restore(sess,tf.train.latest_checkpoint("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model"))
    graph = tf.get_default_graph()
    W = graph.get_tensor_by_name("W:0")
    b = graph.get_tensor_by_name("b:0")
    pred = tf.nn.softmax(tf.matmul(X,W) + b)

    correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))

    accuracy  = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

    print ("Accuracy:",accuracy.eval({X:mnist.test.images[:3000],Y:mnist.test.labels[:3000]}))

方法二 这种存储方式在加载模型时不用定义网络结构

模型训练和存储

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os

mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)
print (mnist)

learning_rate = 0.01
training_epochs = 5
batch_size = 100
display_step = 1

X = tf.placeholder(tf.float32,[None,784],name="X")
Y = tf.placeholder(tf.float32,[None,10],name="Y")

W = tf.Variable(tf.zeros([784,10]),name="W")
b = tf.Variable(tf.zeros([10]),name="b")

pred = tf.nn.softmax(tf.matmul(X,W) + b)
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(pred), reduction_indices =1))

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()

saver = tf.train.Saver(max_to_keep=4)

##把要加载的对象提前加入集合
tf.add_to_collection("pred",pred)

model_path = "./model/lr"
path = os.path.dirname(os.path.abspath(model_path))
if os.path.isdir(path) is False:
    os.makedirs(path)

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(training_epochs):
        avg_cost = 0
        total_batch = int(mnist.train.num_examples/batch_size)
        for i in range(total_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            _,c = sess.run([optimizer,cost],feed_dict={X:batch_xs,Y:batch_ys})
            avg_cost += c / total_batch
        if (epoch + 1) % display_step == 0:
            print ("Epoch:","%04d" % (epoch + 1),"cost=","{}".format(avg_cost))
        saver.save(sess,model_path,write_meta_graph=True)
    print ("Optimization Finished")

    correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))
    accuracy  = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    print ("Accuracy:",accuracy.eval({X:mnist.test.images[:3000],Y:mnist.test.labels[:3000]}))

模型加载

import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)

with tf.Session() as sess:
    saver = tf.train.import_meta_graph("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model/lr.meta")
    saver.restore(sess,tf.train.latest_checkpoint("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model"))

    pred = tf.get_collection("pred")[0]
    graph = tf.get_default_graph()

    X = graph.get_operation_by_name("X").outputs[0]
    Y = graph.get_operation_by_name("Y").outputs[0]

    correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))
    accuracy  = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    print ("Accuracy:",accuracy.eval({X:mnist.test.images[:300],Y:mnist.test.labels[:300]}))

1.Tensorflow模型文件的组成

    主要包含两个文件
  1. 元图 meta graph

保存完整的图结构 包含所有的变量 操作等 扩展名为meta

2.检查点文件 checkpoint

二进制文件 包含所有的权重 偏差 梯度和其他所有保存的值 扩展名是.ckpt , 0.11版本之后不再仅使用一个.ckpt文件来表示了 而是两个文件 .data-00000-of-00001 和.index

其中.data 是包含训练变量的文件

此外还有一个名为checkpoint的文件 用于保存最新检查点的记录

2.如何保存Tensorflow模型

在模型训练完成后 可调用tf.train.Saver()实例来保存所有的参数和计算图

由于tensorflow中的变量只能存在于session中,因此需要在session中调用save 将模型存储


import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2],name='w1’))

w2 = tf.Variable(tf.random_normal(shape=[5]),name=‘w2’)

saver = tf.train.Saver()

with tf.Session() as sess:

    sess.run(tf.global_variable_initialize())

saver.save(sess,’/path/to/save') 

运行后可得以下文件:

model/

├── checkpoint

├── my_test_model.data-00000-of-00001

├── my_test_model.index

└── my_test_model.meta

如果想在1000次迭代之后再保存模型,可通过传递步数来调用save

saver.save(sess,’model_path’,global_step=1000)

image.png

如果想每1000次保存一下模型,由于.meta文件会在第一次保存时创建 而且图结构不会再变化,因此只需要保存模型进一步迭代的数据 而不用存储网络结构 可调用

saver.save(sess,’model_path’,global_step=step,write_meta_graph=False)

如果只想保存最新的4个模型参数,并且希望在训练阶段每两小时保存一个模型,可调用

saver = tf.train.Saver(max_to_keep=4,keep_checkpoint_every_n_hours=2)

如果在tf.train.Saver() 中没有指定任何东西,那么他会保存模型的所有变量,如果只想保存部分变量则需要通过列表或字典的形式将变量传递进去

import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2],name='w1’))

w2 = tf.Variable(tf.random_normal(shape=[5]),name=‘w2’)

saver = tf.train.Saver([w1,w2])

with tf.Session() as sess:

    sess.run(tf.global_variable_initializer())

    sess.run(tf.global_variable_initialize())

    saver.save(sess,’/path/to/save')

3.如何导入一个训练好的模型并进行修改和微调

需要完成两件事情

1.构建网络结构

可通过手动编写代码创建每一层网络结构来重构整个网络

保存模型时也会将网络结构存储到.meta文件中,可直接调用tf.train.import()函数来导入这个模型

saver = tf.train.import_meta_graph(‘model-1000.meta’) 这个操作是将.meta文件中的计算图数据直接附加到当前定义的图中,但是我们仍然需要去加载计算图上所有已经训练好的权重参数

2.加载参数

new_saver.restore(sess,tf.train.latest_checkpoint('./‘))   checkpoint文件所在路径

with tf.Session() as sess:

new_saver = tf.train.import_meta_graph(‘my_test_model-1000.meta’) 

new_saver.restore(sess,tf.train.latest_checkpoint('./')) 
#读取参数 
print(sess.run(‘w1:0')) 

4.恢复任何预先训练好的模型用于预测 (工作中的开发方式)

import tensorflow as tf

w1 = tf.placeholder(‘float’,name='w1’)

w2 = tf.placeholder(‘float’,name=‘w2ww’)

b1 = tf.Variable(2.0,name=‘bias’)

w3 = tf.add(w1,w2)

w4 = tf.multiply(w3,b1,name=‘op_to_restore’)

saver = tf.train.Saver()

with tf.Session() as sess:

  sess.run(tf.global_variables_initializer()) 

  print (sess.run(24,feed_dict={w1:4,w2:8})) 

  saver.save(sess,’test_model’,global_step=1000) 

当需要载入这个模型时,不仅需要恢复所有的计算图和权重参数 还需要准备一个新的feed_dict

用于将新的训练数据传送到网络中进行训练,可通过graph.get_tensor_by_name() 来获得对这些保存的操作和占位符变量的引用

w1 = graph.get_tensor_by_name(‘w1:0’)

op_to_restore = graph.get_tensor_by_name(“op_to_restore:0”)

使用不同的数据来运行相同的网络 则需要通过feed_dict来传递数据

with tf.Session() as sess:

  saver = tf.train.import_meta_graph(’test_model-1000.meta’)

  saver.restore(sess,tf.train.latest_checkpoint(‘./‘)) 

  graph = tf.get_default_graph() 

  w1 = graph.get_tensor_by_name(“w1:0”) 

  w2 = graph.get_tensor_by_name(“w2:0”) 

  feed_dict= {w1:13.0,w2:17.0} 

  op_to_restore = graph.get_tensor_by_name(‘op_to_restore:0’) 

  print (sess.run(op_to_restore,feed_dict)) 

如果想在原来的计算图基础上添加更多的操作和图层,并进行训练

import tensorflow as tf

with tf.Session() as sess:

    saver = tf.train.import_meta_graph(‘my_test_model-1000.meta’)

    saver.restore(sess,tf.train.latest_checkpoint(‘./‘)) 

    graph = tf.get_default_graph() 

    w1 = graph.get_tensor_by_name(‘w1:0’) 

    w2 = graph.get_tensor_by_name(‘w2:0’) 

    feed_dict = {w1:13,w2:17} 

    op_to_restore = graph.get_tensor_by_name(‘op_to_restore:0’) 

    #新添加操作 

    add_on_op = tf.multiply(op_to_restore,2) 

    print (sess.run(add_on_op,feed_dict))
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 219,110评论 6 508
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,443评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 165,474评论 0 356
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,881评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,902评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,698评论 1 305
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,418评论 3 419
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,332评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,796评论 1 316
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,968评论 3 337
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,110评论 1 351
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,792评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,455评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,003评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,130评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,348评论 3 373
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,047评论 2 355

推荐阅读更多精彩内容