下面要实现的功能是:g1和g2并联,placeholder输入x是3.0, g1实现系y=3*x,g2实现y+3, 最后输出12
文件model_b.py如下:
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.tools import saved_model_utils
MODEL_SAVE_PATH = "./models/" # 保存模型的路径
with tf.Graph().as_default() as g2:
input1 = tf.placeholder(tf.float32,name='g2_input')
data = tf.Variable(3.)
mul = tf.add(input1,data)
tf.identity(mul,name='g2_output')
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session(graph=g2) as sess:
sess.run(init)
g1def = graph_util.convert_variables_to_constants(sess,sess.graph_def,["g2_output"],
variable_names_whitelist=None,variable_names_blacklist=None)
#tf.train.write_graph(g1def, MODEL_SAVE_PATH, 'model_g2.pb', as_text=False)
saver.save(sess, "./models/g2_model.ckpt")
文件model_combined.py如下:
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.tools import saved_model_utils
MODEL_SAVE_PATH = "./models/" # 保存模型的路径
#g1和g2并联,输入x是3.0, g1实现系y=3*x,g2实现y+3, 最后输出12
with tf.Graph().as_default() as g1:
input1 = tf.placeholder(tf.float32,name='g1_input')
data = tf.Variable(3.)
mul = tf.multiply(input1,data)
tf.identity(mul,name='g1_output')
init = tf.global_variables_initializer()
with tf.Session(graph=g1) as sess:
sess.run(init)
g1def = graph_util.convert_variables_to_constants(sess, sess.graph_def,["g1_output"],
variable_names_whitelist=None,
variable_names_blacklist=None)
with tf.Graph().as_default() as g2:
with tf.Session(graph=g2) as sess:
saver=tf.train.import_meta_graph('./models/g2_model.ckpt.meta')
saver.restore(sess, './models/g2_model.ckpt')
g2def = graph_util.convert_variables_to_constants(sess,sess.graph_def,["g2_output"])
##------------------------------------------------------------
with tf.Graph().as_default() as g_combined:
with tf.Session(graph=g_combined) as sess:
x = tf.placeholder(tf.float32, name="my_input")
y = tf.import_graph_def(g1def, input_map={"g1_input:0": x}, return_elements=["g1_output:0"])
z, = tf.import_graph_def(g2def, input_map={"g2_input:0": y}, return_elements=["g2_output:0"])
tf.identity(z, "my_output")
print(sess.run(z,feed_dict={'my_input:0':3.}))
#保存1
#g_combineddef = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["my_output"])
#tf.train.write_graph(g_combineddef, MODEL_SAVE_PATH, 'my_model.pb', as_text=False)
#保存2
# tf.saved_model.simple_save(sess,
# "./modelbase",
# inputs={"my_input": x},
# outputs={"my_output": z})