PyTorch 知识

PyTorch使用总览

原文链接:https://blog.csdn.net/u014380165/article/details/79222243

参考:PyTorch学习之路(level1)——训练一个图像分类模型PyTorch学习之路(level2)——自定义数据读取PyTorch源码解读之torchvision.transformsPyTorch源码解读之torch.utils.data.DataLoaderPyTorch源码解读之torchvision.models

一、数据读取

官方代码库中有一个接口例子:torchvision.ImageFolder -- 针对的数据存放方式是每个文件夹包含一个类的图像,但往往实际应用中可能你的数据不是这样维护的,此时需要自定义一个数据读取接口(使用PyTorch中数据读取基类:torch.utils.data.Dataset)

数据读取接口
class customData(data.Dataset):

    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader):
        """
        提供数据地址(data path)、每一文件所属的类别(label),and other Info wanted(transform\loader\...) --> self.(attributes)
        
        :param root(string): Root directory path.
        :param transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
        :param target_transform(callable, optional): A function/transform that takes in the target and transforms it. 
        :param loader (callable, optional): A function to load an image given its path. 
        
        the data loader where the images are arranged in this way: ::
                root/class_1_xxx.png    
                root/class_2_xxx.png
        ...
                root/class_n_xxx.png    # 此例中,文件名包含label信息,__init__中可不需要额外提供
            
        """
        self.dataset = [os.path.join(root, npy_data) for npy_data in os.listdir(root)]  # 整个数据集(图像)文件的路径
                
                self.transform = transform  # (optional)
        self.target_transform = target_transform    # (optional)
        self.loader = loader    # (optional)
        
    def __getitem__(self, index):
        """
        :return 相应index的data && label

                """
        data = np.load(self.dataset[index])
        
        if self.transform is not None:  # (optional)
            img = self.transform(img)
        if self.target_transform is not None:  # (optional)
            target = self.target_transform(target)
        
        label_txt = self.dataset[index].split('/')[-1][:2]  # (class_n)_xxxx.npy → (class_n)

        if label_txt == 'class_1':
            label = 0
        elif label_txt == 'class_2':
            label = 1
        else:
            raise RuntimeError('Now only support class_1 vs class_2.')

        return data, label

    def __len__(self):
        """
                :return 数据集数量
                
        """
        return len(self.dataset)

上述提到的transforms数据预处理,可以通过torchvision.transforms接口来实现。具体请看博客:PyTorch源码解读之torchvision.transforms

接口调用
root_dir = r'xxxxxxxx'  
image_datasets = {x: customData(root=root_dir+x) for x in ['train', 'val', 'test']}

返回的image_datasets(自定义数据读取接口)就和用torchvision.datasets.ImageFolder类(官方提供的数据读取接口)返回的数据类型一样

数据迭代器封装
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4)
                   for x in ['train', 'valid', 'test']}

torch.utils.data.DataLoader接口将每个batch的图像数据和标签都分别封装成Tensor,方便以batch进行模型批训练,具体可以参考博客: PyTorch源码解读之torch.utils.data.DataLoader

至此,从图像和标签文件就生成了Tensor类型的数据迭代器,后续仅需将Tensor对象用torch.autograd.Variable接口封装成Variable类型(比如train_data=torch.autograd.Variable(train_data),如果要在gpu上运行则是:train_data=torch.autograd.Variable(train_data.cuda()))就可以作为模型的输入

二、网络构建

PyTorch框架中提供了一些方便使用的网络结构及预训练模型接口:torchvision.models,具体可以看博客:PyTorch源码解读之torchvision.models。该接口可以直接导入指定的网络结构,并且可以选择是否用预训练模型初始化导入的网络结构。示例如下:

import torchvision
model = torchvision.models.resnet50(pretrained=True)  # 导入resnet50的预训练模型

那么如何自定义网络结构呢?在PyTorch中,构建网络结构的类都是基于torch.nn.Module这个基类进行的,也就是说所有网络结构的构建都可以通过继承该类来实现,包括torchvision.models接口中的模型实现类也是继承这个基类进行重写的。自定义网络结构可以参考:1、https://github.com/miraclewkf/MobileNetV2-PyTorch。该项目中的MobileNetV2.py脚本自定义了网络结构。2、https://github.com/miraclewkf/SENet-PyTorch。该项目中的se_resnet.py和se_resnext.py脚本分别自定义了不同的网络结构。

如果要用某预训练模型为自定义的网络结构进行参数初始化,可以用torch.load接口导入预训练模型,然后调用自定义的网络结构对象的load_state_dict方式进行参数初始化,具体可以看https://github.com/miraclewkf/MobileNetV2-PyTorch项目中的train.py脚本中if args.resume条件语句(如下所示)。

if args.resume:
  if os.path.isfile(args.resume):
    print(("=> loading checkpoint '{}'".format(args.resume)))
    checkpoint = torch.load(args.resume)
    base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.state_dict().items())}
    model.load_state_dict(base_dict)
    else:
      print(("=> no checkpoint found at '{}'".format(args.resume)))

三、其他设置

优化函数通过torch.optim包实现,比如torch.optim.SGD()接口表示随机梯度下降。更多优化函数可以看官方文档:http://pytorch.org/docs/0.3.0/optim.html

学习率策略通过torch.optim.lr_scheduler接口实现,比如torch.optim.lr_scheduler.StepLR()接口表示按指定epoch数减少学习率。更多学习率变化策略可以看官方文档:http://pytorch.org/docs/0.3.0/optim.html

损失函数通过torch.nn包实现,比如torch.nn.CrossEntropyLoss()接口表示交叉熵等。

