描述
- 在训练神经网络模型的时候,当模型训练完之后,确切地说当训练的session关闭之后,我们训练出来的模型参数会全部丢失,从而无法有效复用模型,而TensorFlow中提供了很好地保存模型和提取模型的方法。
方法
保存模型
- 方法如下
import tensorflow as tf
'''导入其它库'''
pass
'''搭建网络及其他准备工作'''
pass
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
'''设置模型保存器'''
m_saver = tf.train.Saver()
'''迭代训练'''
for i in range(n):
'''训练模型'''
pass
if i % e == 0:
'''每隔e代保存一次模型'''
'''model_path 和model_name分别是保存模型文件的路径和文件名'''
'''global_step设置i作为每个模型文件名的后缀'''
m_saver.save(sess, "model_path/model_name", global_step=i)
-
如果你搭建的网络模型没有问题的话,那么在对应的文件目录下将会看到16个文件,TensorFlow只会保存最近5次的模型,每一个模型会有三个文件,外加一个checkpoint文件,下图我的一个示例:
各文件说明
checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是
tf.train.Saver
类自动生成且自动维护的。.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构 ,TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。
.data-00000-of-00001文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。
.index是对应模型的索引文件
提取模型
- 看完上面的文件说明,大概就知道提取模型的步骤了,如下:
'''在一个新的python脚本文件中'''
import tensorflow as tf
'''导入其他库'''
pass
'''其他数据准备工作'''
'''这里不需要重新搭建模型'''
'''提取模型,首先提取计算图,这一步相当于搭建模型'''
saver = tf.train.import_meta_graph("model/mnist.ann-10000.meta")
with tf.Session() as sess:
'''提取保存好的模型参数'''
'''这里注意模型参数文件名要丢弃后缀.data-00000-of-00001'''
saver.restore(sess, "model/mnist.ann-10000")
'''通过张量名获取张量'''
'''这里按张量名获取了我保存的一个模型的三个张量,并换上新的名字'''
new_x = tf.get_default_graph().get_tensor_by_name("x:0")
new_y = tf.get_default_graph().get_tensor_by_name("y:0")
new_y_ = tf.get_default_graph().get_tensor_by_name("y_:0")
'''现在可以进行计算了'''
y_1 = sess.run(new_y_, feed_dict={new_x: new_x_data, new_y: new_y_data})
print(y_1)
- 关于上面代码中按张量名获取张量中的
("x:0")
,如果改成("x")
,则会报错:ValueError: The name 'x' refers to an Operation, not a Tensor. Tensor names must be of the form "<op_name>:<output_index>".
其他
- 其实
tf.train.Saver()
有很多参数可以设置,包括最大保存模型的数量等等,这里给出上面用到的函数的声明: - 完整的Saver类定义在tensorflow/python/training/saver.py.中。
def __init__( self,
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=saver_pb2.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None):
"""Creates a `Saver`."""
def save( self,
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix="meta",
write_meta_graph=True,
write_state=True):
"""Saves variables."""
def restore(self, sess, save_path):
"""Restores previously saved variables."""
def import_meta_graph(meta_graph_or_file, clear_devices=False,
import_scope=None, **kwargs):
"""Recreates a Graph saved in a `MetaGraphDef` proto."""