医学分割的Tricks


将医学分割模型分为六个阶段,分别是预训练模型数据预处理数据扩增模型部署模型推理结果后处理:这些Tricks很重要,可以显著地影响模型的识别性能;某些Tricks具有跨数据、跨模态、跨模型的迁移性;Tricks本质上对应的是解决了语义分割中的一些Challenges,因此在部署Tricks的时候需要考虑到想要解决的Challenges之间的内在关联;基于这些Tricks的实验结论可以帮助我们将来在解决2D和3D图像语义分割Challenges的时候提供实际的指导。

预训练权重

一般来说,与从头开始训练的基线模型相比 (如果不使用预先训练过的权重),微调后的模型可以获得更好的整体性能。这一观察结果不仅验证了预训练的权重的有效性,而且验证了不同的预训练的权重的影响是可变的。

数据预处理

数据预处理是获得令人满意的性能的必要条件。3D-UNet中四种常用的图像预处理技巧:patching、过采样(OverSam)、重采样(ReSam)和强度归一化(IntesNorm)。

Patching

一些特定类别的医学图像( MRI和病理图像)在空间尺寸上往往非常大,在定量上缺乏足够的训练样本。因此,直接使用这些图像来训练模型是不切实际的。相反,人们通常在更小的空间尺度上将整个图像重新采样成不同的图像补丁,这样模型就可以用更少的GPU内存成本实现,并且可以得到更好的训练。直观地看,补丁大小是影响模型性能的最重要因素之一。一般而言,随着补丁尺寸的增加,模型的性能增益逐渐增加。其可以通过RandomCrop实现,集成到transform里面。

class RandomCrop(object):
    """
    Crop randomly the image in a sample
    Args:
    output_size (int): Desired output size
    """

    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        # pad the sample if necessary
        if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \
                self.output_size[2]:
            pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0)
            ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
            pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)
            image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
            label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)

        (w, h, d) = image.shape
        w1 = np.random.randint(0, w - self.output_size[0])
        h1 = np.random.randint(0, h - self.output_size[1])
        d1 = np.random.randint(0, d - self.output_size[2])

        label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        return {'image': image, 'label': label}

推理的时候取重叠的patch,拼接时重叠处取平均,也有高斯加权的方法(高斯权重码拼接Patch)。

  def test_single_case(self, image):
        w, h, d = image.shape
        tta = TTA(if_flip=self.opt.test['flip'], if_rot=self.opt.test['rotate'])
        patch_size = self.opt.model['input_size']
        stride_xy = patch_size[0]//2
        stride_z = patch_size[2]//2
        # if the size of image is less than patch_size, then padding it
        add_pad = False
        if w < patch_size[0]:
            w_pad = patch_size[0]-w
            add_pad = True
        else:
            w_pad = 0
        if h < patch_size[1]:
            h_pad = patch_size[1]-h
            add_pad = True
        else:
            h_pad = 0
        if d < patch_size[2]:
            d_pad = patch_size[2]-d
            add_pad = True
        else:
            d_pad = 0
        wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
        hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
        dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
        if add_pad:
            image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
        ww,hh,dd = image.shape

        sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
        sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
        sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
        # print("{}, {}, {}".format(sx, sy, sz))
        score_map = np.zeros((self.opt.model['num_class'], ) + image.shape).astype(np.float32)
        cnt = np.zeros(image.shape).astype(np.float32)

        for x in range(0, sx):
            xs = min(stride_xy*x, ww-patch_size[0])
            for y in range(0, sy):
                ys = min(stride_xy * y,hh-patch_size[1])
                for z in range(0, sz):
                    zs = min(stride_z * z, dd-patch_size[2])
                    test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
                    # apply tta
                    test_patch_list = tta.img_list(test_patch)
                    y_list = []
                    for img in test_patch_list:
                        img = np.expand_dims(np.expand_dims(img,axis=0),axis=0).astype(np.float32)
                        img = torch.from_numpy(img).cuda()
                        if not self.opt.train['deeps']:
                            y = self.net(img)
                        else:
                            y = self.net(img)[0]
                        y = F.softmax(y, dim=1)
                        y = y.cpu().detach().numpy()
                        y = np.squeeze(y)
                        y_list.append(y)
                    y_list = tta.img_list_inverse(y_list)
                    y = np.mean(y_list, axis=0)
                    score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
                    = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
                    cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
                    = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
        score_map = score_map/np.expand_dims(cnt,axis=0)
        label_map = np.argmax(score_map, axis = 0)
        if add_pad:
            label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
            score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
        return label_map, score_map

过采样(OverSam)

提出了OverSam策略来解决正负样本之间类别失衡的问题。OverSam主要用于少数类样本。目前,已经提出了一组OverSam方案——随机过采样、合成少数过采、边界采样和自适应合成采样。大量的实验结果验证了OverSam方案不影响模型的斜率,但可以放大模型的截距。一种普遍的过采样策略:70%的被选择的训练样本来自随机的图像位置,而30%的图像切片保证持有至少一个训练前景类。这样,每个训练样本可以同时包含一个训练前景图像补丁和一个随机图像补丁。

