TensorFlow常见模型
- Checkpoint(.ckpt):二进制文件,保存模型权值,不包含图结构。由
tf.train.Saver()
对象调用 saver.save()
生成,由saver.restore(session, checkpoint_path)
载入。
- GraphDef(.pb):包含 protobuf 对象序列化后的数据,包含计算图,但不包含权值。
- FrozenGraph:使用
tensorflow/python/tools/freeze_graph.py
对上述2项进行整合得到。
- SavedModel:包含权值和计算图。使用
tf.saved_model.signature_def_utils.build_signature_def()
构建signature_def签名;使用build_tensor_info
方法将输入Tensor和输出Tensor相关信息序列化为TensorInfo Protocol Buffer;使用builder.add_meta_graph_and_variables()
和add_meta_graph()
增加Graph和variables。使用 tf.saved_model.loader.load()
来恢复模型。
验证导出模型代码示例
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_dir = saved_model_dir)
# 从meta_graph_def中取出SignatureDef对象
signature = meta_graph_def.signature_def
# 从signature中找出具体输入输出的tensorname
signature_key = 'predict_images'
input_key = "images"
shape_key = "image_shape"
output_key = "scores"
x_input_tensor_name = signature[signature_key].inputs[input_key].name
x_shape_tensor_name = signature[signature_key].inputs[shape_key].name
y_tensor_name = signature[signature_key].inputs[output_key].name
# 获取输入输出tensor
x_input = sess.graph.get_tensor_by_name(x_input_tensor_name)
x_shape = sess.graph.get_tensor_by_name(x_shape_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)
# inference
img = Image.open("C:/lena.jpg")
img = img.resize((64, 64))
img = img.tobytes()
emb = sess.run(y, feed_dict={x_input: img, x_shape: [64, 64, 0, 0, 64, 64, 64, 64]})
print(emb)
参考资料