胃肠道分割比赛个人方案总结

赛题解析

本次竞赛中将创建一个模型,在核磁共振扫描中自动分割胃和肠。MRI扫描来自实际的癌症患者,他们在不同日期进行了1-5次MRI扫描。基于这些扫描的数据集来构建算法从而提出创造性的深度学习解决方案,帮助癌症患者获得更好的护理。
数据分析: 图像采用16位灰度的PNG格式。每个案例都由多组扫描切片代表(每组由扫描发生的日期来标识)。有些案例是按时间划分的(早期在训练中,后期在测试中),而有些案例是按案例划分的——整个案例都在训练或测试中。这个比赛的目标是能够推广到部分和完全看不见的情况。
赛题难点:

  • 每个病例有完整的扫描文件——引申出2.5D或者3D的思想,这里我采用的横断面和冠状面双模型来获取3D信息,从金牌方案来看直接3D训练是可行的,和2D模型融合会达到优异的性能。
  • 病例中可能存在少许切片存在胃肠器官但是没有进行标签勾画的(可能受到伪影影响),目前优秀的方案都考虑到了这一点。top1的方案更是直接以此建立了分类模型。

数据EDA

训练集

获取文件信息:
The image filenames include 4 numbers (ex. 276_276_1.63_1.63.png).

  • slice height (integer in pixels)
  • slice width (integer in pixels)
  • heigh pixel spacing (floating point in mm)
  • width pixel spacing (floating point in mm)

The first two defines the resolution of the slide. The last two record the physical size of each pixel.

路径信息
def path2info(row):
    path = row['image_path']
    data = path.split('/')
    slice_ = int(data[-1].split('_')[1])
    case = int(data[-3].split('_')[0].replace('case',''))
    day = int(data[-3].split('_')[1].replace('day',''))
    width = int(data[-1].split('_')[2])
    height = int(data[-1].split('_')[3])
    row['height'] = height
    row['width'] = width
    row['case'] = case
    row['day'] = day
    row['slice'] = slice_
    # row['id'] = f'case{case}_day{day}_slice_{slice_}'
    return row

图像大小分布情况
分割的分布情况

mask的RLE编码

RLE编码原理
RLE编码又叫行程编码,是最简单、最古老的数据压缩技术之一,它的原理是通过检测统计数据流中重复的位或字符序列,并用它们出现的次数和每次出现的个数形成新的代码。从而达到数据压缩的目的。

编码:

  • 定义mask:
mask = np.array([ [0, 1, 1, 0], [1, 1, 1, 1], [1, 0, 0, 1] ])

mask的图像就是上方那个简单的图像。

  • 扁平化
mask.flatten()
> array([0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1])
  • 首尾加0
np.concatenate([[0], mask, [0]]) 
> array([0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0])
  • 每个元素与前一个元素比较,找到不同的位置
runs = np.where(mask[1:] != mask[:-1])[0] + 1 
> array([ 2, 4, 5, 10, 12, 13])
  • 作差,找到所有连续1的起始位置和个数
runs[1::2] -= runs[::2] 
> array([ 2, 2, 5, 5, 12, 1])

至此,编码就完成了。

code = ' '.join(str(x) for x in runs)

注意到它这里的起始位置是从1开始算的

def mask2rle(msk, thr=0.5):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    msk    = cp.array(msk)
    pixels = msk.flatten()
    pad    = cp.array([0])
    pixels = cp.concatenate([pad, pixels, pad])
    runs   = cp.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

我的解决方案

公榜0.882,私榜0.875,银牌top 3%,主要思想包括2.5D数据横断面模型和冠状面模型融合,需要的完整流程见思维导图:


具体步骤如下:

数据预处理

1. 如何生成2.5D数据
采用的是在线生成2.5D数据的思想,并且隔层采样的非连续生成->[t-2, t ,t+2]
sample_step用于控制采样的间隔,这里我用的是隔一层采样,最终生成3通道的2.5D数据