class SelectedCrop(object):
    def __init__(self, output_size, oversample_foreground_percent=0.3):
        self.output_size = output_size
        self.percent = oversample_foreground_percent

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        if np.random.random() < self.percent:
            pixels = np.argwhere(label != 0)
            if len(pixels) == 0:
                return RandomCrop(self.output_size)(sample)
            else:
                selected_pixel = pixels[np.random.choice(len(pixels))]
                pw = self.output_size[0] // 2 + 1
                ph = self.output_size[1] // 2 + 1
                pd = self.output_size[2] // 2 + 1

                image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
                label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
                bbox_x_lb = selected_pixel[0]
                bbox_y_lb = selected_pixel[1]
                bbox_z_lb = selected_pixel[2]

                label = label[bbox_x_lb:bbox_x_lb + self.output_size[0], bbox_y_lb:bbox_y_lb + self.output_size[1], bbox_z_lb:bbox_z_lb + self.output_size[2]]
                image = image[bbox_x_lb:bbox_x_lb + self.output_size[0], bbox_y_lb:bbox_y_lb + self.output_size[1], bbox_z_lb:bbox_z_lb + self.output_size[2]]
                return {'image': image, 'label': label}

        else:
            return RandomCrop(self.output_size)(sample)

重采样(ReSam)

ReSam策略,通过机器学习模型来提高所使用数据集的表征能力。由于可用的样本能力有时是有限的和异质的,因此可以通过随机/非随机ReSam策略获得更好的子样本数据集。在其实现中,ReSam主要包括四个步骤:1)间隔插值;2)窗口变换;3)掩模有效范围的获取,4)子图像的生成。基于重组后的子样本数据集, 可以训练一个性能更好的识别模型。

在医学图像中,重采样是指将医疗图像中大小不同的体素归一化到相同的大小。体素是体积元素(Volume Pixel)的简称,一张3D医学图像可以看成是由若干个体素构成的,体素是一张3D医疗图像在空间上的最小单元。
重采样过程:Spacing(0.7422, 0.7422, 8.0)表示的是原始图像体素的大小,也可以将Spacing想象成大小为(0.7422, 0.7422, 8.0)的长方体。而原始图像的Size为 (512, 512, 22),表示的是原始在X轴,Y轴,Z轴中体素的个数。原始图像的大小对应的Spacing既可以得到真实3D图像大小(512*0.7422,512*0.7422,8*22 ),在图像重采样只是修改体素的大小,而真实3D图像大小是保持不变的,因此假设我们将Spacing修改成(1.4844, 1.4844, 2.75)的时候,则修改之后其对应的size应该为((512*0.7422)/ 1.4844,(512*0.7422)/ 1.4844,(22*8)/ 2.75)即(256, 256, 64)。

def resample(self, image, label, spacing, new_spacing=[1,1,1]):
      spacing, new_spacing = np.array(spacing), np.array(new_spacing)
      resize_factor = spacing / new_spacing
      old_shape = np.array(image.shape)
      new_real_shape = old_shape * resize_factor
      new_shape = np.round(new_real_shape)
      real_resize_factor = new_shape / old_shape
      new_spacing = spacing / real_resize_factor

      # image = scipy.ndimage.interpolation.zoom(image, real_resize_factor, mode='nearest')
      image = np.moveaxis(image, 0, -1) # x, y, z
      label = np.moveaxis(label, 0, -1) # x, y, z
      image = np.expand_dims(image, axis=0) # 1, x, y, z
      label = np.expand_dims(label, axis=0) # 1, x, y, z

      # target_size = tuple(new_shape.transpose(1,2,0).astype(int)) # x, y, z
      target_size = tuple(np.array((new_shape[1], new_shape[2], new_shape[0])).astype(int)) # x, y, z
      out_img, out_seg = augment_resize(sample_data=image, sample_seg=label, target_size=target_size)
      out_img, out_seg = out_img[0], out_seg[0] # x,y,z
      return out_img, out_seg

强度归一化(IntesNorm)

IntesNorm是一种规范化该策略主要用于医学图像。通常有两种常用的IntesNorm方法:

  • CT:通过统计整个数据集中mask内像素的HU值范围,clip出[0.05,99.5]百分比范围的HU值范围,然后使用z-score方法进行归一化;
  • MR:对每个患者数据单独执行z-score归一化。如果crop导致数据集的平均尺寸减小到1/4甚至更小,则只在mask内执行标准化,mask设置为0。
def _get_voxels_in_foreground(self,voxels,label):
    mask = label> 0
    # image = list(voxels[mask][::10]) # no need to take every voxel
    image = list(voxels[mask])
    median = np.median(image)
    mean = np.mean(image)
    sd = np.std(image)
    percentile_99_5 = np.percentile(image, 99.5)
    percentile_00_5 = np.percentile(image, 00.5)
    return percentile_99_5,percentile_00_5, median,mean,sd

