构建分割模型的基本库 使用几行代码构建和训练用于图像分割的神经网络模型(教程含源码)

网络模型已被证明在解决分割问题方面非常有效,达到了最先进的准确性。它们导致各种应用的显着改进,包括医学图像分析、自动驾驶、机器人技术、卫星图像、视频监控等等。然而,构建这些模型通常需要很长时间,但在阅读本指南后,您只需几行代码就可以构建一个模型。

主要内容

  • 介绍
  • 建筑模块
  • 建立一个模型
  • 训练模型

介绍

分割是根据某些特征或属性将图像分成多个片段或区域的任务。分割模型将图像作为输入并返回分割掩码:

截屏2023-03-07 08.59.23.png

分割神经网络模型由两部分组成:

  • 编码器:获取输入图像并提取特征。编码器的例子有 ResNet、EfficentNet 和 ViT。
  • 解码器:获取提取的特征并生成分割掩码。解码器因架构而异。架构的例子有 U-Net、FPN 和 DeepLab。

因此,在为特定应用构建分割模型时,您需要选择架构和编码器。但是,如果不测试几个,很难选择最佳组合。这通常需要很长时间,因为更改模型需要编写大量样板代码。Segmentation Models库解决了这个问题。它允许您通过指定架构和编码器在一行中创建模型。然后您只需修改该行即可更改其中任何一个。

要从 PyPI 安装最新版本的分段模型,请使用:

pip install segmentation-models-pytorch

建筑模块

该库为大多数分段架构提供了一个类,并且它们中的每一个都可以与任何可用的编码器一起使用。在下一节中,您将看到要构建模型,您需要实例化所选架构的类并将所选编码器的字符串作为参数传递。下图展示了库提供的各个架构的类名:

截屏2023-03-07 09.01.59.png
截屏2023-03-07 09.02.17.png

编码器有 400 多种,因此无法全部显示,但您可以在此处找到完整列表。

https://github.com/qubvel/segmentation_models.pytorch#encoders

建立一个模型

一旦从上图中选择了架构和编码器,构建模型就非常简单:

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet50",        # choose encoder
    encoder_weights="imagenet",     # choose pretrained (not required)
    in_channels=3,                  # model input channels
    classes=10,                     # model output channels
    activation="None"               # None|"sigmoid"|"softmax"; default is None
)

参数:

  • encoder_name是所选编码器的名称(例如 resnet50、efficentnet-b7、mit_b5)。
  • encoder_weights是预训练的数据集。如果encoder_weights等于"imagenet"编码器权重,则使用预训练的 ImageNet 进行初始化。所有的编码器都至少有一个预训练的,这里有一个完整的列表。
  • in_channels是输入图像的通道数(如果是 RGB,则为 3)。
    即使in_channels不是 3,也可以使用预训练的 ImageNet:第一层将通过重新使用预训练的第一个卷积层的权重来初始化(过程在此处描述
  • out_classes是数据集中的类数。
  • activation是输出层的激活函数。可能的选择是None(默认)sigmoidsoftmax
    注意:当使用期望 logits 作为输入的损失函数时,激活函数必须为 None。例如,使用CrossEntropyLoss函数时,activation必须是None.

训练模型

本节显示执行培训所需的所有代码。但是,这个库不会改变通常用于训练和验证模型的管道。为了简化流程,该库提供了许多损失函数的实现,例如Jaccard Loss、Dice Loss、Dice Cross-Entropy Loss、Focal Loss,以及Accuracy、Precision、Recall、F1Score 和 IOUScore 等指标。有关它们及其参数的完整列表,请查看损失和指标部分中的文档。

提议的训练示例是使用Oxford-IIIT Pet Dataset 的二进制分割(它将通过代码下载)。这是数据集中的两个样本:

截屏2023-03-07 09.11.26.png

最后,这些是执行此类分割任务的所有步骤:

1.建立模型。

import os
from pprint import pprint
import torch
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.datasets import SimpleOxfordPetDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# I don't use any activation function on the last layer
# because I set from_logits=True on the DiceLoss
model = smp.FPN(
    encoder_name='efficientnet-b0',
    encoder_weights='imagenet',
    in_channels=3,
    classes=1,
    activation=None
)
model.to(device)

根据您要使用的损失函数设置最后一层的激活函数。

2. 定义参数。

# get_processing_params returns mean and std you should use to normalize the input
params = smp.encoders.get_preprocessing_params('efficientnet-b0')
mean = torch.tensor(params["mean"]).view(1, 3, 1, 1).to(device)
std = torch.tensor(params["std"]).view(1, 3, 1, 1).to(device)

num_epochs = 50
loss_fn = smp.losses.DiceLoss('binary', from_logits=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, verbose=True)

root = 'data'
SimpleOxfordPetDataset.download(root)

train_dataset = SimpleOxfordPetDataset(root, 'train')
val_dataset = SimpleOxfordPetDataset(root, 'valid')

n_cpu = os.cpu_count()
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=n_cpu)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=n_cpu)

