TensorFlow1.15 模型持久化源码级解析(一) Saver与存储原理

官方guide: https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/saved_model.md

官方API:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/train/Saver

所有pb格式数据文件的格式定义文件proto:https://github.com/tensorflow/tensorflow/tree/r1.15/tensorflow/core/protobuf

由于tf2开始,tf.train.saver被砍了,于是这里重点介绍tf1.xx版本的模型持久化方式,手动保存与恢复,tf官方定义为low-level API。(high-level 使用 estimator)

一、Saver类

tf.reset_default_graph()
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1),name='v1')
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=2),name='v2')
result = v1 + v2
result2 = v1 * v2
c1 = tf.zeros([2,2], name="c1")

init_op = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=2)  # Add ops to save and restore all the variables.

with tf.Session() as sess:
    sess.run(init_op)
    for x in range(10000):
        if x % 1000 == 0:
            print('saved:',x,'of 1w')
            save_path = saver.save(sess, "Saved_model/test.ckpt", global_step=x)
            print(save_path)
    print(v1.eval(), v2.eval())

对Saver设置max_to_keep参数,能自动保存下最新的n个模型。得到结果:


image.png

Saver中传入的文件名都只需要前缀即可,即若要恢复只要传入test.ckpt-9000,其中-后面的数字代表了step数。00000-of-00001代表device信息,有一个GPU,且在第0个上。checkpoint文件保存并维护着模型列表,以及最新模型的文件名。可使用如下函数获取保存模型的最新文件位置:

ckpt = tf.train.get_checkpoint_state(checkpoint_dir='Saved_model')
print(ckpt.model_checkpoint_path)  # Saved_model\test.ckpt-9000

Saver中可传入list或dict,用来指定保存或恢复时候的变量。默认是所有变量。注意,一旦传入,则只会保存或恢复list或dict中的变量,不管其余变量。

  • 传入list。
    若保存时候的Saver([v1,v2]),则恢复的时候,也要这么指定(除非当前graph里只有v1、v2),否则会报错:变量v3未找到。即恢复的变量应当是保存时候变量的一个子集。要恢复的变量必须存在。
  • 传入dict。
    有时候,保存模型的时候v1变量名name='v1',但是恢复模型的时候,graph里v1的变量名设定的是name='v11',又因tf是通过对应的变量名去加载的,因此会发生冲突。此时只要在dict中指定{'v1':v1}即可。dict中的key-value对:<String name: Variable 变量的引用>。案例:若我们使用上面的代码保存了模型,变量名为v1和v2,可通过方法查看:
    image.png

在新文件中,我们重新构建好graph结构,但是我们此时变量v1的name改了,于是需要设定dict映射,来将ckpt中的v1加载到新模型的变量v1(name='v11')中:


image.png

注意此时,我们新建的graph中,v3并没有被初始化,也未被Saver指定恢复数值。

那么问题来了,若是模型有100个variable,新构建的graph中,部分variable的name和之前不一样,难道我们还需要手动写dict吗?解决方法:通过tf的collection机制,获取所有variable,组成dict后手动修改部分key值:


image.png

官方Notes:

  • You can create as many Saver objects as you want if you need to save and restore different subsets of the model variables. The same variable can be listed in multiple saver objects; its value is only changed when the Saver.restore() method is run.

  • If you only restore a subset of the model variables at the start of a session, you have to run an initialize op for the other variables. See tf.variables_initializer for more information.

二、存储机制详解

可以发现,存储后的文件有3种后缀,data、index与meta。同时,tf还提供了大量接口让人混淆:

  • tf.train.Saver()/saver.restore()
  • saver.export_meta_graph()/ tf.train.import_meta_graph
  • tf.train.write_graph()/tf.Import_graph_def()

https://yq.aliyun.com/articles/620067
https://www.jianshu.com/p/ca637520002f
https://zhuanlan.zhihu.com/p/31308381

2.1 GraphDef 之 tf.Import_graph_def()

学术界适合使用上面所阐述的Saver.save()方法持久化模型,能方便之后继续训练或测试。但是工业界需要通用的模型文件,使得Java/C++也能直接部署,调用模型获得输出。所以工业界部署模型推荐tf.Import_graph_def()方法。

graph序列化的protobuf叫做graphDef,就是define graph的意思,一个graph的定义,包含了计算图上的节点信息。这个graphDef可以用tf.train.write_graph()/tf.Import_graph_def()来写入和导出。然而graphDef里面其实是没有存储变量具体数值的,因此无法拿来训练,但是可以存常量,就是constant。因此也可将所有session中持有的变量转constant后(graph.util.convert_variables_to_constants),存储为pb,拿来部署做inference。这样graph结构信息与变量权重就能归并到一个pb文件中,没了变量初始化、模型保存等辅助节点后,模型文件更小更简洁,是无视语种的数据描述文件,适合工业部署做predict。