def do_preprocessing(self, minimun=0, maxmun=0, new_spacing=(3.22, 1.62, 1.62)):
    maybe_mkdir_p(self.out_base_preprocess)
    self.data_info = pickle.load(open(join(self.out_base_raw, 'dataset_pro.pkl'), 'rb'))
    for i in range(len(self.data_info['patient_names'])):
        print(f"Preprocessing {i}/{len(self.data_info['patient_names'])}")
        # voxels = self.images[i]
        # label = self.labels[i]
        voxels = np.load(join(self.out_base_raw, "imagesTr", self.data_info['patient_names'][i] + "_image.npy"))
        label = np.load(join(self.out_base_raw, "imagesTr", self.data_info['patient_names'][i] + "_label.npy"))
        if minimun:
            lower_bound = minimun
            upper_bound = maxmun
        else:
            upper_bound, lower_bound, median, mean_before, sd_before = self._get_voxels_in_foreground(voxels, label)
        voxels = np.clip(voxels, lower_bound, upper_bound)
        ### Convert to [0, 1]
        voxels = (voxels - voxels.min()) / (voxels.max() - voxels.min())

        # resample to isotropic voxel size
        spacing = self.data_info['dataset_properties'][self.data_info['patient_names'][i]]['spacing']
        spacing = (spacing[2], spacing[0], spacing[1])

        voxels, label = self.resample(voxels, label, spacing, new_spacing)
        np.save(join(self.out_base_preprocess, self.data_info['patient_names'][i] + "_image.npy"),
                voxels.astype(np.float32))
        np.save(join(self.out_base_preprocess, self.data_info['patient_names'][i] + "_label.npy"), label)
    save_pickle(self.data_info, join(self.out_base_preprocess, 'dataset_pro.pkl'))
    with open(self.out_base_preprocess + '/all.txt', 'w') as f:
        for train_patient in self.data_info['patient_names']:
            f.write(train_patient)
            f.write('\n')

数据增强

数据增强技术是计算机视觉领域中最基本的技术之一。它通常用于定量处理训练样本不足的问题,可以用来缓解过拟合问题,给出了较强的模型泛化能力,并赋予了鲁棒性。特别是对于医学图像,数据增强通常用于解决数据短缺问题。所使用的数据增强方案主要可分为以下两类:基于几何变换的数据增强(GTAug)和基于生成式对抗网络(GAN)的数据增强(GANAug)。

GTAug

包含两种类型的数据增强,分别为 GTAug-A (像素级变换)和GTAug-B(空间级变换):GTAug-A中包括随机亮度对比、随机噪声、随机伽马和CLAHE,GTAug-B中包括位移、尺度、旋转、水平翻转和垂直翻转等。

class RandomScale(object):
    def __init__(self, scale_factor):
        self.scale_factor = scale_factor
    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        scale = np.random.uniform(self.scale_factor[0], self.scale_factor[1])
        new_shape = (int(image.shape[0] * scale), int(image.shape[1] * scale), int(image.shape[2] * scale))
        image = np.expand_dims(image, axis=0)
        label = np.expand_dims(label, axis=0)
        image, label = augment_resize(image, label, new_shape)
        image = np.squeeze(image)
        label = np.squeeze(label)
        return {'image': image, 'label': label}

class RandomRotation(object):
    """
    Crop randomly flip the dataset in a sample
    Args:
    output_size (int): Desired output size
    """

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        k = np.random.randint(0, 4)
        image = np.rot90(image, k)
        label = np.rot90(label, k)
        return {'image': image, 'label': label}


class RandomMirroring(object):
    def __init__(self, axes=(0, 1, 2)):
        self.axes = axes
    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        if 0 in self.axes  and np.random.uniform() < 0.5:
            image[:] = image[::-1]
            label[:] = label[::-1]
        if 1 in self.axes  and np.random.uniform() < 0.5:
            image[:, :] = image[:,::-1]
            label[:, :] = label[:,::-1]
        if 2 in self.axes  and np.random.uniform() < 0.5:
            image[:, :, :] = image[:, :, ::-1]
            label[:, :, :] = label[:, :, ::-1]
        return {'image': image, 'label': label}


class RandomNoise(object):
    def __init__(self, mu=0, sigma=0.1):
        self.mu = mu
        self.sigma = sigma

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        noise = np.clip(self.sigma * np.random.randn(image.shape[0], image.shape[1], image.shape[2]), -2*self.sigma, 2*self.sigma)
        noise = noise + self.mu
        image = image + noise
        return {'image': image, 'label': label}


