from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions
from keras.preprocessing.image import load_img, img_to_array
import numpy as np
import os
import re
path = '/content/drive/My Drive/test/calibration_set'
val_path = '/content/drive/My Drive/test/calibration_set/val.txt'
synsets_path = '/content/drive/My Drive/test/calibration_set/synsets.txt'
def ReadTxtName(rootdir): # 读取txt文件内容
lines = []
with open(rootdir, 'r') as file_to_read:
while True:
line = file_to_read.readline()
if not line:
break
line = line.strip('\n')
lines.append(line)
return lines
val = ReadTxtName(val_path)
synset = ReadTxtName(synsets_path)
# 建立图片文件名到label的映射关系
map_list = []
num_list = []
for map in val:
val_num = re.findall('val_(\d+)', map)
val_mapping = re.findall('JPEG\s(\d+)', map)
num_list.append(val_num)
map_list.append(val_mapping)
a = [x for y in num_list for x in y]
b = [x for y in map_list for x in y]
synset_list = []
for x in b:
idx = int(x)
synset_list.append(synset[idx])
map_dict = dict(zip(a,synset_list)) # 映射字典
label = []
for imgs in os.listdir(path):
if imgs.endswith('.JPEG'):
number = re.findall('val_(\d+)',imgs)
lbl = map_dict[number[0]]
label.append(lbl)
label = np.array(label) # 根据映射字典获取测试集标签
def vgg_predict(img): # 返回top5的编号
image_data = img_to_array(img)
image_data = image_data.reshape((1,) + image_data.shape)
image_data = preprocess_input(image_data)
prediction = model.predict(image_data)
results = decode_predictions(prediction, top=5)
results = np.array(results)
result_num = np.squeeze(results[:,:,0])
return result_num
model = VGG16(weights='imagenet', include_top=True)
pre_result = []
for imgs in os.listdir(path):
if imgs.endswith('.JPEG'):
img_path = os.path.join(path,imgs)
img = load_img(img_path,target_size=(224, 224))
result = vgg_predict(img)
pre_result.append(result)
pre_result = np.array(pre_result)
cnt = 0
for i in range(1000): # 判断预测结果中是否包含真实标签
y = label[i]
y_hat = pre_result[i]
if(y_hat.__contains__(y)):
cnt += 1
print('top5 error:'1-cnt/1000) # top5 error: 0.136