先上代码
import tensorflow as tf
import numpy as np
from tensorflow.python.platform import gfile
from tensorflow.python.lib.io import file_io
input_tensor_key = 'Placeholder:0'
def loadNpData(filename):
tensor_key_feed_dict = {}
#inputs = preprocess_inputs_arg_string(inputs_str)
data = np.load(file_io.FileIO(filename, mode='r'))
# When no key is specified for the input file.
# Check if npz file only contains a single numpy ndarray.
if isinstance(data, np.lib.npyio.NpzFile):
variable_name_list = data.files
if len(variable_name_list) != 1:
raise RuntimeError(
'Input file %s contains more than one ndarrays. Please specify '
'the name of ndarray to use.' % filename)
tensor_key_feed_dict[input_tensor_key] = data[variable_name_list[0]]
else:
tensor_key_feed_dict[input_tensor_key] = data
return tensor_key_feed_dict
with tf.Session() as sess:
# 定义模型文件及样本测试文件
model_filename = 'merge1_graph.pb'
example_png = 'examples.npy'
# 加载npy格式的图片测试样本数据
image_data = loadNpData(example_png)
#加载模型文件
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef();
graph_def.ParseFromString(f.read())
# 获取输入节点的tensor
inputs = sess.graph.get_tensor_by_name("Placeholder:0");
#打印输入节点的信息
#print inputs
# 导入计算图,定义输入节点及输出节点
output = tf.import_graph_def(graph_def, input_map={'Placeholder:0':inputs}, return_elements=[ 'ArgMax:0','Softmax:0'])
# 打印输出节点的信息
#print output
results = sess.run(output, feed_dict={inputs:image_data[input_tensor_key]})
print 'ArgMax result(预测结果对应的标签值):'
print results[0]
print 'Softmax result(最后一层的输出):'
print results[1]
# 输出node详细信息,此处默认只打印第一个节点
for node in graph_def.node:
print node
break
运行输出
ArgMax result(预测结果对应的标签值):
[3 3]
Softmax result(最后一层的输出):
[[4.1668140e-12 9.0696268e-18 6.4261091e-13 9.9999940e-01 1.7161388e-30
5.4321697e-07 7.6357951e-09 6.3293229e-19 1.3812791e-13 1.5360580e-12]
[1.1472046e-05 3.3404951e-10 6.0365837e-09 9.9997592e-01 9.8635665e-15
5.7557719e-07 1.1977763e-05 1.6275100e-16 7.2288098e-10 5.0601763e-08]]
此处加载的关键在于tf.import_graph_def
函数的参数配置,三个参数graph_def
input_map
return_elements
第一个参数是导入的图
input_map
是指定输入节点,如果不指定,后面run的时候会报错 ==You must feed a value for placeholder tensor 'Placeholder'==
return_elements
是指定运算后的输出节点,此处就是我们想要得到的标签估计值 ArgMax
以及 最后一层节点输出 Softmax
模型的测试参考 将Tensorflow模型导出为单个文件