Pytorch实现VOC数据集的Dataset

Pascal VOC2012 数据集下载地址:

http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

代码

import os
import torch
import xml.etree.ElementTree as ET
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from config import Config
import numpy as np
from PIL import Image

image_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(256),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])


class VOCDataset(Dataset):

    def __init__(self, data_dir, train=True, transform=None):
        super(VOCDataset, self).__init__()
        # 获取txt文件
        self.data_dir = data_dir
        if (train):
            split = 'trainval'
        else:
            split = 'val'
        id_list_file = os.path.join(self.data_dir, 'ImageSets/Main/{0}.txt'.format(split))
        self.ids = [id_.strip() for id_ in open(id_list_file)]

        self.transform = transform

    def __getitem__(self, item):
        id = self.ids[item]
        # 解析xml文件得到图片的bbox, label
        anno = ET.parse(
            os.path.join(self.data_dir, 'Annotations', id + '.xml'))

        bbox = []
        label = []
        for obj in anno.findall('object'):

            bndbox_anno = obj.find('bndbox')
            box = []
            for tag in ('ymin', 'xmin', 'ymax', 'xmax'):
                box.append(int(bndbox_anno.find(tag).text) - 1)
            bbox.append(box)

            name = obj.find('name').text.lower().strip()
            label.append(Config.VOC_BBOX_LABEL_NAMES.index(name))

        bbox = np.stack(bbox).astype(np.float32)
        label = np.stack(label).astype(np.float32)

        # 获取对应图片
        img_file = os.path.join(self.data_dir, 'JPEGImages', id + '.jpg')
        img = Image.open(img_file)
        if self.transform:
            img = self.transform(img)
        if img.ndim == 2:
            img = img[np.newaxis]
        # (H,W,C)->(C,H,W)
        img = img.transpose(2, 0)

        return img, bbox, label

    def __len__(self):
        return len(self.ids)


if __name__ == '__main__':
    dataset = VOCDataset(data_dir=Config.voc_data_dir, train=True, transform=image_transform)
    data_loader = DataLoader(dataset, batch_size=1)
    for idx, (image, bbox, lable) in enumerate(data_loader):
        print (bbox)

常量文件 config.py

class Config:
    voc_data_dir = 'VOCdevkit/VOC2012'

    VOC_BBOX_LABEL_NAMES = (
        'aeroplane',
        'bicycle',
        'bird',
        'boat',
        'bottle',
        'bus',
        'car',
        'cat',
        'chair',
        'cow',
        'diningtable',
        'dog',
        'horse',
        'motorbike',
        'person',
        'pottedplant',
        'sheep',
        'sofa',
        'train',
        'tvmonitor'
    )
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容