Pytorch 中比较重要的是对数据的处理,其中,进行数据读取的一般有三个类:
- Dataset
- DataLoader
其中,这两个是一个依次封装的关系:“Dataset
被封装进DataLoader
,DataLoader
再被封装进DataLoaderIter
”
Dataset
Dataset
位于torch.utils.data.Dataset
,每当我们自定义类MyDataset
必须要继承它并实现其两个成员函数:
__len__()
__getitem__()
import torch
from torch.utils.data import Dataset
import pandas as pd
# 定义自己的类
class MyDataset(Dataset):
# 初始化
def __init__(self, file_name):
# 读入数据
self.data = pd.read_csv(file_name, sep='\t', usecols=['Phrase', 'Sentiment'])
# 返回df的长度
def __len__(self):
return len(self.data)
# 获取第idx+1列的数据
def __getitem__(self, idx):
return self.data.iloc[idx].Phrase, self.data.iloc[idx].Sentiment
# 通过实例化对象来访问该类
# 假设同目录下存在名为train.tsv的文件
ds = MyDataset('../datasets/train.tsv')
print(ds.data.head()) # 头数据
print(ds.data.iloc[1]) # 按行索引获取数据
# 结果
Phrase Sentiment
0 A series of escapades demonstrating the adage ... 1
1 A series of escapades demonstrating the adage ... 2
2 A series 2
3 A 2
4 series 2
Phrase A series of escapades demonstrating the adage ...
Sentiment 2
Name: 1, dtype: object
DataLoader
DataLoader
位于torch.utils.data.DataLoader
, 为我们提供了对Dataset
的读取操作
# 仅仅列举了常用的几个参数
torch.nn.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
-
dataset
: 上面所实现的自定义类Dataset
-
batch_size
: 默认为1,每次读取的batch的大小 -
shuffle
: 默认为False, 是否对数据进行shuffle操作(简单理解成将数据集打乱) -
num_works
: 默认为0,表示在加载数据的时候每次使用子进程的数量,即简单的多线程预读数据的方法
DataLoader
返回的是一个迭代器,我们通过这个迭代器来获取数据
Dataloder
的目的是将给定的n个数据, 经过Dataloader
操作后, 在每一次调用时调用一个小batch, 如:
- 给出的是: (5000,28,28) , 表示有5000个样本,每个样本的size为(28,28)
- 经过
Dataloader
处理后, 一次得到的是(100,28,28)(假设batch_size大小为100), 表示本次取出100个样本, 每个样本的size为(28,28)
# 连接上面的Dataset实现代码
from torch.utils.data import DataLoader
dl = DataLoader(ds, batch_size=10, shuffle=True, num_workers=2)
通过迭代器来分次获取数据:
dl_data = iter(dl)
print(next(dl_data))
# 结果
[('thematic ironies', 'whimsical and relevant today', "director George Hickenlooper 's approach to the material is too upbeat", 'direct-to-video\\/DVD category', 'Four Feathers', 'may well be the only one laughing at his own joke', 'the end credits', "What sets Ms. Birot 's film apart from others in the genre", 'overcoming-obstacles', 'homage pokepie hat , but as a character'), tensor([2, 3, 2, 1, 2, 1, 2, 3, 2, 2])]
或,直接通过for循环进行遍历输出
for i, data in enumerate(dl):
print(i, data)
# 这里只循环一次,所以用break
break
#结果
0 [('huge action sequence', ', characterization , poignancy , and intelligence', 'potentially incredibly twisting mystery', 'felt disrespected', 'a rather bland', 'the character dramas', 'a key strength', "'s never dull and always looks good", "the Queen 's", 'uncompromising knowledge'), tensor([3, 3, 3, 1, 1, 2, 3, 3, 2, 3])]