PyTorch入门学习(七):数据加载与处理

写在前面

未经允许,不得转载,谢谢~~

在深度学习的问题中处理数据都会占据比较大的时间,只有把数据处理好了才有可能对模型进行训练、测试等后续工作。

PyTorch提供了很多用于让数据加载变得更加方便的工具,接下来我们就来学习一下怎么样处理那些PyTorch没有提供直接接口的数据。

在学习这个之前,首先要保证电脑上已经安装了下面这两样东西:

  • scikit-image:用于图像输入输出和转换
  • pandas:用于更好的处理csv数据

这篇文章内容还是比较多的,但认真看完应该就可以掌握各种数据集的处理了。

1. 头文件导入

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

2. 数据集介绍及下载

2.1 数据集介绍

接下来我们要处理的数据集是关于脸部姿势的,每张图片都会被注释成这样,每张脸上都会有68各不同的标记点:


2.2 数据集下载与展示

  1. 戳这里下载需要教程中用到的脸部数据集,跟数据集一起的还有一个注释文件face_landmarks.csv
    直接打开如下图所示:

    csv文件

    即每张图片都对应一个文件名和对应的N个脸部特征标记点。

  2. 在注释文件中的是N个坐标点,每个坐标点由两个横纵坐标组成。所以先用pandas工具把注释文件处理一下。

landmarks_frame = pd.read_csv('faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

得到的结果为:


3.将图像和对应的特征点标记出来展示。

def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()
show_landmarks(io.imread(os.path.join('faces/', img_name)),
               landmarks)
plt.show()

得到的结果为:


3. Dataset类介绍

3.1 原理介绍

torch.utils.data.Dataset是一个PyTorch用来表示数据集的抽象类。我们用这个类来处理自己的数据集的时候必须继承Dataset,然后重写下面的函数:

  1. __len__: 使得len(dataset)返回数据集的大小;
  2. __getitem__:使得支持dataset[i]能够返回第i个数据样本这样的下标操作。

3.2 创建脸部图像数据集

  1. 在类的__init__函数中完成csv文件的读取工作;
  2. 在类的__getitem__函数中完成图片的读取工作。这样是为了减小内存开销,只要在需要用到的时候才将图片读入。
  3. 除此,数据集还会接收一个可以选择的参数transform,用来对图像做一些改变,具体的会在下面进行介绍。
  4. 最终返回的样本数据是一个字典形式的,如下所示:{‘image':image,'landmarks':landmarks}

那么现在我们就可以写出类的定义:

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

3.3 实例化类

接下来我们对上面定义好的类做实例化,然后在数据样本上进行迭代。我们会打印前4个样本图像及其对应的坐标点。

face_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
                                    root_dir='faces/')

fig = plt.figure()

for i in range(len(face_dataset)):
    sample = face_dataset[i]

    print(i, sample['image'].shape, sample['landmarks'].shape)

    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)

    if i == 3:
        plt.show()
        break

结果如下所示:

4. Transforms

从上面显示的图片我们可以看到每张图片的大小都不一样,但往往我们在处理神经网络的输入图像的时候都希望它们有一个相对固定的大小。因此,我们需要一些对图像进行预处理的工作。

4.1 实现常用变换功能

我们试着写一下这三个常用的变换功能:

  • Rescale:重新调整图像大小;
  • RandomCrop:随机从图像中截取一部分;
  • ToTensor:将numpy类型表示的图像转换成torch表示的图像。

我们用类而不是函数来实现以上这三个功能,主要是考虑到如果用函数的话,每次都需要传入参数,但是用类就可以省掉很多麻烦。我们只需要实现每个类的__call__函数和__init__函数。

下面是对这三个功能的实现:

class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}

class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

4.2 组合以上变换功能

假设我们现在需要将图像的较短边调整到256,然后从中随机截取224的正方形图像。我们就可以调用torchvision.transforms.Compose将以上的RescaleRandomCrop两个变换组合起来。

以下的代码段展示了分开进行变换以及用Compose组合进行变换的结果图

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)

    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)

plt.show()

5. 合并dataset与transform、遍历数据集

简单回顾一下:

  • 第3小节我们介绍了dataset类;
  • 第4小节我们我们介绍了怎么样实现各个转换函数,然后将其组合起来。

