最近在用 java 改写一个用 python 编写的 model,遇到了有关模型保存与恢复的问题,发现网上的资料有些混乱,在这里做一些记录。
.ckpt
1. .ckpt 全称为 checkpoint,代表着一个检查点,即为 model 训练过程中的一个快照,可能是在训练开始,也可能是在训练完成。
2. .ckpt 是由 Saver 调用 save 产生的:
saver.save(sess,"/tmp/model.ckpt")
3. 由 Saver 调用 restore 来复原 model 的数据:
saver.restore(sess,path)
注意这里,复原的只有数据,不含 graph 信息。
4. .ckpt 不是单独的一个文件,而是一系列文件。
其内部包含了:
①checkpoint: .ckpt 的标记信息。
②.data: model 中 graph 的数据,包括各种变量,不含常量。
③.index: 索引信息。
④.meta: graph 信息。
在这里要搞明白一点,一个 model 是由 graph(④) + 数据(②) 组成的。
graph 代表着执行逻辑,在 tensorflow 中,每个算子用一个 node 来表示,众多 node 组合起来便是一张图(graph),也就是我们的执行逻辑,而这些执行逻辑在 Saver 调用 save 时,会被存到 .meta 中(不含数据)。各个 node 中含有各种参数(变量,比如训练的权重),这些参数则被存储到 .data 中。graph 与数据是分别存储的。
tf.train.import_meta_graph
该方法只能恢复 graph,不恢复数据。
注意与上面提及的 saver.restore 区分,saver.restore 只恢复数据,不恢复 graph。
recover model
现在我们来讨论下,如何能恢复一个model。前面已经提过了,一个 model 由 graph 和 数据组成,所以只要能恢复这两部分就可以了,依据恢复的方法不同,可以分为两类。
①分别恢复 graph 和数据:
对于数据来说,可以用 saver.restore 来恢复。
对于graph来说,依据恢复方法不同可以分为两种:
A.硬编码恢复:在调用方法中,重新书写 graph 信息。
B. .meta 恢复:通过调用 tf.train.import_meta_graph 方法获得 graph,并配合 get_tensor_by_name 的方法来调用 model 中特定的算子(node)。
saver = tf.train.import_meta_graph('~/tmp/model.ckpt-1000.meta')
graph = tf.get_default_graph()
input = graph.get_tensor_by_name('input:0')
② freezing(固化):
该方法将变量(训练的权重)固化在 graph 中,即用常量来替换 graph 中的变量,从而达到无需恢复数据,直接调用 graph 即可。权重一旦被固化就不能再修改,该方法一般用于生产环境。
注:笔者在测试 Java API 时,其只支持调用 freezing 后的图。
References: