from importlib import import_module
from dataloader import MSDataLoader #
from torch.utils.data.dataloader import default_collate
class Data:
def __init__(self, args, model):
kwargs = {}
if not args.cpu:
kwargs['collate_fn'] = default_collate
kwargs['pin_memory'] = True
else:
kwargs['collate_fn'] = default_collate
kwargs['pin_memory'] = False
self.loader_train = None
if not args.test_only:
if args.data_train.lower() != 'rrl':
module_train = import_module('data.' + args.data_train.lower())
trainset = getattr(module_train, args.data_train)(args)
else:
module_train = import_module('data.' + args.rrl_data.lower())
trainclass = getattr(module_train, args.rrl_data)
module_train = import_module('data.rrl')
trainset = getattr(module_train, 'RRL')(trainclass, args, model)
self.loader_train = MSDataLoader(
args,
trainset,
batch_size=args.batch_size,
shuffle=True,
**kwargs
)
if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100']:
if not args.benchmark_noise:
module_test = import_module('data.benchmark')
testset = getattr(module_test, 'Benchmark')(args, train=False)
else:
module_test = import_module('data.benchmark_noise')
testset = getattr(module_test, 'BenchmarkNoise')(
args,
train=False
)
else:
if args.data_test.lower() != 'rrl':
module_test = import_module('data.' + args.data_test.lower())
testset = getattr(module_test, args.data_test)(args, train=False)
else:
module_test = import_module('data.' + args.rrl_data.lower())
testclass = getattr(module_test, args.rrl_data)
module_test = import_module('data.rrl')
testset = getattr(module_test, 'RRL')(testclass, args, model, False)
self.loader_test = MSDataLoader(
args,
testset,
batch_size=1,
shuffle=False,
**kwargs
)
__init__.py
©著作权归作者所有,转载或内容合作请联系作者
- 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
- 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
- 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
推荐阅读更多精彩内容
- 原文: http://zengrong.net/post/2192.htm 本站文章除注明转载外,均为本站原创或...
- init.py 文件的作用是将文件夹变为一个Python模块,Python 中的每个模块的包中,都有init.py...
- encoding=utf-8 该文件本身的作用是是python2版本可以识别包,即import testpacke...