tensorflow提供两种模型格式
- checkpoint:依赖于创建模型的代码
- SavedModel:与模型代码无关
这里尽介绍checkpoint
1. 保存经过部分训练的模型
Estimator自动将如下内容写入磁盘
- checkpoints: 训练期间所创建的模型版本
- event files: 包含有TensorBoard用于创建可视化图标的全部信息
要指定模型的顶级存储目录,可以使用Estimator构造函数的可选参数model_dir
,设置代码如下所示
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir="./models_dir")
当调用Estimator的train
方法时,Estimator会将checkpoint和其他文件保存到model_dir
目录中,保存之后,这个目录中的文件如下所示:
checkpoint
events.out.tfevents.timestamp.hostname
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta
model.ckpt-200.data-00000-of-00001
model.ckpt-200.index
model.ckpt-200.meta
这个目录存储的是Estimator在第一步训练开始和第200不训练结束时创建的checkpoints
2. checkpoint频率
默认情况下,Estimator按照如下时间将checkpoint保存到model_dir
中
- 每600秒保存一次
- 在
train
方法开始以及完成时都要保存checkpoint - 在目录中最多保留5个最近的checkpoints
可以通过如下步骤来更改默认设置:
- 创建
RunConfig
对象来自定义设置 - 在实例化Estimator时,将该
RunConfig
对象传递个Estimatro的config
参数
my_checkpointing_config = tf.estimator.RunConfig(
save_checkpoints_secs = 20*60,
keep_checkpoint_max = 10,
)
3. 从checkpoint中恢复模型
在第一次调用Estimator的train
方法时,Tensorflow会将checkpoint保存到model_dir
中,随后每次调用Estimator的train
、eval
或者predict
方法时,都会发生下列情况:
- Esitmator运行
model_fun()
构建模型图 - Estimator根据最近写入的checkpoint中存储的数据来初始化新模型的权重
4. 避免不当恢复
通过checkpoint恢复模型的状态必须保证模型和checkpoint保存的兼容才可以。例如我们训练了一个DNNClassifier
Estimator,它包含有2个隐藏层且每层都有10个节点,经过训练兵保存了checkpoint到model_dir
中。后续在训练的时候,假如将代码中的隐藏层修改为了每层20个节点,这样用这样的Estimator调用train
时就回报错,因为checkpoint保存的模型结构与代码中的模型是不兼容的。这一点切记。