variable转constants并持久化pb模型

2.2 MetaGraph 之 tf.train.import_meta_graph()

tf.train.import_meta_graph() 方法可以直接从.meta文件中恢复Graph结构,其包含以下几种主要成分:

MetaGraph

  • MetaInfoDef 这个是存metadata的,像版本信息啊,用户信息,运算方法信息(比如定义了加法、乘法等,供GraphDef使用)
  • GraphDef 上面说的就是这个GraphDef,包含了节点信息。(节点使用了哪种运算操作、输入输出都是什么)
  • SaverDef 记录了所有持久化相关的参数,包括存储与恢复使用的op的名字、保存频率等
  • CollectionDef 集合名称到集合内容的映射
  • signature_def 记号标记用于saved_model保存pb的时候使用,定义统一的输入输出名
  • AssetFileDef 记录外置文件位置
可以看到constant也被恢复了
restore只是去restore variable,常量是在MetaGraph的GraphDef里的

restore只是去restore variable,常量是在MetaGraph的GraphDef里的。故实验发现没有restore,常量依旧已经获取到了。
总结来看,saver.save()和saver.restore()保存和读取的东西不一致,save会保存所有一坨信息,而restore只是将data里的variable值恢复到当前graph中的对应节点里,graph你得自己新建或使用tf.train.import_meta_graph()。

三、 Saver源码解析

3.1 Saver([var_list]).init()

当传入var_list初始化Saver的时候,若未指定saver_def,则会自动使用build() ---> _build() ---> BaseSaverBuilder() 来创建新的saver_def.

Saver._build():

 if not self.saver_def or context.in_eager_mode():
      if self._builder is None:
        self._builder = BaseSaverBuilder(self._write_version)  # 创建BaseSaverBuilder
      if self._var_list is None:
        self._var_list = variables._all_saveable_objects()  # 若未传入var_list则默认设置为所有variable

      self.saver_def = self._builder._build_internal(  # 使用BaseSaverBuilder来创建saver_def
          self._var_list,
          reshape=self._reshape,
          sharded=self._sharded,
          max_to_keep=self._max_to_keep,
          keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
          name=self._name,
          restore_sequentially=self._restore_sequentially,
          filename=checkpoint_path,
          build_save=build_save, build_restore=build_restore)

再来看BaseSaverBuilder中返回saver_def的关键函数:

  def _build_internal(self,
                      names_to_saveables,  # 就是Saver初始化时候的var_list
                      reshape=False,
                      sharded=False,
                      max_to_keep=5,
                      keep_checkpoint_every_n_hours=10000.0,
                      name=None,
                      restore_sequentially=False,
                      filename="model",
                      build_save=True,
                      build_restore=True):

    # 首先将var_list转换成names_to_saveables,格式为<k,v>dict键值对:<op_name:op>
    # 随后将op一个个取出,将variable包装为VariableSaveable后存入list:saveables并返回
    saveables = self._ValidateAndSliceInputs(names_to_saveables)

    # 创建op的name前缀:save
    with ops.name_scope(name, "save",
                        [saveable.op for saveable in saveables]) as name:
      # Add the Constant string tensor for the filename.
      filename_tensor = constant_op.constant(filename or "model")

      # Add the save ops. 创建保持和恢复的ops【重要】
      if sharded:
        ... ...
      else:
        if build_save:
          # 为每个saveables中的op添加保存op,并对op_list进行组合并返回组合依赖后的输出tensor
          #(通过control_flow_ops.with_dependencies)代表运行此tensor前必须运行全部的保存op
          save_tensor = self._AddSaveOps(filename_tensor, saveables)
        if build_restore:
          restore_op = self._AddRestoreOps(filename_tensor, saveables,
                                           restore_sequentially, reshape)

    if context.in_graph_mode():
      # 正式构建并返回saver_def
      return saver_pb2.SaverDef(
          filename_tensor_name=filename_tensor.name,
          save_tensor_name=save_tensor.name,
          restore_op_name=restore_op.name,
          max_to_keep=max_to_keep,
          sharded=sharded,
          keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
          version=self._write_version)

总结:Saver初始化的时候,就已经根据传入的var,对每个var添加了对应的保存和恢复的op操作。同时构建了saverDef,是metaGraph重要的一部分。通过该saverDef可以将很多记录和参数持久化为pb,比如文件名的constant op的name,保存流程后的输出tensor的name等等。通过这些信息,就能从metaGraph中还原出Saver实例。(Saver只要知道节点name即可,就能通过name从graphDef中获得对应的op,而这些op就是保存/恢复op_list后输出的op,运行这个op即可运行之前的(依赖着的)所有的保存恢复ops。)

