__init__.py

from importlib import import_module
from dataloader import MSDataLoader      #
from torch.utils.data.dataloader import default_collate

class Data:
    def __init__(self, args, model):
        kwargs = {}
        if not args.cpu:
            kwargs['collate_fn'] = default_collate
            kwargs['pin_memory'] = True
        else:
            kwargs['collate_fn'] = default_collate
            kwargs['pin_memory'] = False

        self.loader_train = None
        if not args.test_only:

            if args.data_train.lower() != 'rrl':
                module_train = import_module('data.' + args.data_train.lower())
                trainset = getattr(module_train, args.data_train)(args)
            else: 
                module_train = import_module('data.' + args.rrl_data.lower())
                trainclass = getattr(module_train, args.rrl_data)
                
                module_train = import_module('data.rrl')
                trainset = getattr(module_train, 'RRL')(trainclass, args, model)

            self.loader_train = MSDataLoader(
                    args,
                    trainset,
                    batch_size=args.batch_size,
                    shuffle=True,
                    **kwargs
                )

        if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100']:
            if not args.benchmark_noise:
                module_test = import_module('data.benchmark')
                testset = getattr(module_test, 'Benchmark')(args, train=False)
            else:
                module_test = import_module('data.benchmark_noise')
                testset = getattr(module_test, 'BenchmarkNoise')(
                    args,
                    train=False
                )

        else:
            if args.data_test.lower() != 'rrl': 
                module_test = import_module('data.' +  args.data_test.lower())
                testset = getattr(module_test, args.data_test)(args, train=False)
            else: 
                module_test = import_module('data.' + args.rrl_data.lower())
                testclass = getattr(module_test, args.rrl_data)

                module_test = import_module('data.rrl')
                testset = getattr(module_test, 'RRL')(testclass, args, model, False)

        self.loader_test = MSDataLoader(
            args,
            testset,
            batch_size=1,
            shuffle=False,
            **kwargs
        )
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容