说一说调试的尴尬事
这是flower_photos图片地址
(http://download.tensorflow.org/example_images/flower_photos.tgz)
retain>python retrain.py --bottleneck_dir=./bottlenecks --how_many_training_steps=10 --output_graph=retrained_graph.pb --output_labels=retrained_labels.txt --summaries_dir=./retrain_logs --image_dir=./flower_photos
在含有retrain.py文件家里输入后出现如下错误:(未解决哈,谁能留言给点提示)
也就是说不能在flower_photos出现超过一个文件夹,一个文件夹是可以的。
2018-08-29 21:02:20.606407: W T:\src\github\tensorflow\tensorflow\core\framework\op_def_util.cc:346] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
Looking for images in 'dandelion'
Looking for images in 'roses'
Looking for images in 'sunflowers'
Looking for images in 'tulips'
More than one folder found inside flower_photos directory.
In order to prevent validation issues, put all the training images into
one folder inside flower_photos directory and delete everything else inside the flower_photos directory.
输入简单版:
python retrain.py --image_dir=./flower_photos
结果:上面的未解决
INFO:tensorflow:2018-08-29 20:54:08.346688: Step 3990: Train accuracy = 97.0%
INFO:tensorflow:2018-08-29 20:54:08.347188: Step 3990: Cross entropy = 0.140886
INFO:tensorflow:2018-08-29 20:54:08.562333: Step 3990: Validation accuracy = 87.0% (N=100)
INFO:tensorflow:2018-08-29 20:54:10.727268: Step 3999: Train accuracy = 97.0%
INFO:tensorflow:2018-08-29 20:54:10.727770: Step 3999: Cross entropy = 0.136108
INFO:tensorflow:2018-08-29 20:54:10.953419: Step 3999: Validation accuracy = 90.0% (N=100)
INFO:tensorflow:Final test accuracy = 91.1% (N=744)
INFO:tensorflow:Froze 2 variables.
Converted 2 variables to const ops.
出现如下的文件了:这是自动生成的
好吧,上一段代码看看测试结果
import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt
lines = tf.gfile.GFile('output_labels.txt').readlines()
uid_to_human = {}
#一行一行读取数据
for uid,line in enumerate(lines) :
#去掉换行符
line=line.strip('\n')
uid_to_human[uid] = line
def id_to_string(node_id):
if node_id not in uid_to_human:
return ''
return uid_to_human[node_id]
#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile('output_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
#遍历目录
for root,dirs,files in os.walk('images/'):
for file in files:
#载入图片
image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
predictions = np.squeeze(predictions)#把结果转为1维数据
#打印图片路径及名称
image_path = os.path.join(root,file)
print(image_path)
#显示图片
img=Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show()
#排序
top_k = predictions.argsort()[::-1]
print(top_k)
for node_id in top_k:
#获取分类名称
human_string = id_to_string(node_id)
#获取该分类的置信度
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
print()
这是结果当然会有的哟
images/timg[7].jpg
[3 4 2 0 1]
sunflowers (score = 0.83506)
daisy (score = 0.15763)
tulips (score = 0.00446)
roses (score = 0.00214)
dandelion (score = 0.00071)