[代码解读]pytorch-3dunet-master

源代码:https://github.com/wolny/pytorch-3dunet

Installation

1.使用conda安装

2.运行setup.py 安装
有关setup.py可以看
Python学习笔记|python之setuptools
setup如何打包
如何将自己的Python程序打包--setuptools详解

setup.py

from setuptools import setup, find_packages
# 获得__version__.py里的内容,使得获取到__version__
exec(open('pytorch3dunet/__version__.py').read())
setup(
    name="pytorch3dunet",    # 包名称------------生成的egg名称
    # 自动动态获取packages,默认在和setup.py同一目录下搜索各个含有 init.py的包。exclude:打包的时,排除tests文件
    packages=find_packages(exclude=["tests"]),
    version=__version__,      # (-V) 包版本----生成egg包的版本号
    author="Adrian Wolny, Lorenzo Cerrone",
    url="https://github.com/wolny/pytorch-3dunet",  # 程序的官网地址
    license="MIT",
    python_requires='>=3.7'   # --requires 定义依赖哪些模块
)

训练

train.py

  • def main():

 0   logger = get_logger('UNet3DTrain')
def main():
      # Load and log experiment configuration
 1    config = load_config()
 2    logger.info(config)

 3    manual_seed = config.get('manual_seed', None)
 4    if manual_seed is not None:
 5        logger.info(f'Seed the RNG for all devices with {manual_seed}')
 6        torch.manual_seed(manual_seed)
          # see https://pytorch.org/docs/stable/notes/randomness.html
 7        torch.backends.cudnn.deterministic = True
 8        torch.backends.cudnn.benchmark = False

       # Create the model
 9   model = get_model(config)
     # use DataParallel if more than 1 GPU available
10   device = config['device']
11   if torch.cuda.device_count() > 1 and not device.type == 'cpu':
12        model = nn.DataParallel(model)
13        logger.info(f'Using {torch.cuda.device_count()} GPUs for training')

    # put the model on GPUs
14  logger.info(f"Sending the model to '{config['device']}'")
15  model = model.to(device)

    # Log the number of learnable parameters
16  logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

    # Create loss criterion
17  loss_criterion = get_loss_criterion(config)
    # Create evaluation metric
18  eval_criterion = get_evaluation_metric(config)

    # Create data loaders
19  loaders = get_train_loaders(config)

    # Create the optimizer
20  optimizer = _create_optimizer(config, model)

    # Create learning rate adjustment strategy
22  lr_scheduler = _create_lr_scheduler(config, optimizer)

    # Create model trainer
23  trainer = _create_trainer(config, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler,
24                           loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders)
    # Start training
25  trainer.fit()

1. config = load_config()
打印设备日志

1.config = load_config())

Python中 logging 日志详解

logger = utils.get_logger('ConfigLoader')

def load_config():
    parser = argparse.ArgumentParser(description='UNet3D')
    parser.add_argument('--config', type=str, help='Path to the YAML config file', required=True)
    args = parser.parse_args()
    config = _load_config_yaml(args.config)  # 打开--config的文件
    # Get a device to train on
    device_str = config.get('device', None)
    if device_str is not None:
        logger.info(f"Device specified in config: '{device_str}'")
        -----------
        if device_str.startswith('cuda') and not torch.cuda.is_available():
            logger.warn('CUDA not available, using CPU')
            device_str = 'cpu'
    else:
        device_str = "cuda:0" if torch.cuda.is_available() else 'cpu'
        logger.info(f"Using '{device_str}' device")
        -----------
    device = torch.device(device_str)
    config['device'] = device

def _load_config_yaml(config_file):
    return yaml.safe_load(open(config_file, 'r'))