class build_dataset(Dataset):
    def __init__(self, df, label=True, sample_step=1, transforms=None, cfg=None):
        self.df = df
        self.label = label
        self.img_paths = df['image_path'].tolist() # image
        self.ids = df['id'].tolist()

        if 'mask_path' in df.columns:
            self.mask_paths  = df['mask_path'].tolist() # mask
        else:
            self.mask_paths = None
        
        self.sample_step = sample_step
        self.transforms = transforms
        self.n_25d_shift = cfg.n_25d_shift

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        #### load id
        id       = self.ids[index]
        #### load image
        img_path  = self.img_paths[index]
        img = self.load_2_5d_slice(img_path, self.sample_step) # [h, w, c]
        h, w = img.shape[:2]
        
        if self.label: # train
            #### load mask
            mask_path = self.mask_paths[index]
            mask = np.load(mask_path).astype('float32')
            mask/=255.0 # scale mask to [0, 1]

            ### augmentations
            data = self.transforms(image=img, mask=mask)
            img  = data['image']
            mask  = data['mask']
            class_label = np.array(self.df['class_new'][index])
            img = np.transpose(img, (2, 0, 1)) # [h, w, c] => [c, h, w]
            mask = np.transpose(mask, (2, 0, 1)) # [h, w, c] => [c, h, w]
            return torch.tensor(img), torch.tensor(mask), torch.tensor(class_label)
        
        else:  # test
            ### augmentations
            data = self.transforms(image=img)
            img  = data['image']
            img = np.transpose(img, (2, 0, 1)) # [h, w, c] => [c, h, w]
            return torch.tensor(img), id, h, w

    ###############################################################
    ##### >>>>>>> trick: construct 2.5d slice images <<<<<<
    ###############################################################
    def load_2_5d_slice(self, middle_img_path, step):
        #### 步骤1: 获取中间图片的基本信息
        #### eg: middle_img_path: 'slice_0005_266_266_1.50_1.50.png' 
        middle_slice_num = os.path.basename(middle_img_path).split('_')[1] # eg: 0005
        middle_str = 'slice_'+middle_slice_num

        new_25d_imgs = []

        ##### 步骤2:按照左右n_25d_shift数量进行填充,如果没有相应图片填充为Nan.
        ##### 注:经过EDA发现同一天的所有患者图片的shape是一致的
        for i in range(-self.n_25d_shift, self.n_25d_shift+1, step): # eg: i = {-2, -1, 0, 1, 2}
            shift_slice_num = int(middle_slice_num) + i
            shift_str = 'slice_'+str(shift_slice_num).zfill(4)
            shift_img_path = middle_img_path.replace(middle_str, shift_str)
            
            if os.path.exists(shift_img_path):
                shift_img = cv2.imread(shift_img_path, cv2.IMREAD_UNCHANGED) # [w, h]
                new_25d_imgs.append(shift_img)
            else:
                new_25d_imgs.append(None)
        
        ##### 步骤3:从中心开始往外循环,依次填补None的值
        ##### eg: n_25d_shift = 2, 那么形成5个channel, idx为[0, 1, 2, 3, 4], 所以依次处理的idx为[1, 3, 0, 4]
        shift_left_idxs = []
        shift_right_idxs = []
        for related_idx in range(int(self.n_25d_shift/step)):
            shift_left_idxs.append(int(self.n_25d_shift/step) - related_idx - 1)
            shift_right_idxs.append(int(self.n_25d_shift/step) + related_idx + 1)

        for left_idx, right_idx in zip(shift_left_idxs, shift_right_idxs):
            if new_25d_imgs[left_idx] is None:
                new_25d_imgs[left_idx] = new_25d_imgs[left_idx+1]
            if new_25d_imgs[right_idx] is None:
                new_25d_imgs[right_idx] = new_25d_imgs[right_idx-1]

        new_25d_imgs = np.stack(new_25d_imgs, axis=2).astype('float32') # [w, h, c]
        mx_pixel = new_25d_imgs.max()
        if mx_pixel != 0:
            new_25d_imgs /= mx_pixel
        return new_25d_imgs

2.数据增强方式

经过实验,去除了原本的垂直翻转,使用的数据增强:

  • 随机放射变换(ShiftScaleRotate):该方法可以对图片进行平移(translate)、缩放(scale)和旋转(roatate)
  • 三种非刚体变换方法:弹性变换(ElasticTransform)、网格失真(GridDistortion) 和 畸变(OpticalDistortion)


  • 矩形丢弃增强器(CoarseDropout):将图像中的矩形区域设置为零。


def build_transforms(CFG):
    data_transforms = {
        "train": A.Compose([
            A.OneOf([
                A.Resize(*CFG.img_size, interpolation=cv2.INTER_NEAREST, p=1.0),
            ], p=1),
            A.HorizontalFlip(p=0.5),
            # A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
            # A.RandomBrightnessContrast(p=0.5),
            A.OneOf([
                A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
                A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
            ], p=0.25),
            A.CoarseDropout(max_holes=8, max_height=CFG.img_size[0]//20, max_width=CFG.img_size[1]//20,
                            min_holes=5, fill_value=0, mask_fill_value=0, p=0.5),
            ], p=1.0),
        
        "valid_test": A.Compose([
            A.Resize(*CFG.img_size, interpolation=cv2.INTER_NEAREST),
            ], p=1.0)
        }
    return data_transforms

模型选择

目前大家比较常用的模型包括UNet、UNet++、UperNet等,Bacbone包括EfficientNet、Swin-Transformer、ConvNeXt。比赛中更重要的是调参和模型融合。本人使用基本的Unet并加入了一些模块改进模型,包括ASPP、Hypercolumn。模块在smp源码的基础上直接加入使用:

  1. ASPP:传统 U 形网络的缺陷是由一个小的感受野造成的。因此,如果模型需要对大对象的分割做出决定,特别是对于大图像分辨率,它可能会因为只能查看对象的一部分而感到困惑。增加感受野并实现图像不同部分之间交互的一种方法是使用具有不同扩张的卷积块组合(在 ASPP 块中具有不同速率的 Atrous 卷积)。
