载入文件
参考链接:https://medium.com/@alexkn15/tensorflow-save-model-for-use-in-java-or-c-ab351a708ee4
model_filename = './saved_model.pb'
with tf.gfile.GFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# tensorflow adds "import/" prefix to all tensors when imports graph definition, ex: "import/input:0"
# so we explicitly tell tensorflow to use empty string -> name=""
tf.import_graph_def(graph_def, name="")
print(tf.get_default_graph().get_operations()) # just print all operations for debug
除此以外,查看载入图中的节点信息
node_list = graph_def.node
##其数据结构类似
print node_list[-1]
##----------------------
name: "prediction"
op: "Reshape"
input: "score"
input: "prediction/shape"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "Tshape"
value {
type: DT_INT32
}
}
查看所有节点的名字就
for node in graph_def.node
print node.name
如果要找出所有Placeholder就
placeholder_nodes = [node for node in graph_def.node if node.op == "Placeholder"]