--------------------------------------------------------------------------------------------------------------------------
import logging
loggers = {}
def get_logger(name, level=logging.INFO):
    global loggers
    if loggers.get(name) is not None:
        return loggers[name]
    else:
        logger = logging.getLogger(name)    # 生成器
        logger.setLevel(level)     # 设置日志级别   #生成器日志级别
        # Logging to console
        stream_handler = logging.StreamHandler(sys.stdout)   # 控制台句柄
        # 格式化对象
        formatter = logging.Formatter(
            '%(asctime)s [%(threadName)s] %(levelname)s %(name)s - %(message)s')
        stream_handler.setFormatter(formatter)    # 绑定格式化对象与控制台句柄
        logger.addHandler(stream_handler)         # 绑定生成器与控制台句柄

        loggers[name] = logger

        return logger

0.logger = get_logger('UNet3DTrain')
2.logger.info(config)

打印--config 输入的文件内容
运行截图:

2.logger.info(config)

9. model = get_model(config)
这一行比较简单,就是获取模型,可以改写model.py添加自己的模型。
5.logger.info(f'Seed the RNG for all devices with {manual_seed}')
14.logger.info(f"Sending the model to '{config['device']}'")
16.logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

运行截图:


Loss有点难,日后再补

17.loss_criterion = get_loss_criterion(config)


18.eval_criterion = get_evaluation_metric(config)
评价指标,没细看,估计也不太难,需要添加自己的指标


19.loaders = get_train_loaders(config)
这行代码调用了很多函数,简单来说就是返回了已经写好patch切片索引,可使用的data。下面的详解可能不对,只是我的浅显理解,欢迎批评指教
from pytorch3dunet.datasets.utils import get_train_loaders

def get_train_loaders(config):
    """
    Returns dictionary containing the training and validation loaders (torch.utils.data.DataLoader).

    :param config: a top level configuration object containing the 'loaders' key
    :return: dict {
        'train': <train_loader>
        'val': <val_loader>
    }
    """
1     assert 'loaders' in config, 'Could not find data loaders configuration'
2     loaders_config = config['loaders']

3     logger.info('Creating training and validation set loaders...')

     # get dataset class
4     dataset_cls_str = loaders_config.get('dataset', None)  # StandardHDF5Dataset
5     if dataset_cls_str is None:
6         dataset_cls_str = 'StandardHDF5Dataset'
7         logger.warn(f"Cannot find dataset class in the config. Using default '{dataset_cls_str}'.")
8     dataset_class = _get_cls(dataset_cls_str)

9     assert set(loaders_config['train']['file_paths']).isdisjoint(loaders_config['val']['file_paths']), \
         "Train and validation 'file_paths' overlap. One cannot use validation data for training!"

10    train_datasets = dataset_class.create_datasets(loaders_config, phase='train')

11    val_datasets = dataset_class.create_datasets(loaders_config, phase='val')

12    num_workers = loaders_config.get('num_workers', 1)
13    logger.info(f'Number of workers for train/val dataloader: {num_workers}')
14    batch_size = loaders_config.get('batch_size', 1)
15    if torch.cuda.device_count() > 1 and not config['device'].type == 'cpu':
16        logger.info(
            f'{torch.cuda.device_count()} GPUs available. Using batch_size = {torch.cuda.device_count()} * {batch_size}')
17        batch_size = batch_size * torch.cuda.device_count()

18    logger.info(f'Batch size for train/val loader: {batch_size}')
    # when training with volumetric data use batch_size of 1 due to GPU memory constraints
19    return {
        'train': DataLoader(ConcatDataset(train_datasets), batch_size=batch_size, shuffle=True,
                            num_workers=num_workers),
        'val': DataLoader(ConcatDataset(val_datasets), batch_size=batch_size, shuffle=True, num_workers=num_workers)
    }

1.assert 'loaders' in config, 'Could not find data loaders configuration'
2.loaders_config = config['loaders']
3.logger.info('Creating training and validation set loaders...')

获得config里loaders相关参数。
运行截图:

logger.info('Creating training and validation set loaders...')

