主题
加载数据集
总结1
用Dataset构造数据集,用DataLoader按Mini-Batch取数据。
用Mni-Batch可以并行计算N个样本,提升速度。
训练循环
for epoch in range(trainning_epochs):
for i in range(total_batch)
三个概念:epoch、Batch-Size、Iteration。一次epoch指的是所有样本都参与训练。
Dataset是抽象类,能被继承,不能被实例化。
Dataset的子类需要能被索引访问和被len。即要有
__getitem__()
和__len__()
。在它的__init()__
中可以有两种实现方法,一种是把所有数据都读到内存里面,另一种是当被索引访问时才从硬盘读取。DataLoader可以被实例化。它的常见参数有
DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)
。shuffle为是否打乱,num_workers是采用几个CPU核心并行读取。它的示例是一个可迭代对象,迭代出Tensor类。
总结2
- 导入Dataset类用于自定义自己数据集类:
from torch.utils.data import Dataset
- 导入DataLoader类用于读取数据:
from torch.utils.data import DataLoader
- pytorch内置数据集:
import torchvision.datasets as datasets
- 导入transforms模块用于数据预处理:
import torchvision.transforms as transforms