如果你还记得的话,我们在之前定义dataset的时候是有一个transform参数的,但我们在第4节中是先取了样本数据,然后再进行变换操作,并没有将其作为参数传到dataset中。所以我们现在要做的工作就是将所有的内容集成到一起。每次抽取一个样本,都会有以下步骤:

  1. 从文件中读取图片;
  2. 将转换应用于读入的图片;
  3. 由于做了随机选取的操作,所以起到了数据增强的效果。

其实我们只要把Transform的部分作为形参传入dataset就可以了,其他的都不变。

然后用for循环来依次获得数据集样本。

transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
                                           root_dir='faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]

    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:
        break

取到的四个数据样本如下所示:


6. DataLoader类

以上我们已经实现了datasettransform的合并,也实现了用for循环来获取每一个样本数据,好像事情就已经结束了。

但等等,真的结束了吗?emmmm,我们好像还落了什么事情,是的没错:

  • 按照batch_size获得批量数据;
  • 打乱数据顺序;
  • 用多线程multiprocessing来加载数据;

torch.utils.data.DataLoader这个类为我们解决了以上所有的问题,是不是很腻害~

只要按照要求设置DataLoader的参数即可:

  1. 第一个参数传入transformed_dataset,即已经用了transformDataset实例。
  2. 第二个参数传入batch_size,表示每个batch包含多少个数据。
  3. 第三个参数传入shuffle,布尔型变量,表示是否打乱。
  4. 第四个参数传入num_workers表示使用几个线程来加载数据。

如下所示即实现了DataLoader函数的使用,及批样本数据的展示。

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)


# Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size,
                    landmarks_batch[i, :, 1].numpy(),
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break

这样呢其实就完成了对数据集完整的处理了。

7. torchvision

torchvision包提供了一些常用的数据集和转换函数。使用torchvision甚至不需要自己写处理函数。

torchvision中最通用的数据集是ImageFolder,它假设数据结构为如下:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

这里的root指代根目录,ants bees指的是不同的类标签,后面的是具体的图片名称。

当然它还提供了对PIL.Image的常用操作,包括RandomHorizontalFlip Scale等等。

以下为用torchvision实现的超简化版本的数据处理方法:

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

整理总结

我们来整理一下整个实现思路哦~
主要分以下三种情况:

1 对于torchvision提供的数据集

  • 这是最简单的一种情况。
  • 对于这一类数据集,就是PyTorch已经帮我们做好了所有的事情,连数据源都不需要自己下载。
  • Imagenet,CIFAR10,MNIST等等PyTorch都提供了数据加载的功能,所以可以先看看你要用的数据集是不是这种情况。
  • 具体的使用方法详见之前的博客Pytorch入门学习(四)-training a classifier

2 对于特定结构的数据集

  • 这种情况就是不在上述PyTorch提供数据库之列,但是满足下面的形式:
     root/ants/xxx.png
     root/ants/xxy.jpeg
     root/ants/xxz.png
    .
    .
    .
    root/bees/123.jpg
    root/bees/nsdf3.png
    root/bees/asd932_.png
    
  • 那么就可以通过torchvision中的通用数据集ImageFolder来完成加载。
  • 具体使用方法见上文。

3 对于最普通的数据集

  • 最后一种情况是既不是自带数据集,又不满足ImageFolder,这种时候就自己进行处理。
  • 首先,定义数据集的类(myDataset),这个类要继承dataset这个抽象类,并实现__len__以及__getitem__这两个函数,通常情况还包括初始函数__init__.
  • 然后,实现用于特定图像预处理的功能,并封装成类。当然常用的一些变换可以在torchvision中找到。用torchvision.transforms.Compose将它们进行组合成(transform)
  • transform作为上面myDataset类的参数传入,并得到实例化myDataset得到(transformed_dataset)对象。
  • 最后,将transformed_dataset作为torch.utils.data.DataLoader类的形参,并根据需求设置自己是否需要打乱顺序,批大小...
  • 具体见上文。

参考资料

Data Loading and Processing Tutorial

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,029评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,395评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,570评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,535评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,650评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,850评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,006评论 3 408
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,747评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,207评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,536评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,683评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,342评论 4 330
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,964评论 3 315
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,772评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,004评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,401评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,566评论 2 349

推荐阅读更多精彩内容