4到8行得到以何种方式加载h5数据文件
8.dataset_class = _get_cls(dataset_cls_str)

def _get_cls(class_name):
    modules = ['pytorch3dunet.datasets.hdf5', 'pytorch3dunet.datasets.dsb', 'pytorch3dunet.datasets.utils']
    for module in modules:
        m = importlib.import_module(module)
        clazz = getattr(m, class_name, None)
        if clazz is not None:
            return clazz
    raise RuntimeError(f'Unsupported dataset class: {class_name}')

getattr(m, class_name)相当于m.class_name

clazz,_get_cls

9.assert set(loaders_config['train']['file_paths']).isdisjoint(loaders_config['val']['file_paths']),
"Train and validation 'file_paths' overlap. One cannot use validation data for training!"

set.isdisjoint():用于判断两个集合是否包含相同的元素
即判断训练集和验证集是不是同一个数据集。
10.train_datasets = dataset_class.create_datasets(loaders_config, phase='train')
调用hdf5.py的StandardHDF5Dataset的AbstractHDF5Dataset(ConfigDataset)的create_datasets。
获取train条件下,transformer,slice_builder,file_paths的配置。
file_paths可能包含文件和目录;如果file_paths是一个目录,那么其中的所有H5文件都将包含在最终的file_paths中

    def create_datasets(cls, dataset_config, phase):
        phase_config = dataset_config[phase]
        transformer_config = phase_config['transformer']
        slice_builder_config = phase_config['slice_builder']
        file_paths = phase_config['file_paths']
        #file_paths可能包含文件和目录;如果file_path是一个目录,那么其中的所有H5文件都将包含在最终的file_path中
        file_paths = cls.traverse_h5_paths(file_paths)

        datasets = []
        for file_path in file_paths:
            try:
                logger.info(f'Loading {phase} set from: {file_path}...')
                dataset = cls(file_path=file_path,
                              phase=phase,
                              slice_builder_config=slice_builder_config,
                              transformer_config=transformer_config,
                              mirror_padding=dataset_config.get('mirror_padding', None),
                              raw_internal_path=dataset_config.get('raw_internal_path', 'raw'),
                              label_internal_path=dataset_config.get('label_internal_path', 'label'),
                              weight_internal_path=dataset_config.get('weight_internal_path', None))
                datasets.append(dataset)
            except Exception:
                logger.error(f'Skipping {phase} set: {file_path}', exc_info=True)
        return datasets

    def traverse_h5_paths(file_paths):
        assert isinstance(file_paths, list) # 确保 file_paths是list类型
        results = []
        for file_path in file_paths:
            if os.path.isdir(file_path):
                # if file path is a directory take all H5 files in that directory
                iters = [glob.glob(os.path.join(file_path, ext)) for ext in ['*.h5', '*.hdf', '*.hdf5', '*.hd5']]
                for fp in chain(*iters):
                    results.append(fp)
            else:
                results.append(file_path)
        return results

调用完create_datasets后可能还顺便调用的init()?
如果是train,val不进行mirror_padding,如果是test进行mirror_padding,相当于给数据四周都加了个边
init()就不详细写了,对数据进行了相应的transform,并把每个patch的位置都标注出来了。
19行
返回了需要的DataLoader


20 optimizer = _create_optimizer(config, model)
这个比较简单,就是正常的optimizer


22 lr_scheduler = _create_lr_scheduler(config, optimizer)
这个比较简单,就是正常的lr_scheduler


23 trainer = _create_trainer(config, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders)

def _create_trainer(config, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders):
    assert 'trainer' in config, 'Could not find trainer configuration'
    trainer_config = config['trainer']
    # 中断后重新加载
    resume = trainer_config.get('resume', None)
    # 预训练
    pre_trained = trainer_config.get('pre_trained', None)
    # 忽略验证,即没有验证集
    skip_train_validation = trainer_config.get('skip_train_validation', False)

    # get tensorboard formatter
