TorchVision 预训练模型进行推断

torchvision.models 里包含了许多模型,用于解决不同的视觉任务:图像分类、语义分割、物体检测、实例分割、人体关键点检测和视频分类。

本文将介绍 torchvision 中模型的入门使用,一起来创建 Faster R-CNN 预训练模型,预测图像中有什么物体吧。

import torch
import torchvision
from PIL import Image

创建预训练模型

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

print(model) 可查看其结构:

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    ...
  )
  (rpn): RegionProposalNetwork(
    (anchor_generator): AnchorGenerator()
    (head): RPNHead(
      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (cls_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
      (bbox_pred): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (roi_heads): RoIHeads(
    (box_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2)
    (box_head): TwoMLPHead(
      (fc6): Linear(in_features=12544, out_features=1024, bias=True)
      (fc7): Linear(in_features=1024, out_features=1024, bias=True)
    )
    (box_predictor): FastRCNNPredictor(
      (cls_score): Linear(in_features=1024, out_features=91, bias=True)
      (bbox_pred): Linear(in_features=1024, out_features=364, bias=True)
    )
  )
)

此预训练模型是于 COCO train2017 上训练的,可预测的分类有:

COCO_INSTANCE_CATEGORY_NAMES = [
  '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
  'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
  'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
  'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
  'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
  'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
  'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
  'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
  'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
  'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
  'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
  'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

指定 CPU or GPU

获取支持的 device

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

模型移到 device

model.to(device)

读取输入图像

img = Image.open('data/bicycle.jpg').convert("RGB")
img = torchvision.transforms.ToTensor()(img)

准备模型入参 images

images = [img.to(device)]

例图 data/bicycle.jpg

image

进行模型推断

模型切为 eval 模式:

# For inference
model.eval()

模型在推断时,只需要给到图像数据,不用标注数据。推断后,会返回每个图像的预测结果 List[Dict[Tensor]]Dict 包含字段有:

  • boxes (FloatTensor[N, 4]): 预测框 [x1, y1, x2, y2], x 范围 [0,W], y 范围 [0,H]
  • labels (Int64Tensor[N]): 预测类别
  • scores (Tensor[N]): 预测评分
predictions = model(images)
pred = predictions[0]
print(pred)

预测结果如下:

{'boxes': tensor([[750.7896,  56.2632, 948.7942, 473.7791],
        [ 82.7364, 178.6174, 204.1523, 491.9059],
        ...
        [174.9881, 235.7873, 351.1031, 417.4089],
        [631.6036, 278.6971, 664.1542, 353.2548]], device='cuda:0',
       grad_fn=<StackBackward>), 'labels': tensor([ 1,  1,  2,  1,  1,  1,  2,  2,  1, 77,  1,  1,  1,  2,  1,  1,  1,  1,
         1,  1, 27,  1,  1, 44,  1,  1,  1,  1, 27,  1,  1, 32,  1, 44,  1,  1,
        31,  2, 38,  2,  2,  1,  1, 31,  1,  1,  1,  1,  2,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  2,  2,  1,  1,  1,  2,  1,  1,  1,  1,  2,  1,  2,
         1,  1,  1,  1,  1,  1, 31,  2, 27,  1,  2,  1,  1, 31,  2, 77,  2,  1,
         2,  2,  2, 44,  2, 31,  1,  1,  1,  1], device='cuda:0'), 'scores': tensor([0.9990, 0.9976, 0.9962, 0.9958, 0.9952, 0.9936, 0.9865, 0.9746, 0.9694,
        0.9679, 0.9620, 0.9395, 0.8984, 0.8979, 0.8847, 0.8537, 0.8475, 0.7865,
        0.7822, 0.6896, 0.6633, 0.6629, 0.6222, 0.6132, 0.6073, 0.5383, 0.5248,
        0.4891, 0.4881, 0.4595, 0.4335, 0.4273, 0.4089, 0.4074, 0.3679, 0.3357,
        0.3192, 0.3102, 0.2797, 0.2655, 0.2640, 0.2626, 0.2615, 0.2375, 0.2306,
        0.2174, 0.2129, 0.1967, 0.1912, 0.1907, 0.1739, 0.1722, 0.1669, 0.1666,
        0.1596, 0.1586, 0.1473, 0.1456, 0.1408, 0.1374, 0.1373, 0.1329, 0.1291,
        0.1290, 0.1289, 0.1278, 0.1205, 0.1182, 0.1182, 0.1103, 0.1060, 0.1025,
        0.1010, 0.0985, 0.0959, 0.0919, 0.0887, 0.0886, 0.0873, 0.0832, 0.0792,
        0.0778, 0.0764, 0.0693, 0.0686, 0.0679, 0.0671, 0.0668, 0.0636, 0.0635,
        0.0607, 0.0605, 0.0581, 0.0578, 0.0572, 0.0568, 0.0557, 0.0556, 0.0555,
        0.0533], device='cuda:0', grad_fn=<IndexBackward>)}

绘制预测结果

获取 score >= 0.9 的预测结果:

scores = pred['scores']
mask = scores >= 0.9

boxes = pred['boxes'][mask]
labels = pred['labels'][mask]
scores = scores[mask]

引入 utils.plots.plot_image 绘制结果:

from utils.colors import golden
from utils.plots import plot_image

lb_names = COCO_INSTANCE_CATEGORY_NAMES
lb_colors = golden(len(lb_names), fn=int, scale=0xff, shuffle=True)
lb_infos = [f'{s:.2f}' for s in scores]
plot_image(img, boxes, labels, lb_names, lb_colors, lb_infos,
           save_name='result.png')

utils.plots.plot_image 函数实现可见后文源码,注意其要求 torchvision >= 0.9.0/nightly

image

源码

utils.colors.golden:

import colorsys
import random


def golden(n, h=random.random(), s=0.5, v=0.95,
           fn=None, scale=None, shuffle=False):
  if n <= 0:
    return []

  coef = (1 + 5**0.5) / 2

  colors = []
  for _ in range(n):
    h += coef
    h = h - int(h)
    color = colorsys.hsv_to_rgb(h, s, v)
    if scale is not None:
      color = tuple(scale*v for v in color)
    if fn is not None:
      color = tuple(fn(v) for v in color)
    colors.append(color)

  if shuffle:
    random.shuffle(colors)
  return colors

utils.plots.plot_image:

from typing import Union, Optional, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from PIL import Image


def plot_image(
  image: Union[torch.Tensor, Image.Image, np.ndarray],
  boxes: Optional[torch.Tensor] = None,
  labels: Optional[torch.Tensor] = None,
  lb_names: Optional[List[str]] = None,
  lb_colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
  lb_infos: Optional[List[str]] = None,
  save_name: Optional[str] = None,
  show_name: Optional[str] = 'result',
) -> torch.Tensor:
  """
  Draws bounding boxes on given image.
  Args:
    image (Image): `Tensor`, `PIL Image` or `numpy.ndarray`.
    boxes (Optional[Tensor]): `FloatTensor[N, 4]`, the boxes in `[x1, y1, x2, y2]` format.
    labels (Optional[Tensor]): `Int64Tensor[N]`, the class label index for each box.
    lb_names (Optional[List[str]]): All class label names.
    lb_colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of all class label names.
    lb_infos (Optional[List[str]]): Infos for given labels.
    save_name (Optional[str]): Save image name.
    show_name (Optional[str]): Show window name.
  """
  if not isinstance(image, torch.Tensor):
    image = torchvision.transforms.ToTensor()(image)

  if boxes is not None:
    if image.dtype != torch.uint8:
      image = torchvision.transforms.ConvertImageDtype(torch.uint8)(image)
    draw_labels = None
    draw_colors = None
    if labels is not None:
      draw_labels = [lb_names[i] for i in labels] if lb_names is not None else None
      draw_colors = [lb_colors[i] for i in labels] if lb_colors is not None else None
    if draw_labels and lb_infos:
      draw_labels = [f'{l} {i}' for l, i in zip(draw_labels, lb_infos)]
    # torchvision >= 0.9.0/nightly
    #  https://github.com/pytorch/vision/blob/master/torchvision/utils.py
    res = torchvision.utils.draw_bounding_boxes(image, boxes,
      labels=draw_labels, colors=draw_colors)
  else:
    res = image

  if save_name or show_name:
    res = res.permute(1, 2, 0).contiguous().numpy()
    if save_name:
      Image.fromarray(res).save(save_name)
    if show_name:
      plt.gcf().canvas.set_window_title(show_name)
      plt.imshow(res)
      plt.show()

  return res

参考

GoCoding 个人实践的经验分享,可关注公众号!

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,053评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,527评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,779评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,685评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,699评论 5 366
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,609评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,989评论 3 396
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,654评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,890评论 1 298
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,634评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,716评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,394评论 4 319
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,976评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,950评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,191评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 44,849评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,458评论 2 342

推荐阅读更多精彩内容