Tensorflow 数据保存(2)

1. saver=tf.train.Saver (tf.global_variables(),max_to_keep)
(1)max_to_keep
这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:

saver=tf.train.Saver(max_to_keep=0)

当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即

saver=tf.train.Saver(max_to_keep=1)

(2)tf.global_variables()
只保存tf.global_variables()里的这些变量,如果saver=tf.train.Saver()里面不传入参数,默认保存全部变量

weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值
saver.save(sess,'model.ckpt')


2. saver.save ()
创建完saver对象后,就可以保存训练好的模型了,如:

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

第一个参数sess,第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。

3. saver.restore ()
模型的恢复用的是restore()函数,它需要两个参数restore(sess, save_path),save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint ()来自动获取最后一次保存的模型。如:

model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

  • 在这篇tensorflow教程中,我会解释: 1) Tensorflow的模型(model)长什么样子? 2) 如...
    JunsorPeng阅读 8,851评论 1 6
  • 这篇文章是针对有tensorflow基础但是记不住复杂变量函数的读者,文章列举了从输入变量到前向传播,反向优化,数...
    horsetif阅读 4,910评论 0 1
  • 今天就是2016年最后两个小时了,回想一年,感慨良多,今年一年虽然生活和工作好像都是按部就班,但有半年都在烦躁的心...
    一嘉一阅读 1,598评论 0 0
  • 线性回归原理 如图所示,这是一组二维的数据,我们先想想如何通过一条直线较好的拟合这些散点了?直白的说:尽量让拟合的...
    罗罗攀阅读 8,054评论 5 3
  • 小时候我在父母的光环下生活的无忧无虑,不曾知道生活的无奈和社会生存的残酷,因为父母那坚强的臂膀会帮我撑起整片蓝天,...
    尚孟_b1d4阅读 1,477评论 0 0

友情链接更多精彩内容