-
Estimator
# 使用scaffold加载模型参数并导出模型 model = tf.estimator.Estimator(model_fn=model_fn, params=params, model_dir=model_dir, config=config) def serving_input_fn(): input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ 'x': tf.placeholder(tf.int32, [None, 4], name='x') })() return input_fn model.export_saved_model(model_dir, serving_input_fn, checkpoint_path)
-
Sess 对图进行序列化
import tensorflow as tf from tensorflow.python.framework import graph_util def freeze_graph(path='model.ckpt', output='model.pb'): saver = tf.train.import_meta_graph(path+'.meta', clear_devices=True) graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() with tf.Session() as sess: saver.restore(sess, path) output_graph_def = graph_util.convert_variables_to_constants( sess=sess, input_graph_def=input_graph_def, # = sess.graph_def, output_node_names=['output/scores']) with tf.gfile.GFile(output, 'wb') as fgraph: fgraph.write(output_graph_def.SerializeToString()) freeze_graph('model.ckpt', 'model.pb')
-
测试模型
def model(path='./logs/pb/model.pb'): with tf.gfile.GFile(path, 'rb') as fgraph: graph_def = tf.GraphDef() graph_def.ParseFromString(fgraph.read()) with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def, name='') input_x = graph.get_tensor_by_name('input_x:0') keep_prob = graph.get_tensor_by_name('dropout_keep_prob:0') pred = graph.get_tensor_by_name('output/scores:0') sess = tf.Session(graph=graph) return sess, input_x, keep_prob, pred def predict(x, sess, input_x, keep_prob, pred, topk=1): feed_dict = { input_x: x, keep_prob: 1.0 } prob = sess.run(pred, feed_dict=feed_dict) return prob
-
加载 pb 模型
def load_pb(path='model.pb'): with tf.gfile.GFile(path, 'rb') as fgraph: graph_def = tf.GraphDef() graph_def.ParseFromString(fgraph.read()) return graph_def
-
将两个图合并
def combined_graph(): with tf.Graph().as_default() as g_combine: with tf.Session(graph=g_combine) as sess: graph_a = load_pb('./logs/pb/model.pb') graph_b = load_pb('./logs/pb/model_rnn.pb') tf.import_graph_def(graph_a, name='') tf.import_graph_def(graph_b, name='') g_combine_def = graph_util.convert_variables_to_constants( sess=sess, input_graph_def=sess.graph_def, output_node_names=['output/scores_rnn', 'output/scores']) tf.train.write_graph(g_combine_def, './logs/pb/', 'model_combine.pb', as_text=False)
-
读取 ckpt 模型权重
def read_ckpt(ckpt): reader = tf.train.NewCheckpointReader(ckpt) weights = {n: reader.get_tensor(n) for (n, _) in reader.get_variable_to_shape_map().items()} return weights weights = read_ckpt('model.ckpt')
将ckpt转成pb文件
©著作权归作者所有,转载或内容合作请联系作者
- 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
- 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
- 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...