#并不知道是干啥的,不过里面包含了最大最小标准化
    tensorboard_formatter = get_tensorboard_formatter(trainer_config.get('tensorboard_formatter', None))

    if resume is not None:
        # continue training from a given checkpoint
        # 中断训练后继续训练
        return UNet3DTrainer.from_checkpoint(resume, model,
                                             optimizer, lr_scheduler, loss_criterion,
                                             eval_criterion, loaders, tensorboard_formatter=tensorboard_formatter)
    elif pre_trained is not None:
        # fine-tune a given pre-trained model
        # 对预训练的模型进行微调
        return UNet3DTrainer.from_pretrained(pre_trained, model, optimizer, lr_scheduler, loss_criterion,
                                             eval_criterion, device=config['device'], loaders=loaders,
                                             max_num_epochs=trainer_config['epochs'],
                                             max_num_iterations=trainer_config['iters'],
                                             validate_after_iters=trainer_config['validate_after_iters'],
                                             log_after_iters=trainer_config['log_after_iters'],
                                             eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'],
                                             tensorboard_formatter=tensorboard_formatter,
                                             skip_train_validation=skip_train_validation)
    else:
        # start training from scratch
        # 从头开始训练
        return UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion,
                             config['device'], loaders, trainer_config['checkpoint_dir'],
                             max_num_epochs=trainer_config['epochs'],
                             max_num_iterations=trainer_config['iters'],
                             validate_after_iters=trainer_config['validate_after_iters'],
                             log_after_iters=trainer_config['log_after_iters'],
                             eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'],
                             tensorboard_formatter=tensorboard_formatter,
                             skip_train_validation=skip_train_validation)

25 trainer.fit()

    def fit(self):
        for _ in range(self.num_epoch, self.max_num_epochs):
            # train for one epoch
            should_terminate = self.train(self.loaders['train'])

            if should_terminate:
                logger.info('Stopping criterion is satisfied. Finishing training')
                return

            self.num_epoch += 1
        logger.info(f"Reached maximum number of epochs: {self.max_num_epochs}. Finishing training...")

train()里面就是正常的训练,根据验证集的评估指标保存模型。并且把learning_rate,loss_avg,eval_score_avg,预测图片等利用TensorBoardX保存下来。


预测

predict.py

predict.py和train.py整体差不多,对每个.nii预测的概率分别进行了保存,但没有进行评估。
emmm....其实后面对输出预测概率的处理没看太懂,只知道他移除了一些像素防止伪影,并且把经过mirror padding的预测图又还原了回去

def main():
    # Load configuration
    config = load_config()

    # Create the model
    model = get_model(config)

    # Load model state
    model_path = config['model_path']
    logger.info(f'Loading model from {model_path}...')
    utils.load_checkpoint(model_path, model)
    # use DataParallel if more than 1 GPU available
    device = config['device']
    if torch.cuda.device_count() > 1 and not device.type == 'cpu':
        model = nn.DataParallel(model)
        logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction')

    logger.info(f"Sending the model to '{device}'")
    model = model.to(device)

    logger.info('Loading HDF5 datasets...')
    for test_loader in get_test_loaders(config):
        logger.info(f"Processing '{test_loader.dataset.file_path}'...")

        output_file = _get_output_file(test_loader.dataset)

        predictor = _get_predictor(model, test_loader, output_file, config)
        # run the model prediction on the entire dataset and save to the 'output_file' H5
        predictor.predict()

参考链接:
Python中的getattr()函数详解
Python Set isdisjoint() 方法

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 219,490评论 6 508
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,581评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 165,830评论 0 356
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,957评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,974评论 6 393
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,754评论 1 307
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,464评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,357评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,847评论 1 317
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,995评论 3 338
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,137评论 1 351
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,819评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,482评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,023评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,149评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,409评论 3 373
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,086评论 2 355