class GammaAdjust(object):
    def __init__(self, gamma_range=(0.5, 2), epsilon=1e-7,retain_stats = False):
        self.gamma_range = gamma_range
        self.epsilon = epsilon
        self.retain_stats = retain_stats

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        if self.retain_stats:
            mn = image.mean()
            sd = image.std()
        if np.random.random() < 0.5 and self.gamma_range[0] < 1:
            gamma = np.random.uniform(self.gamma_range[0], 1)
        else:
            gamma = np.random.uniform(max(self.gamma_range[0], 1), self.gamma_range[1])
        minm = image.min()
        rnge = image.max() - minm
        image = np.power(((image - minm) / float(rnge + self.epsilon)), gamma) * rnge + minm
        if self.retain_stats:
            image = image - image.mean() + mn
            image = image / (image.std() + 1e-8) * sd
        return {'image': image, 'label': label}

GANAug

数据增强的内在前提是将领域知识或其他增量信息引入训练数据集。从这方面来看,GANAug可以被看作是一个损失函数,它专注于引导网络生成一些接近源数据集域的真实数据。特别是由于医学图像中的数据集较小,GAN生成模型拟合的数据分布优于判别模型。基于条件对抗网络的经典像素级图像到图像转换(pix2pix)可以被用于其默认设置的数据增强。
GTAug-A和GAug-B比GANAug获得更好的实验结果,证明了位移尺度旋转、水平翻转和垂直翻转在医学图像分割中的重要性。当面对不同的数据集时,还是应该选择合适的数据增强技巧。

模型实现

模型实现技巧对于医学分割模型至关重要。 三类常用的实现技巧:深度监督(DeepS);类平衡损失(CBL),其中包括四个损失函数(CBL_{Dice}CBL_{Focal}CBL_{Tvers}CBL_{WCE})和实例规范化(IntNorm)。

DeepS

DeepS是DSN中提出的一种辅助学习技巧,通过在一些中间隐藏层上以直接或间接的方式添加一个辅助分类器或分割器来实现监督主干网络的。它可用于解决训练梯度消失或收敛速度较慢的问题。对于图像分割,这个技巧通常通过添加图像级分类损失来实现。 可以从最后三个解码器层中提取特征图,并使用1*1卷积层将掩膜投射到相同的通道大小中。然后,通过双线性插值将分割头网络不同层的输出特征图上采样到与输入图像相同的空间大小。

CBL

CBL通常用于学习一般的类权重,每个类的权重只与对象类别相关。与一些传统的分割损失函数(交叉熵损失) 相比,在类不平衡数据集上CBL可以提高模型的表示能力。在所使用的数据集中,CBL引入了有效样本的数量来表示所选数据集的期望体量表示,并通过有效样本的数量而不是原始样本的数量来加权不同的类。四种常用的医学图像领域的CBL损失函数,包括骰子损失(CBL_{Dice})、焦点损失(CBL_{Focal}) ,Tversky损失(CBL_{Tvers})和加权交叉熵损失(CBL_{WCE})。

class CELoss(nn.Module):
    def __init__(self, weight=None, reduction='mean'):
        self.weight = weight
        self.reduction = reduction

    def __call__(self, y_pred, y_true):
        y_true = y_true.long()
        if self.weight is not None:
            self.weight = self.weight.to(y_pred.device)
        if len(y_true.shape) == 5:
            y_true = y_true[:, 0, ...]
        loss = nn.CrossEntropyLoss(weight=self.weight, reduction=self.reduction)
        return loss(y_pred, y_true)


class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-8):
        super(DiceLoss, self).__init__()
        
        self.smooth = smooth

    def forward(self, y_pred, y_true):
        # first convert y_true to one-hot format
        axis = identify_axis(y_pred.shape)
        y_pred = nn.Softmax(dim=1)(y_pred)
        tp, fp, fn, _ = get_tp_fp_fn_tn(y_pred, y_true, axis)
        intersection = 2 * tp + self.smooth
        union = 2 * tp + fp + fn + self.smooth
        dice = 1 - (intersection / union)
        return dice.mean()


