调用Keras的vgg16模型进行测试

from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions
from keras.preprocessing.image import load_img, img_to_array
import numpy as np
import os
import re

path = '/content/drive/My Drive/test/calibration_set'
val_path = '/content/drive/My Drive/test/calibration_set/val.txt'
synsets_path = '/content/drive/My Drive/test/calibration_set/synsets.txt'
def ReadTxtName(rootdir):  # 读取txt文件内容
    lines = []
    with open(rootdir, 'r') as file_to_read:
        while True:
            line = file_to_read.readline()
            if not line:
                break
            line = line.strip('\n')
            lines.append(line)
    return lines


val = ReadTxtName(val_path)
synset = ReadTxtName(synsets_path)

#  建立图片文件名到label的映射关系
map_list = []
num_list = []
for map in val:
  val_num = re.findall('val_(\d+)', map)
  val_mapping = re.findall('JPEG\s(\d+)', map)
  num_list.append(val_num)
  map_list.append(val_mapping)
  
a = [x for y in num_list for x in y]
b = [x for y in map_list for x in y]

synset_list = []
for x in b:
  idx = int(x)
  synset_list.append(synset[idx])
map_dict = dict(zip(a,synset_list))  # 映射字典


label = []
for imgs in os.listdir(path):
  if imgs.endswith('.JPEG'):
    number = re.findall('val_(\d+)',imgs)
    lbl = map_dict[number[0]]
    label.append(lbl)

label = np.array(label)  # 根据映射字典获取测试集标签
def vgg_predict(img):  # 返回top5的编号
  image_data = img_to_array(img)
  image_data = image_data.reshape((1,) + image_data.shape)
  image_data = preprocess_input(image_data)
  prediction = model.predict(image_data)
  results = decode_predictions(prediction, top=5)
  results = np.array(results)
  result_num = np.squeeze(results[:,:,0])
  return result_num


model = VGG16(weights='imagenet', include_top=True)

pre_result = []
for imgs in os.listdir(path):
  if imgs.endswith('.JPEG'):
    img_path = os.path.join(path,imgs)
    img = load_img(img_path,target_size=(224, 224))
    result = vgg_predict(img)
    pre_result.append(result)
pre_result = np.array(pre_result)  


cnt = 0
for i in range(1000):  # 判断预测结果中是否包含真实标签
  y = label[i]
  y_hat = pre_result[i]
  if(y_hat.__contains__(y)):
    cnt += 1
print('top5 error:'1-cnt/1000)  # top5 error: 0.136
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。