Dataloader主要是拿出一些Mini-Batch来供训练时能够快速使用。使用batch可以提升计算速度,但是其求值的性能会有些问题。因此选用了Mini-Batch来进行综合;
使用了Mini-Batch后都得用下面这样的嵌套循环:
#train cycle
for epoch in range(training_epochs):
#每次迭代执行一个Mini-Batch
for i in range(total_batch):
- 需要了解的概念:
- epoch:所有的样本都进行了一次正,反向传播。即所有样本都进行了一次训练;
- Batch-Size:每次训练时,所用的样本数量;
-
Iteration:分了多少个batch,也就是内层的那个迭代执行了多少次;
如:现有1w个样本,batch是1k个,即每次拿1k个样本。那么Iteration就是
Dataloader:需要知道目标数据的索引[i]以及长度len。这样一来,dataloader就可以自动对dataset进行小批量的数据集的生成:
第一步,shuffle:就是打乱顺序;第二步,将打乱后的数据进行分组,这里将两个样本作为一个batch
- 常见的读取数据集的方法:
- 直接读取所有数据,这种方法适用于数据集本身就不算大的数据;
- 对于数据量很大的一堆文件/图片之类的,可以通过一个list来保存其地址一类的,然后在用到的时候再进行读取;
如何去定义一个数据集
import torch
#Dataset是个抽象类
from torch.utils.data import Dataset
#torch中帮助加载数据的类
from torch.utils.data import DataLoader
import numpy as np
class DiabetesDataset(Dataset):
def __init__(self):
pass
#通过这个方法来支持下标操作
def __getitem__(self, index):
pass
def __len__(self):
pass
dataset = DiabetesDataset()
#torch直接提供的,一般都是设置这四个参数
#num_workers表示读数据的时候可以兼容的线程数
train_loader = DataLoader(dataset = dataset
,batch_size = 32
,shuffle = True
,num_workers = 2
)
但是在windows下运行上面的loader代码好像以后会报错:
解决方法:将loader封装到if中,而不是直接顶格写出来
if __name__ == '__main__':
for epoch in range(100):
#将train_loader所拿出的x, y 放入到data中去
for i, data in enumerate(train_loader, 0):
至此,加载数据集的功能就写好了,然后再加上前面写的model的代码就变成了:
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
#注意这里用的是nn下的sigmoid,
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
model = Model()
#3. 构造损失函数和优化器
#和之前一样
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
注意,上面这段代码在jupyter中无法运行,因为这玩意儿不能多线程。然后在pycharm中可以运行,但是速度很慢。这应该是因为数据量太少,所以多线程的调用反而影响了读取速度照成的;
主要就是在1和4进行了改造:1中不再是加载所有数据了,而是构造并使用了dataset和dataloader;4中则是改成了嵌套循环,适配mini-batch;
这样一来,就完成了对于糖尿病数据集进行分类的神经网络学习流程。
torchvision中提供的数据集们:
这些数据集都派生自dataset,所以都可以用dataloader进行加载,也有getitem, len等方法,还可以用多进程进行加速
transform是指要将数据转为想要的数据类型,这里是张量;
在test_loader那里是不用shuffle的,以保证每次输出的顺序都是一样的