Pytorch基础(一)

Pytorch基本使用

Date:2021-10-15
Author:CJPP

1. 自定义Datasets & Dataloaders

PyTorch 提供了两个数据处理类:torch.utils.data.DataLoader 和 torch.utils.data.Dataset,方便我们预加载的数据集以及我们自己的数据。 Dataset 存储样本及其相应的标签,DataLoader 在 Dataset 周围包装一个可迭代对象,以便轻松访问样本。

加载集成的Dataset,使用Dataloader迭代加载输出到tensorboard显示

  • root:string类型,train/test数据集保存的目录
  • train:bool类型,True表示指定train数据集,False表示指定test数据集
  • download:bool类型,默认True,当root目录没有指定数据时会自动下载该数据集
  • transform & target_transform:指定feature和label变换操作,详情查看【Transforms使用】
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('logs')
trains_data = torchvision.datasets.CIFAR10('./dataset', train=True, transform=transforms.ToTensor(), download=True)
dataloader = DataLoader(trains_data, batch_size=256,num_workers=1)
if __name__ == '__main__':
    for index, batch_ds in enumerate(dataloader):
        writer.add_image("CIFAR10", make_grid(batch_ds[0]), index)
    writer.close()

终端运行tensorboard查看输出

tensorboard --logdir=logs

浏览器查看输出结果


Tensorboard输出

实现自定义数据集

  • 准备数据集

自定义数据集需要继承Dataset类,并且重写一下三个方法:__init__,__len__和__getitem__。准备好一下结构的数据集(以此结构为例)


目标
标签
  • 自定义数据集类
import torch
import os
from PIL import Image


class MyDataset(torch.utils.data.Dataset):
   """Some Information about MyDataset"""

   def __init__(self, root_dir, image_dir, label_dir, transform):
       super(MyDataset, self).__init__()
       self.root_dir = root_dir
       self.image_dir = image_dir
       self.label_dir = label_dir
       self.label_path = os.path.join(self.root_dir, self.label_dir)
       self.image_path = os.path.join(self.root_dir, self.image_dir)
       self.image_list = os.listdir(self.image_path)
       self.lable_list = os.listdir(self.label_path)
       self.transform = transform

       self.image_list.sort()
       self.lable_list.sort()

   def __getitem__(self, index):
       img_name = self.image_list[index]
       label_name = self.lable_list[index]
       img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)
       label_item_path = os.path.join(
           self.root_dir, self.label_dir, label_name)
       img = Image.open(img_item_path)

       with open(label_item_path, 'r') as f:
           label = f.readline()
       img = self.transform(img)
       return {'img': img, 'label': label}

   def __len__(self):
       assert len(self.image_list) == len(self.lable_list)
       return len(self.image_list)

  • 使用自定义数据集
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from MyData import MyDataset

writer = SummaryWriter('logs')
transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
ants_ds = MyDataset('dataset/ds/train', 'ants_image', 'ants_label', transform)
dataloader = DataLoader(ants_ds, batch_size=32,num_workers=1)
if __name__ == '__main__':
    for index, batch_ds in enumerate(dataloader):
        writer.add_image("ANTS_BEES", make_grid(batch_ds['img']), index)
    writer.close()
  • 查看tensorboard
tensorboard --logdir=logs
自定义数据集

2. Transforms

前面两部分我们已经涉及并使用了transforms的部分方法,下面介绍其常用功能

  • transforms.Compose():接受一个数组作为参数,数组的元素是每一个具体变换,如:transforms.ToTensor()
  • transforms.ToTensor():将数据转化成pytorch的tensor类型
  • transforms.Resize():参数是一个size,将图片转变成此大小
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
禁止转载,如需转载请通过简信或评论联系作者。

推荐阅读更多精彩内容