from importlib import import_module
from dataloader import MSDataLoader #
from torch.utils.data.dataloader import default_collate
class Data:
def __init__(self, args, model):
kwargs = {}
if not args.cpu:
kwargs['collate_fn'] = default_collate
kwargs['pin_memory'] = True
else:
kwargs['collate_fn'] = default_collate
kwargs['pin_memory'] = False
self.loader_train = None
if not args.test_only:
if args.data_train.lower() != 'rrl':
module_train = import_module('data.' + args.data_train.lower())
trainset = getattr(module_train, args.data_train)(args)
else:
module_train = import_module('data.' + args.rrl_data.lower())
trainclass = getattr(module_train, args.rrl_data)
module_train = import_module('data.rrl')
trainset = getattr(module_train, 'RRL')(trainclass, args, model)
self.loader_train = MSDataLoader(
args,
trainset,
batch_size=args.batch_size,
shuffle=True,
**kwargs
)
if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100']:
if not args.benchmark_noise:
module_test = import_module('data.benchmark')
testset = getattr(module_test, 'Benchmark')(args, train=False)
else:
module_test = import_module('data.benchmark_noise')
testset = getattr(module_test, 'BenchmarkNoise')(
args,
train=False
)
else:
if args.data_test.lower() != 'rrl':
module_test = import_module('data.' + args.data_test.lower())
testset = getattr(module_test, args.data_test)(args, train=False)
else:
module_test = import_module('data.' + args.rrl_data.lower())
testclass = getattr(module_test, args.rrl_data)
module_test = import_module('data.rrl')
testset = getattr(module_test, 'RRL')(testclass, args, model, False)
self.loader_test = MSDataLoader(
args,
testset,
batch_size=1,
shuffle=False,
**kwargs
)
__init__.py
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
相关阅读更多精彩内容
- 原文: http://zengrong.net/post/2192.htm 本站文章除注明转载外,均为本站原创或...
- init.py 文件的作用是将文件夹变为一个Python模块,Python 中的每个模块的包中,都有init.py...
- encoding=utf-8 该文件本身的作用是是python2版本可以识别包,即import testpacke...