class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation):
        super(_ASPPModule, self).__init__()
        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                                     stride=1, padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.relu(x)


class ASPP(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels=None, dilations=[1, 6, 12, 18]):
        super(ASPP, self).__init__()
        self.context = nn.Sequential(nn.Conv2d(in_channels,  mid_channels, kernel_size=3, padding=1, dilation=1, bias=True),
                                     nn.BatchNorm2d(mid_channels),
                                     nn.ReLU(),
                                     BaseOC_Context_Module(in_channels=mid_channels, out_channels=mid_channels,
                                                           key_channels=mid_channels // 2, value_channels=mid_channels,
                                                           dropout=0, sizes=([2])))
        self.aspp1 = _ASPPModule(in_channels, mid_channels, 1, padding=0, dilation=dilations[0])
        self.aspp2 = _ASPPModule(in_channels, mid_channels, 3, padding=dilations[1], dilation=dilations[1])
        self.aspp3 = _ASPPModule(in_channels, mid_channels, 3, padding=dilations[2], dilation=dilations[2])
        self.aspp4 = _ASPPModule(in_channels, mid_channels, 3, padding=dilations[3], dilation=dilations[3])

        self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                             nn.Conv2d(in_channels, mid_channels, 1, stride=1, bias=False),
                                             nn.BatchNorm2d(mid_channels),
                                             nn.ReLU())
        out_channels = out_channels if out_channels is not None else mid_channels
        self.conv1 = nn.Conv2d(mid_channels * 5, out_channels, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.context(x)
        # x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear')
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self.conv1(x)
        x = self.relu(x)
        x = self.bn1(x)

        return x
  1. Hypercolumn: 本质上就是特征金字塔网络(FPN)的思想,解码器的不同上采样块和输出层之间的附加跳过连接。因此,最终预测是基于 U-net 输出与中间层调整大小的输出串联接产生的。这些跳跃连接为梯度传导提供了捷径以提高模型性能和收敛速度。由于中间层有许多通道,它们的上采样和用作最后一层的输入会在计算时间和内存方面引入大量开销。因此,在调整大小之前应用 3*3+3*3 卷积(分解)以减少通道数。
class FPN(nn.Module):
    def __init__(self, input_channels: list, output_channels: list):
        super().__init__()
        self.convs = nn.ModuleList(
            [nn.Sequential(nn.Conv2d(in_ch, out_ch * 2, kernel_size=3, padding=1),
                           nn.ReLU(inplace=True), nn.BatchNorm2d(out_ch * 2),
                           nn.Conv2d(out_ch * 2, out_ch, kernel_size=3, padding=1))
             for in_ch, out_ch in zip(input_channels, output_channels)])

    def forward(self, xs: list, last_layer):
        hcs = [F.interpolate(c(x), scale_factor=2 ** (len(self.convs) - i), mode='bilinear', align_corners=True)
               for i, (c, x) in enumerate(zip(self.convs, xs))]
        hcs.append(last_layer)
        return torch.cat(hcs, dim=1)
  1. backbone的选择:横断面的分辨率设置为[384, 384]。从efficientnetb2~b6的实验中最终选择了b6,其显示的本地CV最佳,另外尝试了se_resnext50_32x4d、se_resnext101_32x4d,效果依然不如efficientnetb6。

Efficientnet

  • 根据以往的经验,增加网络的深度depth能够得到更加丰富、复杂的特征并且能够很好的应用到其它任务中。但网络的深度过深会面临梯度消失,训练困难的问题。
  • 增加网络的width能够获得更高细粒度的特征并且也更容易训练,但对于width很大而深度较浅的网络往往很难学习到更深层次的特征。
  • 增加输入网络的图像分辨率能够潜在得获得更高细粒度的特征模板,但对于非常高的输入分辨率,准确率的增益也会减小。并且大分辨率图像会增加计算量。

    当计算资源增加时,如果彻底的去搜索宽度、深度、图片分辨率这三个变量的各种组合,那么搜索空间将无限大,搜索效率会非常低。作者提出了组合缩放的方法:
    在达到相同效率的情况下,EfficientNet具有巨大参数量和计算量优势,比同级网络少了好几倍。总结一句话:模型小,计算量小,性能高。在迁移学习方面,EfficientNet也凭借高性能,低参数量,占得巨大优势。这说明了该网络在数据集扩展方面具有很强的鲁棒性,易于扩展别的计算机视觉任务。

损失函数的设置

  1. 组合了三种loss:BCE+dice+边缘loss。从提交结果来看边缘loss带来小的提升,加入边缘loss的目的在于希望模型更加注重边缘的分割效果。边缘loss尝试过Active Contour Loss和Hausdorff Loss,Active Contour Loss没有什么效果,使用开源的Hausdorff Loss出错但是没有找到解决方案,后续可以尝试。最终采取的方案对图像进行拉普拉斯变换突出边缘后计算loss。
    边缘loss
laplace = torch.tensor([[1, 1, 1],
                        [1, -8, 1],
                        [1, 1, 1]], dtype=torch.float, requires_grad=False).view(1, 1, 3, 3)
avgpool = torch.tensor([[1/9, 1/9, 1/9],
                         [1/9, 1/9, 1/9],
                         [1/9, 1/9, 1/9]], dtype=torch.float, requires_grad=False).view(1, 1, 3, 3)

def conv_operator(x, avgpool, laplace, in_channels=3):
    x = nn.functional.conv2d(x, avgpool.repeat(1, in_channels, 1, 1), stride=1, padding=1,)
    x = nn.functional.conv2d(x, avgpool.repeat(1, in_channels, 1, 1), stride=1, padding=1,)
    x = nn.functional.conv2d(x, avgpool.repeat(1, in_channels, 1, 1), stride=1, padding=1,)
    x = nn.functional.conv2d(x, laplace.repeat(1, in_channels, 1, 1), stride=1, padding=1,)
    return x
def build_loss():
    BCELoss     = smp.losses.SoftBCEWithLogitsLoss()
    TverskyLoss = smp.losses.TverskyLoss(mode='multilabel', log_loss=False)
    FocalLoss = smp.losses.FocalLoss(mode='multilabel')
    return {"BCELoss":BCELoss, "TverskyLoss":TverskyLoss, "FocalLoss":FocalLoss}

def train_one_epoch(model, train_loader, optimizer, scheduler, losses_dict, CFG):
    model.train()
    scaler = amp.GradScaler() 
    losses_all, bce_all, tverskly_all, boundary_all, cls_all = 0, 0, 0, 0, 0
    
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc='Train ')
    for _, (images, masks, class_label) in pbar:
        optimizer.zero_grad()

        images = images.to(CFG.device, dtype=torch.float) # [b, c, w, h]
        masks  = masks.to(CFG.device, dtype=torch.float)  # [b, c, w, h]
        class_label = class_label.to(CFG.device, dtype=torch.float)

        with amp.autocast(enabled=True):
            y_preds, class_preds = model(images) # [b, c, w, h]
            class_loss = torch.nn.BCEWithLogitsLoss()(class_preds, class_label)
            bce_loss = losses_dict["BCELoss"](y_preds, masks)
            tverskly_loss = losses_dict["TverskyLoss"](y_preds, masks)
            boundary_loss = nn.MSELoss()(conv_operator(torch.sigmoid(y_preds).cpu().float()), conv_operator(masks.cpu().float()))
            losses = 10*bce_loss + tverskly_loss + 1000*boundary_loss + 5*class_loss
        
        scaler.scale(losses).backward()
        scaler.step(optimizer)
        scaler.update()
                    
        scheduler.step()
       
        losses_all += losses.item() / images.shape[0]
        bce_all += bce_loss.item() / images.shape[0]
        tverskly_all += tverskly_loss.item() / images.shape[0]
        boundary_all += boundary_loss.item() / images.shape[0]
        cls_all += class_loss.item() / images.shape[0]
   
    current_lr = optimizer.param_groups[0]['lr']
    print("lr: {:.4f}".format(current_lr), flush=True)
    print("loss: {:.3f}, bce_all: {:.3f}, tverskly_all: {:.3f}, boundary_all: {:.3f}, cls_all: {:.3f}".\
        format(losses_all, bce_all, tverskly_all, boundary_all, cls_all), flush=True)
  1. 加入了深监督: 考虑到有的层没有标注,有的层可能只有1-2个标签,故加入深监督希望在encoder部分实现多标签的分类,将encoder的输出头连接全连接层实现三分类输出,相当于实现了先分类再分割。对应前面提到的赛题难点修改推理部分,以0.3作为阈值,大于0.3则将预测输出生成mask,小于0.3则不管预测值直接生成全0的mask。该损失函数依然是多标签loss所以选择的是torch.nn.BCEWithLogitsLoss()。

