介绍一下另一种数据集的类型,即可迭代类型(Iterable-style)的数据集。相比映射类型的数据集,这个数据集并不需要实现getitem方法或者len方法,它本身更像一个Python迭代器。torch.utils.data.InterableDataset是要给可迭代数据集类型的构造方法。不同于映射类型,因为索引之间相互独立,在使用多进程载入数据的情况下(DataLoader中的参数num_works>1),多个进程可以独立分配索引,迭代器在使用过程中,因为索引之间有前后顺序,需要考虑如何分割数据,使得不同的进程可以得到不同的数据。
#torch.utils.data.InterableDataset
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, \
"this example code only works with end >= start"
self.start = start
self.end = end
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
iter_start = self.start
iter_end = self.end
else:
per_worker = int(math.ceil((self.end - self.start)/float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return iter(range(iter_start, iter_end))
从上述代码来看,根据不同的工作进程的序号worker_id, 设定了不同进程数据迭代器取值的范围,这样就能保证不同的进程获取不同的迭代器,而且迭代器的返回的数据各不相同。