请记住,在使用预训练时,应使用用于训练预训练的数据的均值和标准差对输入进行归一化。

3.定义train函数。

def train():
    best_accuracy = 0.0
    for epoch in range(num_epochs):
        mean_loss = 0.0
        for i, batch in enumerate(train_dataloader):
            image = batch["image"].to(device)
            mask = batch["mask"].to(device)
            # normalize input
            image = (image - mean) / std

            optimizer.zero_grad()
            logits_mask = model(image)
            loss = loss_fn(logits_mask, mask)
            loss.backward()
            optimizer.step()

            mean_loss += loss.item()
            print(f'[epoch {epoch + 1}, batch {i + 1}/{len(train_dataloader)}] step_loss: {loss.item():.4f}, mean_loss: {(mean_loss / (i + 1)):.4f}')

        scheduler.step()

        # compute validation metrics of this epoch
        metrics = validate()
        epoch_accuracy = metrics["accuracy"]

        # save the model if accuracy has improved
        if epoch_accuracy > best_accuracy:
            torch.save(model.state_dict(), 'best_model.pth')
            best_accuracy = epoch_accuracy

        print(f'For epoch {epoch + 1} the validation metrics are:')
        pprint(metrics)

与您在不使用库的情况下为训练模型而编写的训练函数相比,此处没有任何变化。

4. 定义验证函数。

def validate():
    with torch.no_grad():
        # total true positives, false positives, true negatives and false negatives
        total_tp, total_fp, total_fn, total_tn = None, None, None, None
        for batch in val_dataloader:
            image = batch["image"].to(device)
            mask = batch["mask"].to(device).long()

            image = (image - mean) / std
            logits_mask = model(image)
            loss = loss_fn(logits_mask, mask)

            # we need to convert the logits to classes to compute metrics
            prob_mask = logits_mask.sigmoid()
            pred_mask = (prob_mask > 0.5).long()

            # computing true positives, false positives, true negatives and false negatives of the batch
            tp, fp, fn, tn = smp.metrics.get_stats(pred_mask, mask, mode="binary")
            total_tp = torch.cat([total_tp, tp]) if total_tp != None else tp
            total_fp = torch.cat([total_fp, fp]) if total_fp != None else fp
            total_fn = torch.cat([total_fn, fn]) if total_fn != None else fn
            total_tn = torch.cat([total_tn, tn]) if total_tn != None else tn

    # metrics are computed using tp, fp, tn, fn values
    metrics = {
        "loss": loss,
        "accuracy": smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro"),
        "precision": smp.metrics.precision(tp, fp, fn, tn, reduction="micro"),
        "recall": smp.metrics.recall(tp, fp, fn, tn, reduction="micro"),
        "f1_score": smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
    }

    return metrics

批次中的真阳性、假阳性、假阴性和真阴性全部加在一起,仅在批次结束时计算指标。请注意,必须先将 logits 转换为类,然后才能计算指标。调用训练函数开始训练。

5.使用模型。

test_dataset = SimpleOxfordPetDataset(root, 'test')
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=n_cpu)
# take a single batch
batch = next(iter(test_dataloader))

