汽车识别篇
看看训练的模型到底有没有学到什么东西
首先在项目文件夹下打开终端,运行jupyter notebook,编写一个汽车识别的函数。它接受模型,汽车类别,以及一张预测的图片。返回预测图片的汽车类型和置信度。
from models import VGG10
import torch
from PIL import Image
from torchvision import transforms as tfs
from dataset import data_set
import matplotlib.pyplot as plt
def car_classifier(model, classes, img):
data_tfs = tfs.Compose([tfs.Resize([160,160]),
tfs.ToTensor(),
tfs.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
img = data_tfs(img).unsqueeze(0)
output = model(img)
p = torch.softmax(output, dim=1).detach().numpy()[0]
_, pred = torch.max(output, 1)
print('预测车型为:{}, 置信度为:{:.1f}%'.format(classes[pred.item()], p[pred.item()]*100))
看一下训练的汽车包括那些
classes = data_set.classes
model = VGG10('VGG10', 10)
model.eval()
model.load_state_dict(torch.load('./checkpoint/weights.pkl'))
print(classes)
['MINI', '奔驰C级', '奥迪A4L', '宋MAX', '宝马3系', '思域', '朗逸', '福克斯', '途观L', '速腾']
我从汽车之家论坛摘下了6张图片,送入模型,看看模型能否识别成功。
大众套娃果然厉害,朗逸、途观L,模型傻傻分不清楚