所谓深监督(Deep Supervision),就是在深度神经网络的某些中间隐藏层加了一个辅助的分类器作为一种网络分支来对主干网络进行监督的技巧,用来解决深度神经网络训练梯度消失和收敛速度过慢等问题。通常而言,增加神经网络的深度可以一定程度上提高网络的表征能力,但随着深度加深,会逐渐出现神经网络难以训练的情况,其中就包括像梯度消失梯度爆炸等现象。为了更好的训练深度网络,我们可以尝试给神经网络的某些层添加一些辅助的分支分类器来解决这个问题。这种辅助的分支分类器能够起到一种判断隐藏层特征图质量好坏的作用。

模型集成

为了充分利用数据的3D特性,将每个病例的横断面拼接在一起构成完整的3D体素,然后由3D体素生成每个病例的冠状面和矢状面,建立相应的冠状面和矢状面分割模型。本地CV来看冠状面可以达到和横断面相近的分割性能,但是矢状面差强人意,所以最终的模型集成采用了横断面和冠状面模型的集成。并将设置横断面分割模型的权重为0.6。

这里冠状面分割模型的输入分辨率根据冠状面原始分辨率的特点设置成宽高比为1:1.5。

生成3D数据获得冠状面和矢状面:

def build_3D_img(folder_path):
    '''
    Take all slices and store them into a 3D numpy array
    '''
    slice_list = sorted(os.listdir(folder_path))
    slice_pixel_list = []
    for slice_filename in slice_list:
        pixel_slice = Image.open(folder_path+slice_filename)
        slice_pixel_list.append(np.array(pixel_slice))
        pixel_slice.close()
    return np.array(slice_pixel_list)

