《PyTorch深度学习实践》(3)

主题

加载数据集

总结1

  1. 用Dataset构造数据集,用DataLoader按Mini-Batch取数据。

  2. 用Mni-Batch可以并行计算N个样本,提升速度。

  3. 训练循环

for epoch in range(trainning_epochs):
    for i in range(total_batch)
  1. 三个概念:epoch、Batch-Size、Iteration。一次epoch指的是所有样本都参与训练。

  2. Dataset是抽象类,能被继承,不能被实例化。

  3. Dataset的子类需要能被索引访问和被len。即要有__getitem__()__len__()。在它的__init()__中可以有两种实现方法,一种是把所有数据都读到内存里面,另一种是当被索引访问时才从硬盘读取。

  4. DataLoader可以被实例化。它的常见参数有DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)。shuffle为是否打乱,num_workers是采用几个CPU核心并行读取。它的示例是一个可迭代对象,迭代出Tensor类。

总结2

  1. 导入Dataset类用于自定义自己数据集类:from torch.utils.data import Dataset
  2. 导入DataLoader类用于读取数据:from torch.utils.data import DataLoader
  3. pytorch内置数据集:import torchvision.datasets as datasets
  4. 导入transforms模块用于数据预处理:import torchvision.transforms as transforms
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。