代码载入训练好的模型,对输入图片分类识别,并观察识别的结果
一、搭建环境并训练模型
1、安装TensorFlow1.13.1
pip install tensorflow==1.13.1
2、下载TensorFlow的models模块
git clone https://github.com/tensorflow/models
3、部署TensorFlow的slim模块
models中models-master路径下的slim文件夹,复制到本地代码同级路径下
4、下载PNASNet模型
https://github.com/tensorflow/models/tree/master/research/slim -> pnasnet-5...,解压以后放在本地代码同级目录下。
5、准备ImhNet数据集标签
直接调用,获取中文标签.csv和图片
二、代码实现:初始化环境变量、载入imgnet标签
# 初始化环境变量
import sys
nets_path = r'slim'
if net_path not in sys.path:
sys.path.insert(0, nets_path)
else:
print('already add slim')
# 引入模块
import tensorflow as tf
from PIL import Image
from matplotlib import pyplot as plt # import matplotlib.pyplot as plt
from nets.nasnet import pnasnet
import numpy as np
from datasets import imagenet
slim = tf.contrib.slim
tf.reset_default_graph()
# 获取图片尺寸
image_size = pnasnet.build_pnasnet_large.default_image_size
# 获取标签
labels = imagenet.create_readable_names_for_imagenet_lables()
# 显示输出标签
print(len(labels), labels)
# 打开文件,输出中文标签
def getone(onestr):
return onestr.replace(',', ' ')
with open('中文标签.csv', 'r+') as f:
labels = list(map(getone, list(f)))
print(len(labels), type(labels), labels[:5])
三、代码实现-定义网络结构
# 步骤:定义待识别图片占位符->归一化处理占位符,生成张量x1->x1传入pnasnet对象申城处理结果和end_points->获取prob值->找到最大索引即为图片分类
# 定义待测试图片名称
sample_image = ['hy.jpg', 'ps.jpg', '72.jpg']
# 定义占位符
input_imgs = tf.placeholder(tf.float32, [None, image_size, image_size, 3])
# 归一化图片
x1 = 2 * (input_imgs / 255.0) -1.0
# 获得模型命名空间
arg_scope = pnasnet.pnasnet_large_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = pnasnet.build_pnasnet_large(x1, num_classes = 1001, is_training=False)
prob = end_points['Predictions']
y = tf.argmax(prob, axis=1) # 获得结果的输出节点
四、代码实现-载入模型进行识别
# 步骤:定义模型路径->建立会话->会话中载入预训练模型->模型识别
# 定义预训练模型路径
checkpoint_file = r'pnasnet-5_large_2017_12_13\model.ckpt)'
# 定义saver,用于加载模型
saver = tf.train.Saver()
# 建立会话,载入模型,定义图片预处理函数
with tf.Session() as sess:
saver.restore(sess, checkpoint_file)
def preimg(img):
ch = 3
if img.model =='RGBA':
ch = 4
imgnp = np.asarray(img.resize((image_size, image_size)), dtype=np.float32).reshape(image_size, image_size, ch)
return imgnp[:, : ,3]
# 获得原始图片与预处理图片
batchImg = [preimg(Image.open(imgfilename)) for imgfilename in sample_images]
orgImg = [Image.open(imgfilename) for imgfilename in sample_images]
# 输入模型
yv, img_norm = sess.run([y, x], feed_dict={input_imgs: batchImg})
print(yv, np.shape(yv))
# 定义显示图片的函数
def showresult(yy, img_norm, img_org):
plt.figure()
p1 = plt.subplot(121)
p2 = plt.subplot(121)
p1.imshow(img_org)
p1.axis('off')
p1.set_title('organization image')
p2.imshow((img_norm * 255).astype(np.uint8))
p2.axis('off')
p2.set_title(input image)
plt.show()
print(yy, labels[yy])
# 显示每条结果与图片
for yy, img1, img2 in zip(yv, batchImg, orgImg):
showresult(yy, img1, img2)