def build_3D_segmentation(segmentation_df, img_width, img_height):
    '''
    Take all segmentation slices, decode the RLE and 
    concatenate the slices together to obtain a 3D numpy array
    '''
    slice_pixel_list = []
    for segmentation in segmentation_df["segmentation"].values:
        if pd.isnull(segmentation):
            slice_pixel_list.append(np.zeros((img_height,img_width)).astype("float64"))
        else:
            slice_pixel_list.append(decode_rle(segmentation, img_height, img_width))
    return np.array(slice_pixel_list)

def generate_sagittal_slices(df, case_id, day, mask_path, MRI_3D_img, 
                             large_bowel_3D_mask, small_bowel_3D_mask, stomach_3D_mask):
    '''
    Create slices for an MRI from a sagittal plane and
    returns a dataframe with filepaths and groundtruths
    '''
    
    slice_info_list = []
    for i in range(MRI_3D_img.shape[2]):
        filename = "case{}_{}_sagittal_{}.png".format(case_id, day, i)
        sagittal_slice = Image.fromarray(MRI_3D_img[:,:,i])
        width, height = sagittal_slice.size
        
        sagittal_slice.save("{}/{}".format(image_output_folder,filename))
        large_bowel_rle = rle_encoding(large_bowel_3D_mask[:,:,i].T)
        small_bowel_rle = rle_encoding(small_bowel_3D_mask[:,:,i].T)
        stomach_rle = rle_encoding(stomach_3D_mask[:,:,i].T)

        slice_info_list.append({"filename":filename, 
                        "organ_class":"large_bowel", 
                        "segmentation":large_bowel_rle, 
                        "slice_plane":"sagittal",
                        "width":width,
                        "height":height,
                        "mask_path": mask_path})
        slice_info_list.append({"filename":filename, 
                        "organ_class":"small_bowel", 
                        "segmentation":small_bowel_rle, 
                        "slice_plane":"sagittal",
                        "width":width,
                        "height":height,
                        "mask_path": mask_path})
        slice_info_list.append({"filename":filename, 
                        "organ_class":"stomach", 
                        "segmentation":stomach_rle, 
                        "slice_plane":"sagittal",
                        "width":width,
                        "height":height,
                        "mask_path": mask_path})
    return df.append(slice_info_list)

def generate_coronal_slices(df, case_id, day, mask_path, MRI_3D_img, 
                            large_bowel_3D_mask, small_bowel_3D_mask, stomach_3D_mask):
    '''
    Create slices for an MRI from a coronal plane and
    returns a dataframe with filepaths and groundtruths
    '''
    
    slice_info_list = []
    for i in range(MRI_3D_img.shape[1]):
        filename = "case{}_{}_coronal_{}.png".format(case_id, day, i)
        coronal_slice = Image.fromarray(MRI_3D_img[:,i,:])
        width, height = coronal_slice.size
        
        coronal_slice.save("{}/{}".format(image_output_folder,filename))
        large_bowel_rle = rle_encoding(large_bowel_3D_mask[:,i,:].T)
        small_bowel_rle = rle_encoding(small_bowel_3D_mask[:,i,:].T)
        stomach_rle = rle_encoding(stomach_3D_mask[:,i,:].T)
        
        slice_info_list.append({"filename":filename, 
                        "organ_class":"large_bowel", 
                        "segmentation":large_bowel_rle, 
                        "slice_plane":"coronal",
                        "width":width,
                        "height":height,
                        "mask_path": mask_path})
        slice_info_list.append({"filename":filename, 
                        "organ_class":"small_bowel", 
                        "segmentation":small_bowel_rle, 
                        "slice_plane":"coronal",
                        "width":width,
                        "height":height,
                        "mask_path": mask_path})
        slice_info_list.append({"filename":filename, 
                        "organ_class":"stomach", 
                        "segmentation":stomach_rle, 
                        "slice_plane":"coronal",
                        "width":width,
                        "height":height,
                        "mask_path": mask_path})
    return df.append(slice_info_list)

