Pytorch基本使用
Date:2021-10-15
Author:CJPP
1. 自定义Datasets & Dataloaders
PyTorch 提供了两个数据处理类:torch.utils.data.DataLoader 和 torch.utils.data.Dataset,方便我们预加载的数据集以及我们自己的数据。 Dataset 存储样本及其相应的标签,DataLoader 在 Dataset 周围包装一个可迭代对象,以便轻松访问样本。
加载集成的Dataset,使用Dataloader迭代加载输出到tensorboard显示
- root:string类型,train/test数据集保存的目录
- train:bool类型,True表示指定train数据集,False表示指定test数据集
- download:bool类型,默认True,当root目录没有指定数据时会自动下载该数据集
- transform & target_transform:指定feature和label变换操作,详情查看【Transforms使用】
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('logs')
trains_data = torchvision.datasets.CIFAR10('./dataset', train=True, transform=transforms.ToTensor(), download=True)
dataloader = DataLoader(trains_data, batch_size=256,num_workers=1)
if __name__ == '__main__':
for index, batch_ds in enumerate(dataloader):
writer.add_image("CIFAR10", make_grid(batch_ds[0]), index)
writer.close()
终端运行tensorboard查看输出
tensorboard --logdir=logs
浏览器查看输出结果
Tensorboard输出
实现自定义数据集
- 准备数据集
自定义数据集需要继承Dataset类,并且重写一下三个方法:__init__,__len__和__getitem__。准备好一下结构的数据集(以此结构为例)
目标
标签
- 自定义数据集类
import torch
import os
from PIL import Image
class MyDataset(torch.utils.data.Dataset):
"""Some Information about MyDataset"""
def __init__(self, root_dir, image_dir, label_dir, transform):
super(MyDataset, self).__init__()
self.root_dir = root_dir
self.image_dir = image_dir
self.label_dir = label_dir
self.label_path = os.path.join(self.root_dir, self.label_dir)
self.image_path = os.path.join(self.root_dir, self.image_dir)
self.image_list = os.listdir(self.image_path)
self.lable_list = os.listdir(self.label_path)
self.transform = transform
self.image_list.sort()
self.lable_list.sort()
def __getitem__(self, index):
img_name = self.image_list[index]
label_name = self.lable_list[index]
img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)
label_item_path = os.path.join(
self.root_dir, self.label_dir, label_name)
img = Image.open(img_item_path)
with open(label_item_path, 'r') as f:
label = f.readline()
img = self.transform(img)
return {'img': img, 'label': label}
def __len__(self):
assert len(self.image_list) == len(self.lable_list)
return len(self.image_list)
- 使用自定义数据集
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from MyData import MyDataset
writer = SummaryWriter('logs')
transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
ants_ds = MyDataset('dataset/ds/train', 'ants_image', 'ants_label', transform)
dataloader = DataLoader(ants_ds, batch_size=32,num_workers=1)
if __name__ == '__main__':
for index, batch_ds in enumerate(dataloader):
writer.add_image("ANTS_BEES", make_grid(batch_ds['img']), index)
writer.close()
- 查看tensorboard
tensorboard --logdir=logs
自定义数据集
2. Transforms
前面两部分我们已经涉及并使用了transforms的部分方法,下面介绍其常用功能
- transforms.Compose():接受一个数组作为参数,数组的元素是每一个具体变换,如:transforms.ToTensor()
- transforms.ToTensor():将数据转化成pytorch的tensor类型
- transforms.Resize():参数是一个size,将图片转变成此大小