# taken from https://github.com/JunMa11/SegLoss/blob/master/test/nnUNetV2/loss_functions/focal_loss.py
class FocalLoss(nn.Module):
    """
    copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
    This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
    'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
        Focal_Loss= -1*alpha*(1-pt)*log(pt)
    :param num_class:
    :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
    :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
                    focus on hard misclassified example
    :param smooth: (float,double) smooth value when cross entropy
    :param balance_index: (int) balance class index, should be specific when alpha is float
    :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
    """

    def __init__(self, apply_nonlin=None, alpha=0.25, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
        super(FocalLoss, self).__init__()
        self.apply_nonlin = apply_nonlin
        self.alpha = alpha
        self.gamma = gamma
        self.balance_index = balance_index
        self.smooth = smooth
        self.size_average = size_average

        if self.smooth is not None:
            if self.smooth < 0 or self.smooth > 1.0:
                raise ValueError('smooth value should be in [0,1]')

    def forward(self, logit, target):
        if self.apply_nonlin is not None:
            logit = self.apply_nonlin(logit)
        num_class = logit.shape[1]

        if logit.dim() > 2:
            # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
            logit = logit.view(logit.size(0), logit.size(1), -1)
            logit = logit.permute(0, 2, 1).contiguous()
            logit = logit.view(-1, logit.size(-1))
        target = torch.squeeze(target, 1)
        target = target.view(-1, 1)
        alpha = self.alpha

        if alpha is None:
            alpha = torch.ones(num_class, 1)
        elif isinstance(alpha, (list, np.ndarray)):
            assert len(alpha) == num_class
            alpha = torch.FloatTensor(alpha).view(num_class, 1)
            alpha = alpha / alpha.sum()
        elif isinstance(alpha, float):
            alpha = torch.ones(num_class, 1)
            alpha = alpha * (1 - self.alpha)
            alpha[self.balance_index] = self.alpha

        else:
            raise TypeError('Not support alpha type')

        if alpha.device != logit.device:
            alpha = alpha.to(logit.device)

        idx = target.cpu().long()

        one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
        one_hot_key = one_hot_key.scatter_(1, idx, 1)
        if one_hot_key.device != logit.device:
            one_hot_key = one_hot_key.to(logit.device)

        if self.smooth:
            one_hot_key = torch.clamp(
                one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
        pt = (one_hot_key * logit).sum(1) + self.smooth
        logpt = pt.log()

        gamma = self.gamma

        alpha = alpha[idx]
        alpha = torch.squeeze(alpha)
        loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt

        if self.size_average:
            loss = loss.mean()
        else:
            loss = loss.sum()
        return loss


class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, eps=1e-7):
        super(TverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.eps = eps

    def forward(self, y_pred, y_true):
        axis = identify_axis(y_pred.shape)
        y_pred = nn.Softmax(dim=1)(y_pred)
        y_true = to_onehot(y_pred, y_true)
        y_pred = torch.clamp(y_pred, self.eps, 1. - self.eps)
        tp, fp, fn, _ = get_tp_fp_fn_tn(y_pred, y_true, axis)
        tversky = (tp + self.eps) / (tp + self.eps + self.alpha * fn + self.beta * fp)
        return (y_pred.shape[1] - tversky.sum()) / y_pred.shape[1]

def to_onehot(y_pred, y_true):
    shp_x = y_pred.shape
    shp_y = y_true.shape
    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            y_true = y_true.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(y_pred.shape, y_true.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = y_true 
        else:
            y_true = y_true.long()
            y_onehot = torch.zeros(shp_x, device=y_pred.device)
            y_onehot.scatter_(1, y_true, 1)
    return y_onehot

def get_tp_fp_fn_tn(net_output, gt, axes=None, square=False):
    """
    net_output must be (b, c, x, y(, z)))
    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
    if mask is provided it must have shape (b, 1, x, y(, z)))
    :param net_output:
    :param gt:
    :return:
    """
    if axes is None:
        axes = tuple(range(2, len(net_output.size())))

    y_onehot = to_onehot(net_output, gt)

    tp = net_output * y_onehot
    fp = net_output * (1 - y_onehot)
    fn = (1 - net_output) * y_onehot
    tn = (1 - net_output) * (1 - y_onehot)

    if square:
        tp = tp ** 2
        fp = fp ** 2
        fn = fn ** 2
        tn = tn ** 2

    if len(axes) > 0:
        tp = sum_tensor(tp, axes, keepdim=False)
        fp = sum_tensor(fp, axes, keepdim=False)
        fn = sum_tensor(fn, axes, keepdim=False)
        tn = sum_tensor(tn, axes, keepdim=False)

    return tp, fp, fn, tn

def sum_tensor(inp, axes, keepdim=False):
    axes = np.unique(axes).astype(int)
    if keepdim:
        for ax in axes:
            inp = inp.sum(int(ax), keepdim=True)
    else:
        for ax in sorted(axes, reverse=True):
            inp = inp.sum(int(ax))
    return inp

def identify_axis(shape):
    """
    Helper function to enable loss function to be flexibly used for 
    both 2D or 3D image segmentation - source: https://github.com/frankkramer-lab/MIScnn
    """
    # Three dimensional
    if len(shape) == 5 : return [2,3,4]
    # Two dimensional
    elif len(shape) == 4 : return [2,3]
    # Exception - Unknown
    else : raise ValueError('Metric: Shape of tensor is neither 2D or 3D.')

OHEM

OHEM主要思想是,根据输入样本的损失进行筛选,筛选出hard example,表示对分类和检测影响较大的样本,然后将筛选得到的这些样本应用在随机梯度下降中训练。在实际操作中是将原来的一个ROI Network扩充为两个ROI Network,这两个ROI Network共享参数。其中前面一个ROI Network只有前向操作,主要用于计算损失;后面一个ROI Network包括前向和后向操作,以hard example作为输入,计算损失并回传梯度。这种算法的优点在于,对于数据的类别不平衡问题不需要采用设置正负样本比例的方式来解决,且随着数据集的增大,算法的提升更加明显。

class OHEMLoss(nn.CrossEntropyLoss):
    """
    Network has to have NO LINEARITY!
    """
    def __init__(self, weight=None, ignore_index=-100, k=0.7):
        super(OHEMLoss, self).__init__()
        self.k = k
        self.weight = weight
        self.ignore_index = ignore_index

    def forward(self, y_pred, y_true):
        res = CELoss(reduction='none')(y_pred, y_true)
        num_voxels = np.prod(res.shape, dtype=np.int64)
        res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k), sorted=False)
        return res.mean()

IntNorm

IntNorm是一种流行的归一化算法,适用于单个像素要求较高的识别任务。在医学图像领域,使用IntNorm的一个重要原因是,在训练过程中,批处理大小通常被设置为一个很小的值(特别是对于3D图像),这使得使用批处理归一化无效。2D技巧不适合3D数据集的现象提醒我们,在未来的模型设计中需要考虑数据集的格式。

模型推理

两种常用的推理技巧 ,即测试时间增强(TTA)和模型集成。这两个技巧的实现细 如下:TTA是目前模型推理阶段流行的数据增强机制。TTA不需要训练就可以用来提高识别性能,因此它有潜力成为一种即插即用的产品。同时,可以提高模型校准能力,有利于视觉任务。从三个方面遵循与相同的图像增强策略: 1)在基线模型上实施TTA策略(TTA_{baseline});2)TTA_{GTAug-A};3)TTA_{GTAug-B}。集成模型集成策略旨在统一多个训练模型,基于一定的集成制在测试集上实现多模型融合结果,使最终结果能够从每个模型中学习,提高整体泛化能力。常用的模型集成 方法有投票、平均、叠加和非交叉叠加(混合)。