def create_3D_img_and_seg(df, new_dataset_df, img_id, img_width, img_height):
    '''
    concatenate MRI slices and their segmentation information
    '''

    img_info_df = df[df.id.str.contains(img_id)]
    case_id = img_info_df.iloc[0].case_id
    day = img_info_df.iloc[0].day
    mask_path = img_info_df.iloc[0].mask_path
    segmentation_large_bowel_df = img_info_df.query("organ_class=='large_bowel'")
    segmentation_small_bowel_df = img_info_df.query("organ_class=='small_bowel'")
    segmentation_stomach_df = img_info_df.query("organ_class=='stomach'")
    
    MRI_3D_img = build_3D_img("{}/case{}/case{}_{}/scans/".format(train_folder, case_id, case_id, day))
    large_bowel_3D_mask = build_3D_segmentation(segmentation_large_bowel_df, img_width, img_height)
    small_bowel_3D_mask = build_3D_segmentation(segmentation_small_bowel_df, img_width, img_height)
    stomach_3D_mask = build_3D_segmentation(segmentation_stomach_df, img_width, img_height)
    
    new_dataset_df = generate_sagittal_slices(new_dataset_df, case_id, day, mask_path, MRI_3D_img, 
                                             large_bowel_3D_mask, small_bowel_3D_mask, stomach_3D_mask)
    new_dataset_df = generate_coronal_slices(new_dataset_df, case_id, day, mask_path, MRI_3D_img, 
                                             large_bowel_3D_mask, small_bowel_3D_mask, stomach_3D_mask)
    return new_dataset_df

new_dataset_df = pd.DataFrame()
for img_id in tqdm(train_df.image_id.unique()):
    id_df = train_df.query("image_id==@img_id").iloc[0]
    img_width = id_df.width
    img_height = id_df.height
    
    new_dataset_df = create_3D_img_and_seg(train_df, new_dataset_df, img_id, img_width, img_height)

推理的模型集成:

@torch.no_grad()
def generate_3D(slice_info, ckpt_paths, CFG):       
    data_transforms = build_transforms(CFG)
    test_dataset = build_dataset(slice_info, transforms=data_transforms['valid_test'], cfg=CFG)
    test_loader  = DataLoader(test_dataset, batch_size=CFG.valid_bs, num_workers=8, shuffle=False, pin_memory=False)

    pbar = tqdm(enumerate(test_loader), total=len(test_loader))
    large_bowel_list, small_bowel_list,  stomach_list = [], [], []

    for _, (images, ids, cases, days, h, w) in pbar:

        images  = images.to(CFG.device, dtype=torch.float) # [b, c, w, h]
        size = images.size()
        masks = torch.zeros((size[0], 3, size[2], size[3]), device=CFG.device, dtype=torch.float32) # [b, c, w, h]
        labels = torch.zeros((size[0], 3), device=CFG.device, dtype=torch.float32) # [b, c]

        for sub_ckpt_path in ckpt_paths:
            model = build_model(CFG, test_flag=True)
            model.load_state_dict(torch.load(sub_ckpt_path))
            model.eval()
            y_preds, label_preds = model(images) # [b, c, w, h]
            y_preds  = torch.nn.Sigmoid()(y_preds)
            label_preds  = torch.nn.Sigmoid()(label_preds)
            masks += y_preds
            labels += label_preds
            if CFG.tta:
                rotates = [10, 350]
                for rotate in rotates:
                    images_f = F.rotate(images, rotate, resample=2, expand=False)
                    y_preds, label_preds = model(images_f) # [b, c, w, h]
                    y_preds = F.rotate(y_preds, 360-rotate, resample=2, expand=False)
                    y_preds  = torch.nn.Sigmoid()(y_preds)
                    label_preds  = torch.nn.Sigmoid()(label_preds)
                    masks += y_preds
                    labels += label_preds
                    
                images_f = torch.flip(images, [-1])
                y_preds, label_preds = model(images_f) # [b, c, w, h]
                y_preds = torch.flip(y_preds, [-1])
                y_preds   = torch.nn.Sigmoid()(y_preds)
                label_preds  = torch.nn.Sigmoid()(label_preds)
                masks += y_preds
                labels += label_preds
            
            del model, y_preds, label_preds
            gc.collect()

        if CFG.tta:
            total_ckpt_paths = len(ckpt_paths) * 4
        else:
            total_ckpt_paths = len(ckpt_paths)
            
        masks  /= total_ckpt_paths
        labels /= total_ckpt_paths
        masks = masks.permute((0, 2, 3, 1)).to(torch.float32).cpu().detach().numpy()
        labels = labels.to(torch.float32).cpu().detach().numpy()
        
        for idx in range(masks.shape[0]):
            height = h[idx].item()
            width = w[idx].item()
            msk = cv2.resize(masks[idx], 
                            dsize=(width, height), 
                            interpolation=cv2.INTER_LINEAR) # back to original shape
            large_bowel = msk[:,:,0] if labels[idx, 0] > CFG.class_thr else np.zeros((height, width))
            small_bowel = msk[:,:,1] if labels[idx, 1] > CFG.class_thr else np.zeros((height, width))
            stomach = msk[:,:,2] if labels[idx, 2] > CFG.class_thr else np.zeros((height, width))
            large_bowel_list.append(large_bowel)
            small_bowel_list.append(small_bowel)
            stomach_list.append(stomach)
         
        del images, masks
        gc.collect() 

    del test_dataset, test_loader
    gc.collect()

    if CFG.anatomical_plane == "transverse":
        large_bowel_preds = np.stack(large_bowel_list,axis=0)
        small_bowel_preds = np.stack(small_bowel_list,axis=0)
        stomach_preds = np.stack(stomach_list,axis=0)
    if CFG.anatomical_plane == "coronal":
        large_bowel_preds = np.stack(large_bowel_list,axis=1)
        small_bowel_preds = np.stack(small_bowel_list,axis=1)
        stomach_preds = np.stack(stomach_list,axis=1)
    if CFG.anatomical_plane == "sagittal":
        large_bowel_preds = np.stack(large_bowel_list,axis=2)
        small_bowel_preds = np.stack(small_bowel_list,axis=2)
        stomach_preds = np.stack(stomach_list,axis=2)
    
    del large_bowel_list, small_bowel_list, stomach_list
    gc.collect()
    
    return large_bowel_preds, small_bowel_preds, stomach_preds