通过tf.train.export_meta_graph我们可以获得序列化后的metaGraph:


以上代码使用了三次export_meta_graph,分别不同:

  • 第一次未使用Saver,tf.train.export_meta_graph直接输出。
  • 第二次初始化了Saver,tf.train.export_meta_graph直接输出。
  • 第三次初始化Saver之后使用saver.export_meta_graph输出。(默认带上了saver_def)

从结果json我们发现符合我们的代码分析:

  • 第一个json不包含save/xxx节点。
  • 第二个json包含了save/xxx节点,证明了Saver在初始化了时候就已经给图中的variable加上了保存和恢复的ops。但是默认改方法不带saverDef,所以没有这个结构。tf.train.import_meta_graph函数使用后无法重建Saver,所以返回None。
  • 第三个json包含了save/xxx节点与SaverDef,因为saver的export函数默认传入了saver初始化时构建好的saverDef,这样才能在tf.train.import_meta_graph函数使用后返回重建的Saver实例,否则返回None。

3.2 Saver.save() 函数:

  def save(self,
           sess,
           save_path,
           global_step=None,
           latest_filename=None,
           meta_graph_suffix="meta",   # meta_graph默认后缀名.meta
           write_meta_graph=True,   # 若改False,则不会生成.meta
           write_state=True):   # 若改False,则默认保存所有模型文件且无checkpoint文件

需要传入sess,因为当前的session持有着变量相关信息,而save一定会运行Saver()类初始化时候定义的ops从而持久化变量数据(.data与.index)。
write_meta_graph=True代表保存变量数值(.data与.index)的同时会保存metaGraph。而上面已经介绍过,metaGraph中持有重新构建图的所有信息。write_state=True则会默认生成checkpoint文件自动记录训练文件名称。

关键的恢复代码:

        if context.in_graph_mode():
          model_checkpoint_path = sess.run(
              self.saver_def.save_tensor_name,
              {self.saver_def.filename_tensor_name: checkpoint_file})

这一步就是运行之前Saver()初始化之后,创建的saver_def中的op:save_tensor_name。这一个op实际上是graph_def中定义的一个node:"save/control_dependency:0"。而这个op本身是无意义的,其实是为了调用它所依赖的variables身上的store op。同时传入文件名参数。这样就执行了保持variable数值的ops。

但是我们知道save()方法不光光如此,它还会生成.metacheckpoint

        if write_state:
          self._RecordLastCheckpoint(model_checkpoint_path)
          _update_checkpoint_state(
              save_dir=save_path_parent,
              model_checkpoint_path=model_checkpoint_path,
              all_model_checkpoint_paths=self.last_checkpoints,
              latest_filename=latest_filename,
              save_relative_paths=self._save_relative_paths)
           self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix)

上述两个操作分别是更新checkpoint文件以及删除过时的旧模型文件。其实checkpoint本身也是pb,只不过它不影响效率,就使用text_format.MessageToString(ckpt)将pb message转换为了text,方便直接打开看和修改。

以下是生成metaGraph(.meta文件)的关键代码:

    if write_meta_graph:
      meta_graph_filename = self._MetaGraphFilename(
          checkpoint_file, meta_graph_suffix=meta_graph_suffix)
      if context.in_graph_mode():
        with sess.graph.as_default():
          self.export_meta_graph(meta_graph_filename)
... ...

def saver().export_meta_graph:
return export_meta_graph(
        filename=filename,
        graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
        saver_def=self.saver_def,
        collection_list=collection_list,
        as_text=as_text,
        export_scope=export_scope,
        clear_devices=clear_devices,
        clear_extraneous_savers=clear_extraneous_savers)

本质调用了Saver().export_meta_graph() ----> tf.train.export_meta_graph() 最关键的是加入了该类存储着的saver_def,因此输出的.meta文件里是包含saver_def的,下次可以用来恢复Saver(其实是记录着restore_all关键node的name,Saver()初始化的时候已经向graph_def里添加好了所有的save/restore的node)。

3.2 Saver.restore() 函数:

这个函数就非常简单了,关键代码:

    if context.in_graph_mode():
      sess.run(self.saver_def.restore_op_name,
               {self.saver_def.filename_tensor_name: save_path})

就是导入.meta文件后,里面的saver_def二进制流信息重建并返回了Saver实例,然后就能获取到restore所有variables的那个op的名字,然后去运行。该op:restore_op_name: "save/restore_all"同样也是依赖于所有variable的assign操作,即变量赋值。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,053评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,527评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,779评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,685评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,699评论 5 366
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,609评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,989评论 3 396
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,654评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,890评论 1 298
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,634评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,716评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,394评论 4 319
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,976评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,950评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,191评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 44,849评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,458评论 2 342