tensorflo的retrain.py进行重新训练

说一说调试的尴尬事

这是flower_photos图片地址
(http://download.tensorflow.org/example_images/flower_photos.tgz)

retain>python retrain.py --bottleneck_dir=./bottlenecks --how_many_training_steps=10 --output_graph=retrained_graph.pb --output_labels=retrained_labels.txt --summaries_dir=./retrain_logs --image_dir=./flower_photos

在含有retrain.py文件家里输入后出现如下错误:(未解决哈,谁能留言给点提示)
也就是说不能在flower_photos出现超过一个文件夹,一个文件夹是可以的。

2018-08-29 21:02:20.606407: W T:\src\github\tensorflow\tensorflow\core\framework\op_def_util.cc:346] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
Looking for images in 'dandelion'
Looking for images in 'roses'
Looking for images in 'sunflowers'
Looking for images in 'tulips'
More than one folder found inside flower_photos directory.
In order to prevent validation issues, put all the training images into 
one folder inside flower_photos directory and delete everything else inside the flower_photos directory.

输入简单版:

python retrain.py --image_dir=./flower_photos

结果:上面的未解决

INFO:tensorflow:2018-08-29 20:54:08.346688: Step 3990: Train accuracy = 97.0%
INFO:tensorflow:2018-08-29 20:54:08.347188: Step 3990: Cross entropy = 0.140886
INFO:tensorflow:2018-08-29 20:54:08.562333: Step 3990: Validation accuracy = 87.0% (N=100)
INFO:tensorflow:2018-08-29 20:54:10.727268: Step 3999: Train accuracy = 97.0%
INFO:tensorflow:2018-08-29 20:54:10.727770: Step 3999: Cross entropy = 0.136108
INFO:tensorflow:2018-08-29 20:54:10.953419: Step 3999: Validation accuracy = 90.0% (N=100)
INFO:tensorflow:Final test accuracy = 91.1% (N=744)
INFO:tensorflow:Froze 2 variables.
Converted 2 variables to const ops.

出现如下的文件了:这是自动生成的


好吧,上一段代码看看测试结果

import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt

lines = tf.gfile.GFile('output_labels.txt').readlines()
uid_to_human = {}
#一行一行读取数据
for uid,line in enumerate(lines) :
    #去掉换行符
    line=line.strip('\n')
    uid_to_human[uid] = line

def id_to_string(node_id):
    if node_id not in uid_to_human:
        return ''
    return uid_to_human[node_id]


#创建一个图来存放google训练好的模型
with tf.gfile.FastGFile('output_graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    #遍历目录
    for root,dirs,files in os.walk('images/'):
        for file in files:
            #载入图片
            image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
            predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
            predictions = np.squeeze(predictions)#把结果转为1维数据

            #打印图片路径及名称
            image_path = os.path.join(root,file)
            print(image_path)
            #显示图片
            img=Image.open(image_path)
            plt.imshow(img)
            plt.axis('off')
            plt.show()

            #排序
            top_k = predictions.argsort()[::-1]
            print(top_k)
            for node_id in top_k:     
                #获取分类名称
                human_string = id_to_string(node_id)
                #获取该分类的置信度
                score = predictions[node_id]
                print('%s (score = %.5f)' % (human_string, score))
            print()

这是结果当然会有的哟

images/timg[7].jpg
[3 4 2 0 1]
sunflowers (score = 0.83506)
daisy (score = 0.15763)
tulips (score = 0.00446)
roses (score = 0.00214)
dandelion (score = 0.00071)

这ju花被识别为太阳花,可能两种花都和太阳有关吧!!!!!!!!!!!

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容