class TTA():
    def __init__(self, if_tta):
        # for ISIC, the shape is (b, c, h, w)
        # for Kit, the shape is (x, y, z)
        self.if_tta = if_tta

    def img_list(self, img):
        out = []
        out.append(img)
        if not self.if_tta:
            return out
        # apply flip
        for i in range(3):
            out.append(np.flip(img, axis=i))
        # apply rotation
        for i in range(1, 4):
            out.append(np.rot90(img, k=i))
        return out
    
    def img_list_inverse(self, img_list):
        out = [img_list[0]]
        if not self.if_tta:
            return img_list
        # apply flip
        for i in range(3):
            out.append(np.flip(img_list[i+1], axis=i))
        if len(img_list) > 4:
            # apply rotation
            for i in range(3):
                out.append(np.rot90(img_list[i+4], k=-(i+1), axes=(1,2)))
        return out



class TTA_2d():
    def __init__(self, flip=False, rotate=False):
        self.flip = flip
        self.rotate = rotate

    def img_list(self, img):
        # for ISIC, the shape is torch.size(b, c, h, w)
        img = img.detach().cpu().numpy()
        out = []
        out.append(img)
        if self.flip:
            # apply flip
            for i in range(2,4):
                out.append(np.flip(img, axis=i))
        if self.rotate:
            # apply rotation
            for i in range(1, 4):
                out.append(np.rot90(img, k=i, axes=(2,3)))
        return out
    
    def img_list_inverse(self, img_list):
        # for ISIC, the shape is numpy(b, h, w)
        out = [img_list[0]]
        if self.flip:
            # apply flip
            for i in range(2):
                out.append(np.flip(img_list[i+1], axis=i+1))
        if self.rotate:
            # apply rotation
            for i in range(3):
                out.append(np.rot90(img_list[i+3], k=-(i+1), axes=(1,2)))
        return out

结果后处理

后处理操作的目的主要是通过不可学习的方法来提高模型性能。例如分割结果可以通过聚合全局图像信息来进行细化。医学图像分析领域的两种常用的结果后处理方案:最大成分抑制(ABL-CS), 和去除小区域(RSA)。 ABL-CS。ABL-CS的目的是基于有机体物理特性的知识,去除分割结果中的一些错误区域。例如,对于心脏分割任务,我们都知道每个人只有一个心脏,所以如果在获得的掩模中有小的分割区域,我们需要去除这个小区域。 RSA :设置一个像素级的阈值来删除一些太小的实例掩码。

import imp
import skimage.morphology as morph
import numpy as np
from scipy.ndimage import label