pred_strings = []
pred_ids = []
pred_classes = []
for index, row in test_df2.iterrows():
    case_id = row.case
    day = row.day
    if sub_firset:
        folder_path = "../input/uw-madison-gi-tract-image-segmentation/train/case{}/case{}_day{}/scans/".format(case_id, case_id, day)
    else:
        folder_path = "../input/uw-madison-gi-tract-image-segmentation/test/case{}/case{}_day{}/scans/".format(case_id, case_id, day)
    slice_list = sorted(os.listdir(folder_path))
    slice_pixel_list = []
    for slice_filename in slice_list:
        pixel_slice = cv2.imread(folder_path+slice_filename, cv2.IMREAD_UNCHANGED)
        slice_pixel_list.append(pixel_slice)
    MRI_3D_img = np.array(slice_pixel_list).astype('float32')
    del slice_list, slice_pixel_list
    gc.collect()

    transverse_slice_info_list = []
    coronal_slice_info_list = []

    for i in range(MRI_3D_img.shape[1]):
        slice_ = MRI_3D_img[:,i,:].T
        image_id = "case{}_day{}_slice".format(case_id, day)+str(i).zfill(4)
        width, height = slice_.shape
        slice_number = i
        coronal_slice_info_list.append({"id":image_id,
                                "image": slice_.T,#h,w
                                "width":width,
                                "height":height,
                                "slice": slice_number})

    for i in range(MRI_3D_img.shape[0]):
        slice_ = MRI_3D_img[i,:,:].T # d,h,w
        image_id = "case{}_day{}_slice".format(case_id, day)+str(i).zfill(4)
        width, height = slice_.shape
        slice_number = i
        transverse_slice_info_list.append({"id":image_id,
                                "image": slice_.T,#h,w
                                "width":width,
                                "height":height,
                                "slice": slice_number})

    transverse_slice_info = pd.DataFrame(transverse_slice_info_list)
    coronal_slice_info = pd.DataFrame(coronal_slice_info_list)
    display(transverse_slice_info.head(5))

    del MRI_3D_img, coronal_slice_info_list, transverse_slice_info_list
    gc.collect()

    CFG.anatomical_plane = "transverse"
    CFG.img_size = [384,384]
    CFG.backbone = 'efficientnet-b6'
    CFG.ckpt_name = "efficientnetb6Plus_img384384_bs48_fold5_2.5d_channel3_step2_boundary_delfault"
    ckpt_path = f"../input/{CFG.ckpt_fold}/{CFG.ckpt_name}"
    ckpt_paths  = glob(f'{ckpt_path}/best*')
    assert len(ckpt_paths) == CFG.n_fold, "ckpt path error!"
    
    large_bowel_preds1, small_bowel_preds1, stomach_preds1 = generate_3D(transverse_slice_info, ckpt_paths, CFG)
    
    del transverse_slice_info, ckpt_paths
    gc.collect()
    
    CFG.anatomical_plane = "coronal"
    CFG.img_size = [256,384]
    CFG.backbone = 'efficientnet-b6'
    CFG.ckpt_name = "efficientnetb6Plus_img256384_bs64_fold5_2.5d_channel3_step2_online_cls_delfault_coronal"
    ckpt_path = f"../input/{CFG.ckpt_fold}/{CFG.ckpt_name}"
    ckpt_paths  = glob(f'{ckpt_path}/best*')
    assert len(ckpt_paths) == CFG.n_fold, "ckpt path error!"
    
    large_bowel_preds2, small_bowel_preds2, stomach_preds2 = generate_3D(coronal_slice_info, ckpt_paths, CFG)
    del coronal_slice_info, ckpt_paths
    gc.collect()
    
    large_bowel_preds1 = CFG.transverse_weight*large_bowel_preds1 + (1 - CFG.transverse_weight)*large_bowel_preds2
    small_bowel_preds1 = CFG.transverse_weight*small_bowel_preds1 + (1 - CFG.transverse_weight)*small_bowel_preds2
    stomach_preds1 = CFG.transverse_weight*stomach_preds1 + (1 - CFG.transverse_weight)*stomach_preds2
    
    del large_bowel_preds2, small_bowel_preds2, stomach_preds2
    gc.collect()
    
    large_bowel_preds1  = np.where(large_bowel_preds1 > CFG.thr, 1, 0)
    small_bowel_preds1  = np.where(small_bowel_preds1 > CFG.thr, 1, 0)
    stomach_preds1  = np.where(stomach_preds1 > CFG.thr, 1, 0)
    
    for idx, id in enumerate(row.id):
        large_bowel_preds1[idx,:,:] = area_connection(large_bowel_preds1[idx,:,:], 9)
        small_bowel_preds1[idx,:,:] = area_connection(small_bowel_preds1[idx,:,:], 9)
        stomach_preds1[idx,:,:] = area_connection(stomach_preds1[idx,:,:], 9)
        rle = [mask2rle(large_bowel_preds1[idx,:,:]), mask2rle(small_bowel_preds1[idx,:,:]), mask2rle(stomach_preds1[idx,:,:])]
        pred_strings.extend(rle)
        pred_ids.extend([id]*len(rle))
        pred_classes.extend(['large_bowel', 'small_bowel', 'stomach'])
        
    del large_bowel_preds1, small_bowel_preds1, stomach_preds1
    gc.collect()

