因为pytorch数据加载速度太慢,影响训练速度,实训快速加载数据方式,提前获取要加载的数据,整体速度能快1/6.
操作步骤如下所示:
1、激活自己的torch虚拟环境:
source activate torch
2、安装prefetch_generator包
pip install prefetch_generator
3、定义DataLoaderX,继承torch原有的DataLoaderX的属性
class DataLoaderX(DataLoader):
def __iter__(self):
return BackgroundGenerator(super().__iter__())
4、使用的时候只需要把DataLoader变成DataLoaderX:
如下所示:
gen= DataLoaderX(train_dataset,shuffle=True,
batch_size=batch_size,
num_workers=4,pin_memory=True,
drop_last=False,
collate_fn=yolo_dataset_collate)