官方guide: https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/saved_model.md
官方API:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/train/Saver
所有pb格式数据文件的格式定义文件proto:https://github.com/tensorflow/tensorflow/tree/r1.15/tensorflow/core/protobuf
由于tf2开始,tf.train.saver被砍了,于是这里重点介绍tf1.xx版本的模型持久化方式,手动保存与恢复,tf官方定义为low-level API。(high-level 使用 estimator)
一、Saver类
tf.reset_default_graph()
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1),name='v1')
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=2),name='v2')
result = v1 + v2
result2 = v1 * v2
c1 = tf.zeros([2,2], name="c1")
init_op = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=2) # Add ops to save and restore all the variables.
with tf.Session() as sess:
sess.run(init_op)
for x in range(10000):
if x % 1000 == 0:
print('saved:',x,'of 1w')
save_path = saver.save(sess, "Saved_model/test.ckpt", global_step=x)
print(save_path)
print(v1.eval(), v2.eval())
对Saver设置max_to_keep参数,能自动保存下最新的n个模型。得到结果:
Saver中传入的文件名都只需要前缀即可,即若要恢复只要传入test.ckpt-9000,其中-后面的数字代表了step数。00000-of-00001代表device信息,有一个GPU,且在第0个上。checkpoint文件保存并维护着模型列表,以及最新模型的文件名。可使用如下函数获取保存模型的最新文件位置:
ckpt = tf.train.get_checkpoint_state(checkpoint_dir='Saved_model')
print(ckpt.model_checkpoint_path) # Saved_model\test.ckpt-9000
Saver中可传入list或dict,用来指定保存或恢复时候的变量。默认是所有变量。注意,一旦传入,则只会保存或恢复list或dict中的变量,不管其余变量。
-
传入list。
若保存时候的Saver([v1,v2]),则恢复的时候,也要这么指定(除非当前graph里只有v1、v2),否则会报错:变量v3未找到。即恢复的变量应当是保存时候变量的一个子集。要恢复的变量必须存在。 -
传入dict。
有时候,保存模型的时候v1变量名name='v1',但是恢复模型的时候,graph里v1的变量名设定的是name='v11',又因tf是通过对应的变量名去加载的,因此会发生冲突。此时只要在dict中指定{'v1':v1}即可。dict中的key-value对:<String name: Variable 变量的引用>。案例:若我们使用上面的代码保存了模型,变量名为v1和v2,可通过方法查看:
在新文件中,我们重新构建好graph结构,但是我们此时变量v1的name改了,于是需要设定dict映射,来将ckpt中的v1加载到新模型的变量v1(name='v11')中:
注意此时,我们新建的graph中,v3并没有被初始化,也未被Saver指定恢复数值。
那么问题来了,若是模型有100个variable,新构建的graph中,部分variable的name和之前不一样,难道我们还需要手动写dict吗?解决方法:通过tf的collection机制,获取所有variable,组成dict后手动修改部分key值:
官方Notes:
You can create as many Saver objects as you want if you need to save and restore different subsets of the model variables. The same variable can be listed in multiple saver objects; its value is only changed when the Saver.restore() method is run.
If you only restore a subset of the model variables at the start of a session, you have to run an initialize op for the other variables. See tf.variables_initializer for more information.
二、存储机制详解
可以发现,存储后的文件有3种后缀,data、index与meta。同时,tf还提供了大量接口让人混淆:
- tf.train.Saver()/saver.restore()
- saver.export_meta_graph()/ tf.train.import_meta_graph
- tf.train.write_graph()/tf.Import_graph_def()
https://yq.aliyun.com/articles/620067
https://www.jianshu.com/p/ca637520002f
https://zhuanlan.zhihu.com/p/31308381
2.1 GraphDef 之 tf.Import_graph_def()
学术界适合使用上面所阐述的Saver.save()方法持久化模型,能方便之后继续训练或测试。但是工业界需要通用的模型文件,使得Java/C++也能直接部署,调用模型获得输出。所以工业界部署模型推荐tf.Import_graph_def()方法。
graph序列化的protobuf叫做graphDef,就是define graph的意思,一个graph的定义,包含了计算图上的节点信息。这个graphDef可以用tf.train.write_graph()/tf.Import_graph_def()来写入和导出。然而graphDef里面其实是没有存储变量具体数值的,因此无法拿来训练,但是可以存常量,就是constant。因此也可将所有session中持有的变量转constant后(graph.util.convert_variables_to_constants
),存储为pb,拿来部署做inference。这样graph结构信息与变量权重就能归并到一个pb文件中,没了变量初始化、模型保存等辅助节点后,模型文件更小更简洁,是无视语种的数据描述文件,适合工业部署做predict。
2.2 MetaGraph 之 tf.train.import_meta_graph()
tf.train.import_meta_graph() 方法可以直接从.meta
文件中恢复Graph结构,其包含以下几种主要成分:
MetaGraph
- MetaInfoDef 这个是存metadata的,像版本信息啊,用户信息,运算方法信息(比如定义了加法、乘法等,供GraphDef使用)
- GraphDef 上面说的就是这个GraphDef,包含了节点信息。(节点使用了哪种运算操作、输入输出都是什么)
- SaverDef 记录了所有持久化相关的参数,包括存储与恢复使用的op的名字、保存频率等
- CollectionDef 集合名称到集合内容的映射
- signature_def 记号标记用于saved_model保存pb的时候使用,定义统一的输入输出名
- AssetFileDef 记录外置文件位置
restore只是去restore variable,常量是在MetaGraph的GraphDef里的。故实验发现没有restore,常量依旧已经获取到了。
总结来看,saver.save()和saver.restore()保存和读取的东西不一致,save会保存所有一坨信息,而restore只是将data里的variable值恢复到当前graph中的对应节点里,graph你得自己新建或使用tf.train.import_meta_graph()。
三、 Saver源码解析
3.1 Saver([var_list]).init()
当传入var_list初始化Saver的时候,若未指定saver_def,则会自动使用build() ---> _build() ---> BaseSaverBuilder() 来创建新的saver_def.
Saver._build():
if not self.saver_def or context.in_eager_mode():
if self._builder is None:
self._builder = BaseSaverBuilder(self._write_version) # 创建BaseSaverBuilder
if self._var_list is None:
self._var_list = variables._all_saveable_objects() # 若未传入var_list则默认设置为所有variable
self.saver_def = self._builder._build_internal( # 使用BaseSaverBuilder来创建saver_def
self._var_list,
reshape=self._reshape,
sharded=self._sharded,
max_to_keep=self._max_to_keep,
keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
name=self._name,
restore_sequentially=self._restore_sequentially,
filename=checkpoint_path,
build_save=build_save, build_restore=build_restore)
再来看BaseSaverBuilder中返回saver_def的关键函数:
def _build_internal(self,
names_to_saveables, # 就是Saver初始化时候的var_list
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
filename="model",
build_save=True,
build_restore=True):
# 首先将var_list转换成names_to_saveables,格式为<k,v>dict键值对:<op_name:op>
# 随后将op一个个取出,将variable包装为VariableSaveable后存入list:saveables并返回
saveables = self._ValidateAndSliceInputs(names_to_saveables)
# 创建op的name前缀:save
with ops.name_scope(name, "save",
[saveable.op for saveable in saveables]) as name:
# Add the Constant string tensor for the filename.
filename_tensor = constant_op.constant(filename or "model")
# Add the save ops. 创建保持和恢复的ops【重要】
if sharded:
... ...
else:
if build_save:
# 为每个saveables中的op添加保存op,并对op_list进行组合并返回组合依赖后的输出tensor
#(通过control_flow_ops.with_dependencies)代表运行此tensor前必须运行全部的保存op
save_tensor = self._AddSaveOps(filename_tensor, saveables)
if build_restore:
restore_op = self._AddRestoreOps(filename_tensor, saveables,
restore_sequentially, reshape)
if context.in_graph_mode():
# 正式构建并返回saver_def
return saver_pb2.SaverDef(
filename_tensor_name=filename_tensor.name,
save_tensor_name=save_tensor.name,
restore_op_name=restore_op.name,
max_to_keep=max_to_keep,
sharded=sharded,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
version=self._write_version)
总结:Saver初始化的时候,就已经根据传入的var,对每个var添加了对应的保存和恢复的op操作。同时构建了saverDef,是metaGraph重要的一部分。通过该saverDef可以将很多记录和参数持久化为pb,比如文件名的constant op的name,保存流程后的输出tensor的name等等。通过这些信息,就能从metaGraph中还原出Saver实例。(Saver只要知道节点name即可,就能通过name从graphDef中获得对应的op,而这些op就是保存/恢复op_list后输出的op,运行这个op即可运行之前的(依赖着的)所有的保存恢复ops。)
通过tf.train.export_meta_graph我们可以获得序列化后的metaGraph:
以上代码使用了三次export_meta_graph,分别不同:
- 第一次未使用Saver,tf.train.export_meta_graph直接输出。
- 第二次初始化了Saver,tf.train.export_meta_graph直接输出。
- 第三次初始化Saver之后使用saver.export_meta_graph输出。(默认带上了saver_def)
从结果json我们发现符合我们的代码分析:
- 第一个json不包含save/xxx节点。
- 第二个json包含了save/xxx节点,证明了Saver在初始化了时候就已经给图中的variable加上了保存和恢复的ops。但是默认改方法不带saverDef,所以没有这个结构。tf.train.import_meta_graph函数使用后无法重建Saver,所以返回None。
- 第三个json包含了save/xxx节点与SaverDef,因为saver的export函数默认传入了saver初始化时构建好的saverDef,这样才能在tf.train.import_meta_graph函数使用后返回重建的Saver实例,否则返回None。
3.2 Saver.save() 函数:
def save(self,
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix="meta", # meta_graph默认后缀名.meta
write_meta_graph=True, # 若改False,则不会生成.meta
write_state=True): # 若改False,则默认保存所有模型文件且无checkpoint文件
需要传入sess,因为当前的session持有着变量相关信息,而save一定会运行Saver()类初始化时候定义的ops从而持久化变量数据(.data与.index)。
write_meta_graph=True代表保存变量数值(.data与.index)的同时会保存metaGraph。而上面已经介绍过,metaGraph中持有重新构建图的所有信息。write_state=True则会默认生成checkpoint文件自动记录训练文件名称。
关键的恢复代码:
if context.in_graph_mode():
model_checkpoint_path = sess.run(
self.saver_def.save_tensor_name,
{self.saver_def.filename_tensor_name: checkpoint_file})
这一步就是运行之前Saver()初始化之后,创建的saver_def中的op:save_tensor_name。这一个op实际上是graph_def中定义的一个node:"save/control_dependency:0"。而这个op本身是无意义的,其实是为了调用它所依赖的variables身上的store op。同时传入文件名参数。这样就执行了保持variable数值的ops。
但是我们知道save()方法不光光如此,它还会生成.meta
与checkpoint
:
if write_state:
self._RecordLastCheckpoint(model_checkpoint_path)
_update_checkpoint_state(
save_dir=save_path_parent,
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=self.last_checkpoints,
latest_filename=latest_filename,
save_relative_paths=self._save_relative_paths)
self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix)
上述两个操作分别是更新checkpoint文件以及删除过时的旧模型文件。其实checkpoint本身也是pb,只不过它不影响效率,就使用text_format.MessageToString(ckpt)
将pb message转换为了text,方便直接打开看和修改。
以下是生成metaGraph(.meta
文件)的关键代码:
if write_meta_graph:
meta_graph_filename = self._MetaGraphFilename(
checkpoint_file, meta_graph_suffix=meta_graph_suffix)
if context.in_graph_mode():
with sess.graph.as_default():
self.export_meta_graph(meta_graph_filename)
... ...
def saver().export_meta_graph:
return export_meta_graph(
filename=filename,
graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
saver_def=self.saver_def,
collection_list=collection_list,
as_text=as_text,
export_scope=export_scope,
clear_devices=clear_devices,
clear_extraneous_savers=clear_extraneous_savers)
本质调用了Saver().export_meta_graph() ----> tf.train.export_meta_graph() 最关键的是加入了该类存储着的saver_def,因此输出的.meta文件里是包含saver_def的,下次可以用来恢复Saver(其实是记录着restore_all关键node的name,Saver()初始化的时候已经向graph_def里添加好了所有的save/restore的node)。
3.2 Saver.restore() 函数:
这个函数就非常简单了,关键代码:
if context.in_graph_mode():
sess.run(self.saver_def.restore_op_name,
{self.saver_def.filename_tensor_name: save_path})
就是导入.meta文件后,里面的saver_def二进制流信息重建并返回了Saver实例,然后就能获取到restore所有variables的那个op的名字,然后去运行。该op:restore_op_name: "save/restore_all"同样也是依赖于所有variable的assign操作,即变量赋值。