后处理

这里后处理除了前面通过encoder输出生成置信度高的mask以外,还去除了连通域中小的空洞和很小的连通域。实验来看这部分作用并不显著,但是可能有0.0001-0.001的影响。尝试过TTA的手段,没有作用。

def area_connection(result,area_threshold):
    """
    result:预测影像
    area_threshold:最小连通尺寸,小于该尺寸的都删掉
    """
    # 去除小物体
    result = skimage.morphology.remove_small_objects(result==1, min_size=area_threshold, connectivity=1, in_place=True) 
    # 去除孔洞
    result = skimage.morphology.remove_small_holes(result==1, area_threshold=area_threshold, connectivity=1, in_place=True) 
    return result

large_bowel_preds1[idx,:,:] = area_connection(large_bowel_preds1[idx,:,:], 9)
small_bowel_preds1[idx,:,:] = area_connection(small_bowel_preds1[idx,:,:], 9)
stomach_preds1[idx,:,:] = area_connection(stomach_preds1[idx,:,:], 9)

参数设置

    class CFG:
        # step1: hyper-parameter
        seed = 601  # birthday
        num_worker = 8 # debug => 0
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        ckpt_fold = "ckpt-ruan"
        ckpt_name = "efficientnetb6Plus_img384384_bs48_fold5_2.5d_channel3_step2_boundary_delfault"
        
        # step2: data
        n_25d_shift = 2
        sample_step = 2
        n_fold = 5
        img_size = [384, 384]
        train_bs = 48
        valid_bs = 64

        # step3: model
        backbone = 'efficientnet-b6'
        num_classes = 3

        # step4: optimizer
        epoch = 30
        lr  = 1e-3
        scheduler = 'CosineAnnealingLR'
        min_lr = 1e-6
        T_max = int(30000/train_bs*epoch)+50
        T_0 = 25
        warmup_epochs = 0
        wd = 1e-6

        # step5: infer
        thr = 0.4

结合金牌方案的总结

分类+分割是最佳方案的思路关键。分类中又可以分为两种思路:

  • 直接分类:本人方案中就是加入了classfication loss就是为了实现分类,除外可以借鉴方案
    在灰色部分使用了一个分类模型,该模型将预测中大于12像素的切片保留,否则切片生成的为全0预测,其于作为最终的预测后处理结合分割结果。
  • 区域提取:目标检测的一种思想,提取出包含标签的区域将该positive的切片作为下一阶段分割模型的训练集,是一种两阶段的思想。可以将第一阶段的切片预测结果进一步裁剪减少伪影的影响并降低内存需求。

    本人在encoder部分加入了分类的思想,可能比较粗糙。另外,如何利用3D的信息也是本次比赛的关键。
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
禁止转载,如需转载请通过简信或评论联系作者。
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,287评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,346评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,277评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,132评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,147评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,106评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,019评论 3 417
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,862评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,301评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,521评论 2 332
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,682评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,405评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,996评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,651评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,803评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,674评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,563评论 2 352

推荐阅读更多精彩内容