def abl(image: np.ndarray, for_which_classes: list, volume_per_voxel: float = None,
                                                   minimum_valid_object_size: dict = None):
    """
    removes all but the largest connected component, individually for each class
    :param image:
    :param for_which_classes: can be None. Should be list of int. Can also be something like [(1, 2), 2, 4].
    Here (1, 2) will be treated as a joint region, not individual classes (example LiTS here we can use (1, 2)
    to use all foreground classes together)
    :param minimum_valid_object_size: Only objects larger than minimum_valid_object_size will be removed. Keys in
    minimum_valid_object_size must match entries in for_which_classes
    :return:
    """
    if for_which_classes is None:
        for_which_classes = np.unique(image)
        for_which_classes = for_which_classes[for_which_classes > 0]
    assert 0 not in for_which_classes, "cannot remove background"

    if volume_per_voxel is None:
        volume_per_voxel = 1

    largest_removed = {}
    kept_size = {}
    for c in for_which_classes:
        if isinstance(c, (list, tuple)):
            c = tuple(c)  # otherwise it cant be used as key in the dict
            mask = np.zeros_like(image, dtype=bool)
            for cl in c:
                mask[image == cl] = True
        else:
            mask = image == c
        # get labelmap and number of objects
        lmap, num_objects = label(mask.astype(int))

        # collect object sizes
        object_sizes = {}
        for object_id in range(1, num_objects + 1):
            object_sizes[object_id] = (lmap == object_id).sum() * volume_per_voxel

        largest_removed[c] = None
        kept_size[c] = None

        if num_objects > 0:
            # we always keep the largest object. We could also consider removing the largest object if it is smaller
            # than minimum_valid_object_size in the future but we don't do that now.
            maximum_size = max(object_sizes.values())
            kept_size[c] = maximum_size

            for object_id in range(1, num_objects + 1):
                # we only remove objects that are not the largest
                if object_sizes[object_id] != maximum_size:
                    # we only remove objects that are smaller than minimum_valid_object_size
                    remove = True
                    if minimum_valid_object_size is not None:
                        remove = object_sizes[object_id] < minimum_valid_object_size[c]
                    if remove:
                        image[(lmap == object_id) & mask] = 0
                        if largest_removed[c] is None:
                            largest_removed[c] = object_sizes[object_id]
                        else:
                            largest_removed[c] = max(largest_removed[c], object_sizes[object_id])
    # return image, largest_removed, kept_size
    return image


def rsa(image: np.array, for_which_classes: list, volume_per_voxel: float = None, minimum_valid_object_size: dict = None):
    """
    Remove samll objects, smaller than minimum_valid_object_size, individually for each class
    :param image:
    :param for_which_classes: can be None. Should be list of int. Can also be something like [(1, 2), 2, 4].
    Here (1, 2) will be treated as a joint region, not individual classes (example LiTS here we can use (1, 2)
    to use all foreground classes together)
    :param minimum_valid_object_size: Only objects larger than minimum_valid_object_size will be removed. Keys in
    minimum_valid_object_size must match entries in for_which_classes
    :return:
    """
    if for_which_classes is None:
        for_which_classes = np.unique(image)
        for_which_classes = for_which_classes[for_which_classes > 0]
    assert 0 not in for_which_classes, "cannot remove background"

    if volume_per_voxel is None:
        volume_per_voxel = 1
    
    for c in for_which_classes:
        if isinstance(c, (list, tuple)):
            c = tuple(c)
            mask = np.zeros_like(image, dtype=bool)
            for cl in c:
                mask[image == cl] = True
        else:
            mask = image == c
        # get labelmap and number of objects
        lmap, num_objects = label(mask.astype(int))

        # collect object sizes
        object_sizes = {}
        for object_id in range(1, num_objects + 1):
            object_sizes[object_id] = (lmap == object_id).sum() * volume_per_voxel

        if num_objects > 0:
            # removing the largest object if it is smaller than minimum_valid_object_size.
            for object_id in range(1, num_objects + 1):
                # we only remove objects that are smaller than minimum_valid_object_size
                if object_sizes[object_id] < minimum_valid_object_size[c]:
                    image[(lmap == object_id) & mask] = 0
    
    return image

其他的写法

# 通过连通成分分析,移除小区域
import SimpleITK as sitk
import os
import argparse
from pathlib import Path


