为什么要用?
习惯于自己实现业务逻辑的每一步,以至于没有意识去寻找框架本身自有的数据预处理方法,Pytorch的Dataset 和 DataLoader便于加载和迭代处理数据,并且可以傻瓜式实现各种常见的数据预处理,以供训练使用。
调包侠
from torch.utils.data.dataset import Dataset, DataLoader
from torchvision import transforms ##可方便指定各种transformer,直接传入DataLoader
Dataset 和 DataLoader是什么?
Dataset是一个包装类,可对数据进行张量(tensor)的封装,其可作为DataLoader的参数传入,进一步实现基于tensor的数据预处理。
如何处理自己的数据集?
很多教程里分两种情况:数据同在一个文件夹;数据按类别分布在不同文件夹。其实刚开始我是一头雾水,后来总结后发现,两种情况均可用一种方法来处理,即:只要有一份文件,记录图像数据路径及对应的标签即可,如下所示:
record.txt 示例:
pic_path label
./pic_01/aaa.bmp 1
./pic_22/bbb.bmp 0
./pic_03/ccc.bmp 3
./pic_01/ddd.bmp 1
...
其实有了上面的一份数据对照表文件,即可不用管是否在同一文件夹或是不同文件夹的情况,我自己感觉是要方便一些。下面就按照这种方法来介绍如何使用。
第一步:实现MyDataset类
既然是要处理自己的数据集,那么一般情况下还是写一个自己的Dataset类,该类要继承Dataset,并重写 __ init __() 和 __ getitem __() 两个方法。
例如:
class MyDataset(Dataset):
def __init__(self, record_path, is_train=True):
## record_path:记录图片路径及对应label的文件
self.data = []
self.is_train = is_train
with open(record_path) as fp:
for line in fp.readlines():
if line == '\n':
break
else:
tmp = line.split("\t")
## tmp[0]:某图片的路径,tmp[1]:该图片对应的label
self.data.append([tmp[0], tmp[1]])
# 定义transform,将数据封装为Tensor
self.transformations = transforms.Compose([transforms.ToTensor()])
# 获取单条数据
def __getitem__(self, index):
img = self.transformations (Image.open(self.data[index][0]).resize((256,256)).convert('RGB'))
label = int(self.data[index][1])
return img, label
# 数据集长度
def __len__(self):
return len(self.data)
上面是一个简单的MyDataset类,仅依赖记录了图像位置以及相应label的record文件,实现对数据集的读取和Tensor的转换
当然,根据个人对数据预处理的需求不同,该类的实现可进一步完善,例如:
class MyDataset(Dataset):
def __init__(self, base_path, is_train=True):
self.data = []
self.is_train = is_train
with open(base_path) as fp:
for line in fp.readlines():
if line == '\n':
break
else:
tmp = line.split("\t")
self.data.append([tmp[0], tmp[1]])
## transforms.Normalize:对R G B三通道数据做均值方差归一化,因此给出下方三个均值和方差
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]
normTransform = transforms.Normalize(normMean, normStd)
## 可由 transforms.Compose([transformer_01, transformer_02, ...])实现一些数据的处理和增强
self.trainTransform = transforms.Compose([ ## train训练集处理
transforms.RandomCrop(32, padding=4), ## 图像裁剪的transforms
transforms.RandomHorizontalFlip(p=0.5), ## 以50%概率水平翻转
transforms.ToTensor(), ## 转为Tensor形式
normTransform ## 进行 R G B数据归一化
])
## 测试集的transforms数据处理
self.testTransform = transforms.Compose([
transforms.ToTensor(),
normTransform
])
# 获取单条数据
def __getitem__(self, index):
img = self.trainTransform(Image.open(self.data[index][0]).resize((256,256)).convert('RGB'))
if not self.is_train:
img = self.testTransform(Image.open(self.data[index][0]).resize((256, 256)).convert('RGB'))
label = int(self.data[index][1])
return img, label
# 数据集长度
def __len__(self):
return len(self.data)
或许已经看出来了,所有可能的数据处理或数据增强操作,都可通过transforms来进行调用与封装,是不是一下变得很方便呢!
第二步:将MyDataset装入DataLoader中
MyDataset类中的init方法要求传入记录数据路径及label的文件,因此可如下所示进行操作:
import MyDataset
train_data = MyDataset.MyDataset("./train_record.txt")
test_data = myDataset.MyDataset("./test_record.txt")
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
trainLoader = DataLoader(dataset=train_data,batch_size=64,shuffle=True,**kwargs)
testLoader = DataLoader(dataset=test_data,batch_size=64,shuffle=False, **kwargs)
这样,便生成了trainLoader 和testLoader
第三步:在训练中使用DataLoader
for epoch in range(1, args.nEpochs + 1):
## 定义好的train方法
train(args, epoch, model, trainLoader, optimizer)
## 定义好的val方法,用于测试或验证
val(args, epoch, model, testLoader, optimizer)
最后
以上便是使用 Dataset和DataLoader处理自己数据集的通用方法,当然本次仅记录了图片数据的使用方法,后续记录文本数据处理方法。
彩蛋
ooh~~ 那么对于Pytorch自带数据集如果处理呢?
若直接使用 CIFAR10
数据集,可以如下处理:
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normTransform
])
testTransform = transforms.Compose([
transforms.ToTensor(),
normTransform
])
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
trainLoader = DataLoader(dset.CIFAR10(root='cifar', train=True, download=True,
transform=trainTransform),batch_size=64, shuffle=True, **kwargs)
testLoader = DataLoader(dset.CIFAR10(root='cifar', train=False, download=True,
transform=testTransform),batch_size=64, shuffle=False, **kwargs)
其实也就是 torchvision.datasets
将这些共用数据集本身就做了 Dataset类的封装,因此直接调用,传入你想要的transforms,再丢给DataLoader即可。