多GPU训练通过torch.nn.DataParallel接口实现,比如:model = torch.nn.DataParallel(model, device_ids=[0,1])表示在gpu0和1上训练模型。

模块解读

torch.utils.data.DataLoader

将数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入,因此该接口有承上启下的作用

源码地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py

示例:

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4)
                   for x in ['train', 'valid', 'test']}
  • dataset (Dataset): dataset from which to load the data.
  • batch_size (int, optional): how many samples per batch to load (default: 1).
  • shuffle (bool, optional): set to True to have the data reshuffled at every epoch (default: False).
  • num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
  • ... ...

从torch.utils.data.DataLoader类生成的对象中取数据:

train_data=torch.utils.data.DataLoader(...)
for i, (input, target) in enumerate(train_data):
    # ...
    pass

此时,调用DataLoader类的__iter__方法 ⤵️:

    def __iter__(self):
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)

使用队列queue对象,完成多线程调度;通过迭代器iter,完成batch更替(详情读源码)

torchvision.transforms

基本上PyTorch中的data augmentation操作都可以通过该接口实现,包含resize、crop等常见的data augmentation操作

示例:

import torchvision
import torch
train_augmentation = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
                                                     torchvision.transforms.RandomCrop(224),                                                                            
                                                     torchvision.transofrms.RandomHorizontalFlip(),
                                                     torchvision.transforms.ToTensor(),
                                                     torch vision.Normalize([0.485, 0.456, -.406],[0.229, 0.224, 0.225])
                                                     ])

class custom_dataread(torch.utils.data.Dataset):  # 数据读取接口
    def __init__():
        ...
    def __getitem__():
        # use self.transform for input image
    def __len__():
        ...

train_loader = torch.utils.data.DataLoader(  # 数据迭代器
    custom_dataread(transform=train_augmentation),
    batch_size = batch_size, shuffle = True,
    num_workers = workers, pin_memory = True)

这里定义了resize、crop、normalize等数据预处理操作,并最终作为数据读取类custom_dataread的一个参数传入,可以在内部方法__getitem__中实现数据增强操作。

源码地址:transformas.py --- 定义各种data augmentation的类、functional.py --- 提供transformas.py中所需功能函数的实现

  • Compose类:Composes several transforms together. 对输入图像img逐次应用输入的[transform_1, transform_2, ...]操作

  • ToTensor类:Convert a PIL Image or numpy.ndarray to tensor. 要强调的是在做数据归一化之前必须要把PIL Image转成Tensor,而其他resize或crop操作则不需要.

  • ToPILImage类:Convert a tensor or an ndarray to PIL Image.

  • Normalize类:Normalize a tensor image with mean and standard deviation.一般都会对输入数据做归一化操作

  • Resize类:Resize the input PIL Image to the given size. 几乎都要用到,这里输入可以是int,此时表示将输入图像的短边resize到这个int数,长边则根据对应比例调整,图像的长宽比不变。如果输入是个(h,w)的序列,h和w都是int,则直接将输入图像resize到这个(h,w)尺寸,相当于force resize,所以一般最后图像的长宽比会变化,也就是图像内容被拉长或缩短。若输入是PIL Image,则将调用Image的各种方法;若输入是Tensor,则对应函数基本是在调用Tensor的各种方法。

  • CenterCrop类:Crops the given PIL Image at the center. 一般数据增强不会采用这个,因为当size固定的时候,在相同输入图像的情况下,N次CenterCrop的结果都是一样的

  • RandomCrop类:Crop the given PIL Image at a random location. 相较CenterCrop,随机裁剪更常用

  • RandomResizedCrop类:Crop the given PIL Image to random size and aspect ratio. 根据随机生成的scale、aspect ratio(缩放比例、长宽比)、中心点裁剪原图,(为可正常训练)再缩放为输入的size大小

  • RandomHorizontalFlip类:Horizontally flip the given PIL Image randomly with a given probability. 随机的图像水平翻转,通俗讲就是图像的左右对调,较常用。 probability of the image being flipped. Default value is 0.5 (水平翻转的概率是0.5)

  • RandomVerticalFlip类:Vertically flip the given PIL Image randomly with a given probability. 随机的图像竖直翻转,通俗讲就是图像的上下对调,较常用。probability of the image being flipped. Default value is 0.5(竖直翻转的概率是0.5)

  • FiveCrop类:Crop the given PIL Image into four corners and the central crop. 曾在TSN算法的看到过这种用法。

  • TenCrop类:Crop the given PIL Image into four corners and the central crop plus the flipped version of
    these (horizontal flipping is used by default) 将输入图像进行水平或竖直翻转,然后再进行FiveCrop操作;加上原始的FiveCrop操作,这样一张输入图像就能得到10张crop结果。

  • LinearTransformation类:Transform a tensor image with a square transformation matrix and a mean_vector computed offline. 用一个变换矩阵去乘输入图像得到输出结果。

  • ColorJitter类:Randomly change the brightness, contrast, saturation and hue (即亮度,对比度,饱和度和色调)of an image,可以根据注释来合理设置这4个参数。(较常用)

  • RandomRotation类:随机旋转输入图像,具体参数可以看注释,在F.rotate()中主要是调用PIL Image的rotate方法。(较常用)

  • Grayscale类:用来将输入图像转成灰度图的,这里根据参数num_output_channels的不同有两种转换方式

  • RandomGrayscale类:Randomly convert image to grayscale with a probability of p (default 0.1).

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,186评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,858评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,620评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,888评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,009评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,149评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,204评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,956评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,385评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,698评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,863评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,544评论 4 335
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,185评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,899评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,141评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,684评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,750评论 2 351