优化pytorch DataLoader提升数据加载速度

因为pytorch数据加载速度太慢,影响训练速度,实训快速加载数据方式,提前获取要加载的数据,整体速度能快1/6.

操作步骤如下所示:

1、激活自己的torch虚拟环境:

source activate torch

2、安装prefetch_generator包

pip install prefetch_generator

3、定义DataLoaderX,继承torch原有的DataLoaderX的属性

class DataLoaderX(DataLoader):

    def __iter__(self):

        return BackgroundGenerator(super().__iter__())

4、使用的时候只需要把DataLoader变成DataLoaderX:

如下所示:

gen= DataLoaderX(train_dataset,shuffle=True,

batch_size=batch_size,

num_workers=4,pin_memory=True,

drop_last=False,

collate_fn=yolo_dataset_collate)

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容