model.load_state_dict(torch.load("best_model.pth"))
with torch.no_grad():
    model.eval()
    image = batch["image"].to(device)
    mask = batch["mask"].to(device).long()
    image_norm = (image - mean) / std
    logits = model(image_norm)
pred_mask = logits.sigmoid()

for i, (im, pr, gt) in enumerate(zip(image, pred_mask, mask)):
    fig, axes = plt.subplots(1, 3, figsize=(9, 3))
    # show input
    axes[0].imshow(im.cpu().numpy().transpose(1, 2, 0))
    axes[0].set_title("Image")
    axes[0].get_xaxis().set_visible(False)
    axes[0].get_yaxis().set_visible(False)
    # show prediction
    axes[1].imshow(pr.cpu().numpy().squeeze())
    axes[1].set_title("Prediction")
    axes[1].get_xaxis().set_visible(False)
    axes[1].get_yaxis().set_visible(False)
    # show target
    axes[2].imshow(gt.cpu().numpy().squeeze())
    axes[2].set_title("Ground truth")
    axes[2].get_xaxis().set_visible(False)
    axes[2].get_yaxis().set_visible(False)

    plt.tight_layout()
    plt.savefig(f"pred_{i}.png")

这些是一些细分:


截屏2023-03-07 10.19.32.png

结束语

这个库拥有你进行分割实验所需的一切。构建模型和应用更改非常容易,并且提供了大多数损失函数和指标。此外,使用这个库不会改变我们习惯的管道。有关详细信息,请参阅官方文档。我还在参考资料中包含了一些最常见的编码器和架构。

项目参考文献

[1] O. Ronneberger, P. Fischer and T. Brox, U-Net: Convolutional Networks for Biomedical Image Segmentation (2015)

[2] Z. Zhou, Md. M. R. Siddiquee, N. Tajbakhsh and J. Liang, UNet++: A Nested U-Net Architecture for Medical Image Segmentation (2018)

[3] L. Chen, G. Papandreou, F. Schroff, H. Adam, Rethinking Atrous Convolution for Semantic Image Segmentation (2017)

[4] L. Chen, Y. Zhu, G. Papandreou, F. Schroff, H. Adam, Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation (2018)

[5] R. Li, S. Zheng, C. Duan, C. Zhang, J. Su, P.M. Atkinson, Multi-Attention-Network for Semantic Segmentation of Fine Resolution Remote Sensing Images (2020)

[6] A. Chaurasia, E. Culurciello, LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation (2017)

[7] T. Lin, P. Dollár, R. Girshick, K. He, B. Hariharan, S. Belongie, Feature Pyramid Networks for Object Detection (2017)

[8] H. Zhao, J. Shi, X. Qi, X. Wang, J. Jia, Pyramid Scene Parsing Network (2016)

[9] H. Li, P. Xiong, J. An, L. Wang, Pyramid Attention Network for Semantic Segmentation (2018)

[10] K. Simonyan, A. Zisserman, Very Deep Convolutional Networks for Large-Scale Image Recognition (2014)

[11] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun, Deep Residual Learning for Image Recognition (2015)

[12] S. Xie, R. Girshick, P. Dollár, Z. Tu, K. He, Aggregated Residual Transformations for Deep Neural Networks (2016)

[13] J. Hu, L. Shen, S. Albanie, G. Sun, E. Wu, Squeeze-and-Excitation Networks (2017)

[14] G. Huang, Z. Liu, L. van der Maaten, K. Q. Weinberger, Densely Connected Convolutional Networks (2016)

[15] M. Tan, Q. V. Le, EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks (2019)

[16] E. Xie, W. Wang, Z. Yu, A. Anandkumar, J. M. Alvarez, P. Luo, SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (2021)
还有 0% 的精彩内容
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
支付 ¥9.90 继续阅读
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,692评论 6 501
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,482评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,995评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,223评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,245评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,208评论 1 299
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,091评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,929评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,346评论 1 311
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,570评论 2 333
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,739评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,437评论 5 344
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,037评论 3 326
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,677评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,833评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,760评论 2 369
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,647评论 2 354

推荐阅读更多精彩内容