def RemoveSmallConnectedCompont(sitk_maskimg, rate=0.5):
    '''
    two steps:
        step 1: Connected Component analysis: 将输入图像分成 N 个连通域
        step 2: 假如第 N 个连通域的体素小于最大连通域 * rate,则被移除
    :param sitk_maskimg: input binary image 使用 sitk.ReadImage(path, sitk.sitkUInt8) 读取,
                        其中sitk.sitkUInt8必须注明,否则使用 sitk.ConnectedComponent 报错
    :param rate: 移除率,默认为0.5, 小于 1/2最大连通域体素的连通域被移除
    :return:  binary image, 移除了小连通域的图像
    '''

    # step 1 Connected Component analysis
    cc = sitk.ConnectedComponent(sitk_maskimg)
    stats = sitk.LabelIntensityStatisticsImageFilter()
    stats.Execute(cc, sitk_maskimg)
    maxlabel = 0   # 获取最大连通域的索引
    maxsize = 0    # 获取最大连通域的体素大小

    # 遍历每一个连通域, 获取最大连通域的体素大小和索引
    for l in stats.GetLabels():  # stats.GetLabels()  (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
        size = stats.GetPhysicalSize(l)   # stats.GetPhysicalSize(5)=75  表示第5个连通域的体素有75个
        if maxsize < size:
            maxlabel = l
            maxsize = size

    # step 2 获取每个连通域的大小,保留 size >= maxsize * rate 的连通域
    not_remove = []
    for l in stats.GetLabels():
        size = stats.GetPhysicalSize(l)
        if size >= maxsize * rate:
            not_remove.append(l)

    labelmaskimage = sitk.GetArrayFromImage(cc)
    outmask = labelmaskimage.copy()
    outmask[labelmaskimage != maxlabel] = 0
    for i in range(len(not_remove)):
        outmask[labelmaskimage == not_remove[i]] = 1
  # 保存图像
    outmask = outmask.astype('float32')

    out = sitk.GetImageFromArray(outmask)
    out.SetDirection(sitk_maskimg.GetDirection())
    out.SetSpacing(sitk_maskimg.GetSpacing())
    out.SetOrigin(sitk_maskimg.GetOrigin())   # 使 out 的层厚等信息同输入一样

    return out  # to save image: sitk.WriteImage(out, 'largecc.nii.gz')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="remove small connected domains")
    parser.add_argument('--input', type=str, default="./123.nii.gz")
    parser.add_argument("--output", type=str, default='./123.nii.gz')
    args = parser.parse_args()

    # for single image
  
    sitk_maskimg = sitk.ReadImage(args.input, sitk.sitkUInt8)
    out = RemoveSmallConnectedCompont(sitk_maskimg, rate=0.5)  # 可以设置不同的比率
    sitk.WriteImage(out, args.output)

通过填补孔洞提高分割准确度

from scipy.ndimage.morphology import binary_fill_holes
import numpy as np
from scipy import ndimage
import nibabel as nib
from skimage.measure import label
import matplotlib.pyplot as plt


def hole_filling(bw, hole_min, hole_max, fill_2d=True):
    bw = bw > 0
    if len(bw.shape) == 2:
        background_lab = label(~bw, connectivity=1)
        fill_out = np.copy(background_lab)
        component_sizes = np.bincount(background_lab.ravel())
        too_big = component_sizes > hole_max
        too_big_mask = too_big[background_lab]
        fill_out[too_big_mask] = 0
        too_small = component_sizes < hole_min
        too_small_mask = too_small[background_lab]
        fill_out[too_small_mask] = 0
    elif len(bw.shape) == 3:
        if fill_2d:
            fill_out = np.zeros_like(bw)
            for zz in range(bw.shape[1]):
                background_lab = label(~bw[:, zz, :], connectivity=1)   # 1表示4连通, ~bw[zz, :, :]1变为0, 0变为1
                # 标记背景和孔洞, target区域标记为0
                out = np.copy(background_lab)
                # plt.imshow(bw[:, :, 87])
                # plt.show()
                component_sizes = np.bincount(background_lab.ravel())
                # 求各个类别的个数
                too_big = component_sizes > hole_max
                too_big_mask = too_big[background_lab]

                out[too_big_mask] = 0
                too_small = component_sizes < hole_min
                too_small_mask = too_small[background_lab]
                out[too_small_mask] = 0
                # 大于最大孔洞和小于最小孔洞的都标记为0, 所以背景部分被标记为0了。只剩下符合规则的孔洞
                fill_out[:, zz, :] = out
                # 只有符合规则的孔洞区域是1, 背景及target都是0
        else:
            background_lab = label(~bw, connectivity=1)
            fill_out = np.copy(background_lab)
            component_sizes = np.bincount(background_lab.ravel())
            too_big = component_sizes > hole_max
            too_big_mask = too_big[background_lab]
            fill_out[too_big_mask] = 0
            too_small = component_sizes < hole_min
            too_small_mask = too_small[background_lab]
            fill_out[too_small_mask] = 0
    else:
        print('error')
        return

    return np.logical_or(bw, fill_out)  # 或运算,孔洞的地方是1,原来target的地方也是1

参数简介
bw: array, 待填补的数组
hole_min: 孔洞像素的个数最小值,一般为0
hole_max:孔洞像素的个数最大值。
fill_2d:True:二维填充。False:三维填充 只有当孔洞像素值个数在 [hole_min, hole_max] 才会被填补。

模型设计

有关语义分割的奇技淫巧有哪些? - 知乎 (zhihu.com)
从Kaggle学语义分割技巧 - mdnice 墨滴
(42条消息) 分割网络中的奇技淫巧_小白学视觉的博客-CSDN博客

参考:
Deep Learning for Medical Image Segmentation:Tricks, Challenges and Future Directions
hust-linyi/MedISeg (github.com)

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,287评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,346评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,277评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,132评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,147评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,106评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,019评论 3 417
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,862评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,301评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,521评论 2 332
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,682评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,405评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,996评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,651评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,803评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,674评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,563评论 2 352

推荐阅读更多精彩内容