TensorFlow模型导出


TensorFlow常见模型

  • Checkpoint(.ckpt):二进制文件,保存模型权值,不包含图结构。由tf.train.Saver()对象调用 saver.save()生成,由saver.restore(session, checkpoint_path)载入。
  • GraphDef(.pb):包含 protobuf 对象序列化后的数据,包含计算图,但不包含权值。
  • FrozenGraph:使用tensorflow/python/tools/freeze_graph.py对上述2项进行整合得到。
  • SavedModel:包含权值和计算图。使用tf.saved_model.signature_def_utils.build_signature_def()构建signature_def签名;使用build_tensor_info方法将输入Tensor和输出Tensor相关信息序列化为TensorInfo Protocol Buffer;使用builder.add_meta_graph_and_variables()add_meta_graph()增加Graph和variables。使用 tf.saved_model.loader.load()来恢复模型。

验证导出模型代码示例

with tf.Session() as sess:
  meta_graph_def = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_dir = saved_model_dir)

  # 从meta_graph_def中取出SignatureDef对象
  signature = meta_graph_def.signature_def

  # 从signature中找出具体输入输出的tensorname
  signature_key = 'predict_images'
  input_key = "images"
  shape_key = "image_shape"
  output_key = "scores"
  x_input_tensor_name = signature[signature_key].inputs[input_key].name
  x_shape_tensor_name = signature[signature_key].inputs[shape_key].name
  y_tensor_name = signature[signature_key].inputs[output_key].name

  # 获取输入输出tensor
  x_input = sess.graph.get_tensor_by_name(x_input_tensor_name)
  x_shape = sess.graph.get_tensor_by_name(x_shape_tensor_name)
  y = sess.graph.get_tensor_by_name(y_tensor_name)

  # inference
  img = Image.open("C:/lena.jpg")
  img = img.resize((64, 64))
  img = img.tobytes()

  emb = sess.run(y, feed_dict={x_input: img, x_shape: [64, 64, 0, 0, 64, 64, 64, 64]})
  print(emb)

参考资料

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容