通常训练模型的时候是保存ckpt方便接着训练,但是上线可以保存为pb模型,加载的时候不需要重新定义模型,只用输入输出来调用模型。
import os
import tensorflow as tf
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import (signature_constants, signature_def_utils, tag_constants, utils)
class model():
def __init__(self):
self.a = tf.placeholder(tf.float32, [None])
self.w = tf.Variable(tf.constant(2.0, shape=[1]), name="w")
b = tf.Variable(tf.constant(0.5, shape=[1]), name="b")
self.y = self.a * self.w + b
#模型保存为ckpt
def save_model():
graph1 = tf.Graph()
with graph1.as_default():
m = model()
with tf.Session(graph=graph1) as session:
session.run(tf.global_variables_initializer())
update = tf.assign(m.w, [10])
session.run(update)
predict_y = session.run(m.y,feed_dict={m.a:[3.0]})
print(predict_y)
saver = tf.train.Saver()
saver.save(session,"model_pb/model.ckpt")
#保存为pb模型
def export_model(session, m):
#只需要修改这一段,定义输入输出,其他保持默认即可
model_signature = signature_def_utils.build_signature_def(
inputs={"input": utils.build_tensor_info(m.a)},
outputs={
"output": utils.build_tensor_info(m.y)},
method_name=signature_constants.PREDICT_METHOD_NAME)
export_path = "pb_model/1"
if os.path.exists(export_path):
os.system("rm -rf "+ export_path)
print("Export the model to {}".format(export_path))
try:
legacy_init_op = tf.group(
tf.tables_initializer(), name='legacy_init_op')
builder = saved_model_builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
session, [tag_constants.SERVING],
clear_devices=True,
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
model_signature,
},
legacy_init_op=legacy_init_op)
builder.save()
except Exception as e:
print("Fail to export saved model, exception: {}".format(e))
#加载pb模型
def load_pb():
session = tf.Session(graph=tf.Graph())
model_file_path = "pb_model/1"
meta_graph = tf.saved_model.loader.load(session, [tf.saved_model.tag_constants.SERVING], model_file_path)
model_graph_signature = list(meta_graph.signature_def.items())[0][1]
output_tensor_names = []
output_op_names = []
for output_item in model_graph_signature.outputs.items():
output_op_name = output_item[0]
output_op_names.append(output_op_name)
output_tensor_name = output_item[1].name
output_tensor_names.append(output_tensor_name)
print("load model finish!")
sentences = {}
# 测试pb模型
for test_x in [[1],[2],[3],[4],[5]]:
sentences["input"] = test_x
feed_dict_map = {}
for input_item in model_graph_signature.inputs.items():
input_op_name = input_item[0]
input_tensor_name = input_item[1].name
feed_dict_map[input_tensor_name] = sentences[input_op_name]
predict_y = session.run(output_tensor_names, feed_dict=feed_dict_map)
print("predict pb y:",predict_y)
if __name__ == "__main__":
save_model()
graph2 = tf.Graph()
with graph2.as_default():
m = model()
saver = tf.train.Saver()
with tf.Session(graph=graph2) as session:
saver.restore(session, "model_pb/model.ckpt") #加载ckpt模型
export_model(session, m)
load_pb()
save_model 和load_model两个函数要分开执行,第一次注释掉load,只save,第二次load的时候注释掉save。因为声明模型的时候都是用默认图,变量命名会依次是0,1,load的时候名字对应不上。
保存好的pb模型路径文件格式为
还有一种保存pb模型的方法,保存的东西只有model.pb没有variables。
#保存为pb模型,只有model.pb 没有variables
def export_model_one(session, m, graph):
output_names = [m.y.op.name] # 定义模型输出
input_graph_def = graph.as_graph_def()
print(m.y.op.name, m.a.op.name)
output_graph_def = convert_variables_to_constants(session, input_graph_def, output_names)
output_graph = 'pb_model/model.pb' # 保存地址
with tf.gfile.GFile(output_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
#加载pb模型, 只有model.pb 没有variables
def load_pb_one():
graph = tf.Graph()
with graph.as_default():
output_graph_def = tf.GraphDef()
with open( 'pb_model/model.pb' , "rb") as f:
output_graph_def.ParseFromString(f.read())
tensors = tf.import_graph_def(output_graph_def, name="")
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
input = sess.graph.get_tensor_by_name("Placeholder:0")
output = sess.graph.get_tensor_by_name("add:0")
# 测试pb模型
for test_x in [[1], [2], [3], [4], [5]]:
predict_y = sess.run(output, feed_dict={input: test_x})
print("predict pb y:", predict_y)
如果需要发布模型成服务,请看:https://www.jianshu.com/p/5b74f1bc0178