赛题解析
本次竞赛中将创建一个模型,在核磁共振扫描中自动分割胃和肠。MRI扫描来自实际的癌症患者,他们在不同日期进行了1-5次MRI扫描。基于这些扫描的数据集来构建算法从而提出创造性的深度学习解决方案,帮助癌症患者获得更好的护理。
数据分析: 图像采用16位灰度的PNG格式。每个案例都由多组扫描切片代表(每组由扫描发生的日期来标识)。有些案例是按时间划分的(早期在训练中,后期在测试中),而有些案例是按案例划分的——整个案例都在训练或测试中。这个比赛的目标是能够推广到部分和完全看不见的情况。
赛题难点:
- 每个病例有完整的扫描文件——引申出2.5D或者3D的思想,这里我采用的横断面和冠状面双模型来获取3D信息,从金牌方案来看直接3D训练是可行的,和2D模型融合会达到优异的性能。
- 病例中可能存在少许切片存在胃肠器官但是没有进行标签勾画的(可能受到伪影影响),目前优秀的方案都考虑到了这一点。top1的方案更是直接以此建立了分类模型。
数据EDA
获取文件信息:
The image filenames include 4 numbers (ex. 276_276_1.63_1.63.png).
- slice height (integer in pixels)
- slice width (integer in pixels)
- heigh pixel spacing (floating point in mm)
- width pixel spacing (floating point in mm)
The first two defines the resolution of the slide. The last two record the physical size of each pixel.
def path2info(row):
path = row['image_path']
data = path.split('/')
slice_ = int(data[-1].split('_')[1])
case = int(data[-3].split('_')[0].replace('case',''))
day = int(data[-3].split('_')[1].replace('day',''))
width = int(data[-1].split('_')[2])
height = int(data[-1].split('_')[3])
row['height'] = height
row['width'] = width
row['case'] = case
row['day'] = day
row['slice'] = slice_
# row['id'] = f'case{case}_day{day}_slice_{slice_}'
return row
RLE编码原理
RLE编码又叫行程编码,是最简单、最古老的数据压缩技术之一,它的原理是通过检测统计数据流中重复的位或字符序列,并用它们出现的次数和每次出现的个数形成新的代码。从而达到数据压缩的目的。
编码:
- 定义mask:
mask = np.array([ [0, 1, 1, 0], [1, 1, 1, 1], [1, 0, 0, 1] ])
mask的图像就是上方那个简单的图像。
- 扁平化
mask.flatten()
> array([0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1])
- 首尾加0
np.concatenate([[0], mask, [0]])
> array([0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0])
- 每个元素与前一个元素比较,找到不同的位置
runs = np.where(mask[1:] != mask[:-1])[0] + 1
> array([ 2, 4, 5, 10, 12, 13])
- 作差,找到所有连续1的起始位置和个数
runs[1::2] -= runs[::2]
> array([ 2, 2, 5, 5, 12, 1])
至此,编码就完成了。
code = ' '.join(str(x) for x in runs)
注意到它这里的起始位置是从1开始算的
def mask2rle(msk, thr=0.5):
'''
img: numpy array, 1 - mask, 0 - background
Returns run length as string formated
'''
msk = cp.array(msk)
pixels = msk.flatten()
pad = cp.array([0])
pixels = cp.concatenate([pad, pixels, pad])
runs = cp.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return ' '.join(str(x) for x in runs)
我的解决方案
公榜0.882,私榜0.875,银牌top 3%,主要思想包括2.5D数据,横断面模型和冠状面模型融合,需要的完整流程见思维导图:
具体步骤如下:
数据预处理
1. 如何生成2.5D数据
采用的是在线生成2.5D数据的思想,并且隔层采样的非连续生成->[t-2, t ,t+2]
sample_step用于控制采样的间隔,这里我用的是隔一层采样,最终生成3通道的2.5D数据
class build_dataset(Dataset):
def __init__(self, df, label=True, sample_step=1, transforms=None, cfg=None):
self.df = df
self.label = label
self.img_paths = df['image_path'].tolist() # image
self.ids = df['id'].tolist()
if 'mask_path' in df.columns:
self.mask_paths = df['mask_path'].tolist() # mask
else:
self.mask_paths = None
self.sample_step = sample_step
self.transforms = transforms
self.n_25d_shift = cfg.n_25d_shift
def __len__(self):
return len(self.df)
def __getitem__(self, index):
#### load id
id = self.ids[index]
#### load image
img_path = self.img_paths[index]
img = self.load_2_5d_slice(img_path, self.sample_step) # [h, w, c]
h, w = img.shape[:2]
if self.label: # train
#### load mask
mask_path = self.mask_paths[index]
mask = np.load(mask_path).astype('float32')
mask/=255.0 # scale mask to [0, 1]
### augmentations
data = self.transforms(image=img, mask=mask)
img = data['image']
mask = data['mask']
class_label = np.array(self.df['class_new'][index])
img = np.transpose(img, (2, 0, 1)) # [h, w, c] => [c, h, w]
mask = np.transpose(mask, (2, 0, 1)) # [h, w, c] => [c, h, w]
return torch.tensor(img), torch.tensor(mask), torch.tensor(class_label)
else: # test
### augmentations
data = self.transforms(image=img)
img = data['image']
img = np.transpose(img, (2, 0, 1)) # [h, w, c] => [c, h, w]
return torch.tensor(img), id, h, w
###############################################################
##### >>>>>>> trick: construct 2.5d slice images <<<<<<
###############################################################
def load_2_5d_slice(self, middle_img_path, step):
#### 步骤1: 获取中间图片的基本信息
#### eg: middle_img_path: 'slice_0005_266_266_1.50_1.50.png'
middle_slice_num = os.path.basename(middle_img_path).split('_')[1] # eg: 0005
middle_str = 'slice_'+middle_slice_num
new_25d_imgs = []
##### 步骤2:按照左右n_25d_shift数量进行填充,如果没有相应图片填充为Nan.
##### 注:经过EDA发现同一天的所有患者图片的shape是一致的
for i in range(-self.n_25d_shift, self.n_25d_shift+1, step): # eg: i = {-2, -1, 0, 1, 2}
shift_slice_num = int(middle_slice_num) + i
shift_str = 'slice_'+str(shift_slice_num).zfill(4)
shift_img_path = middle_img_path.replace(middle_str, shift_str)
if os.path.exists(shift_img_path):
shift_img = cv2.imread(shift_img_path, cv2.IMREAD_UNCHANGED) # [w, h]
new_25d_imgs.append(shift_img)
else:
new_25d_imgs.append(None)
##### 步骤3:从中心开始往外循环,依次填补None的值
##### eg: n_25d_shift = 2, 那么形成5个channel, idx为[0, 1, 2, 3, 4], 所以依次处理的idx为[1, 3, 0, 4]
shift_left_idxs = []
shift_right_idxs = []
for related_idx in range(int(self.n_25d_shift/step)):
shift_left_idxs.append(int(self.n_25d_shift/step) - related_idx - 1)
shift_right_idxs.append(int(self.n_25d_shift/step) + related_idx + 1)
for left_idx, right_idx in zip(shift_left_idxs, shift_right_idxs):
if new_25d_imgs[left_idx] is None:
new_25d_imgs[left_idx] = new_25d_imgs[left_idx+1]
if new_25d_imgs[right_idx] is None:
new_25d_imgs[right_idx] = new_25d_imgs[right_idx-1]
new_25d_imgs = np.stack(new_25d_imgs, axis=2).astype('float32') # [w, h, c]
mx_pixel = new_25d_imgs.max()
if mx_pixel != 0:
new_25d_imgs /= mx_pixel
return new_25d_imgs
2.数据增强方式
经过实验,去除了原本的垂直翻转,使用的数据增强:
- 随机放射变换(ShiftScaleRotate):该方法可以对图片进行平移(translate)、缩放(scale)和旋转(roatate)
-
三种非刚体变换方法:弹性变换(ElasticTransform)、网格失真(GridDistortion) 和 畸变(OpticalDistortion)
-
矩形丢弃增强器(CoarseDropout):将图像中的矩形区域设置为零。
def build_transforms(CFG):
data_transforms = {
"train": A.Compose([
A.OneOf([
A.Resize(*CFG.img_size, interpolation=cv2.INTER_NEAREST, p=1.0),
], p=1),
A.HorizontalFlip(p=0.5),
# A.VerticalFlip(p=0.5),
A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
# A.RandomBrightnessContrast(p=0.5),
A.OneOf([
A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
], p=0.25),
A.CoarseDropout(max_holes=8, max_height=CFG.img_size[0]//20, max_width=CFG.img_size[1]//20,
min_holes=5, fill_value=0, mask_fill_value=0, p=0.5),
], p=1.0),
"valid_test": A.Compose([
A.Resize(*CFG.img_size, interpolation=cv2.INTER_NEAREST),
], p=1.0)
}
return data_transforms
模型选择
目前大家比较常用的模型包括UNet、UNet++、UperNet等,Bacbone包括EfficientNet、Swin-Transformer、ConvNeXt。比赛中更重要的是调参和模型融合。本人使用基本的Unet并加入了一些模块改进模型,包括ASPP、Hypercolumn。模块在smp源码的基础上直接加入使用:
- ASPP:传统 U 形网络的缺陷是由一个小的感受野造成的。因此,如果模型需要对大对象的分割做出决定,特别是对于大图像分辨率,它可能会因为只能查看对象的一部分而感到困惑。增加感受野并实现图像不同部分之间交互的一种方法是使用具有不同扩张的卷积块组合(在 ASPP 块中具有不同速率的 Atrous 卷积)。
class _ASPPModule(nn.Module):
def __init__(self, inplanes, planes, kernel_size, padding, dilation):
super(_ASPPModule, self).__init__()
self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
stride=1, padding=padding, dilation=dilation, bias=False)
self.bn = nn.BatchNorm2d(planes)
self.relu = nn.ReLU()
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
class ASPP(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels=None, dilations=[1, 6, 12, 18]):
super(ASPP, self).__init__()
self.context = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, dilation=1, bias=True),
nn.BatchNorm2d(mid_channels),
nn.ReLU(),
BaseOC_Context_Module(in_channels=mid_channels, out_channels=mid_channels,
key_channels=mid_channels // 2, value_channels=mid_channels,
dropout=0, sizes=([2])))
self.aspp1 = _ASPPModule(in_channels, mid_channels, 1, padding=0, dilation=dilations[0])
self.aspp2 = _ASPPModule(in_channels, mid_channels, 3, padding=dilations[1], dilation=dilations[1])
self.aspp3 = _ASPPModule(in_channels, mid_channels, 3, padding=dilations[2], dilation=dilations[2])
self.aspp4 = _ASPPModule(in_channels, mid_channels, 3, padding=dilations[3], dilation=dilations[3])
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(in_channels, mid_channels, 1, stride=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU())
out_channels = out_channels if out_channels is not None else mid_channels
self.conv1 = nn.Conv2d(mid_channels * 5, out_channels, 1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.context(x)
# x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear')
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.conv1(x)
x = self.relu(x)
x = self.bn1(x)
return x
- Hypercolumn: 本质上就是特征金字塔网络(FPN)的思想,解码器的不同上采样块和输出层之间的附加跳过连接。因此,最终预测是基于 U-net 输出与中间层调整大小的输出串联接产生的。这些跳跃连接为梯度传导提供了捷径以提高模型性能和收敛速度。由于中间层有许多通道,它们的上采样和用作最后一层的输入会在计算时间和内存方面引入大量开销。因此,在调整大小之前应用 3*3+3*3 卷积(分解)以减少通道数。
class FPN(nn.Module):
def __init__(self, input_channels: list, output_channels: list):
super().__init__()
self.convs = nn.ModuleList(
[nn.Sequential(nn.Conv2d(in_ch, out_ch * 2, kernel_size=3, padding=1),
nn.ReLU(inplace=True), nn.BatchNorm2d(out_ch * 2),
nn.Conv2d(out_ch * 2, out_ch, kernel_size=3, padding=1))
for in_ch, out_ch in zip(input_channels, output_channels)])
def forward(self, xs: list, last_layer):
hcs = [F.interpolate(c(x), scale_factor=2 ** (len(self.convs) - i), mode='bilinear', align_corners=True)
for i, (c, x) in enumerate(zip(self.convs, xs))]
hcs.append(last_layer)
return torch.cat(hcs, dim=1)
- backbone的选择:横断面的分辨率设置为[384, 384]。从efficientnetb2~b6的实验中最终选择了b6,其显示的本地CV最佳,另外尝试了se_resnext50_32x4d、se_resnext101_32x4d,效果依然不如efficientnetb6。
Efficientnet
- 根据以往的经验,增加网络的深度depth能够得到更加丰富、复杂的特征并且能够很好的应用到其它任务中。但网络的深度过深会面临梯度消失,训练困难的问题。
- 增加网络的width能够获得更高细粒度的特征并且也更容易训练,但对于width很大而深度较浅的网络往往很难学习到更深层次的特征。
-
增加输入网络的图像分辨率能够潜在得获得更高细粒度的特征模板,但对于非常高的输入分辨率,准确率的增益也会减小。并且大分辨率图像会增加计算量。
当计算资源增加时,如果彻底的去搜索宽度、深度、图片分辨率这三个变量的各种组合,那么搜索空间将无限大,搜索效率会非常低。作者提出了组合缩放的方法:
损失函数的设置
-
组合了三种loss:BCE+dice+边缘loss。从提交结果来看边缘loss带来小的提升,加入边缘loss的目的在于希望模型更加注重边缘的分割效果。边缘loss尝试过Active Contour Loss和Hausdorff Loss,Active Contour Loss没有什么效果,使用开源的Hausdorff Loss出错但是没有找到解决方案,后续可以尝试。最终采取的方案对图像进行拉普拉斯变换突出边缘后计算loss。
laplace = torch.tensor([[1, 1, 1],
[1, -8, 1],
[1, 1, 1]], dtype=torch.float, requires_grad=False).view(1, 1, 3, 3)
avgpool = torch.tensor([[1/9, 1/9, 1/9],
[1/9, 1/9, 1/9],
[1/9, 1/9, 1/9]], dtype=torch.float, requires_grad=False).view(1, 1, 3, 3)
def conv_operator(x, avgpool, laplace, in_channels=3):
x = nn.functional.conv2d(x, avgpool.repeat(1, in_channels, 1, 1), stride=1, padding=1,)
x = nn.functional.conv2d(x, avgpool.repeat(1, in_channels, 1, 1), stride=1, padding=1,)
x = nn.functional.conv2d(x, avgpool.repeat(1, in_channels, 1, 1), stride=1, padding=1,)
x = nn.functional.conv2d(x, laplace.repeat(1, in_channels, 1, 1), stride=1, padding=1,)
return x
def build_loss():
BCELoss = smp.losses.SoftBCEWithLogitsLoss()
TverskyLoss = smp.losses.TverskyLoss(mode='multilabel', log_loss=False)
FocalLoss = smp.losses.FocalLoss(mode='multilabel')
return {"BCELoss":BCELoss, "TverskyLoss":TverskyLoss, "FocalLoss":FocalLoss}
def train_one_epoch(model, train_loader, optimizer, scheduler, losses_dict, CFG):
model.train()
scaler = amp.GradScaler()
losses_all, bce_all, tverskly_all, boundary_all, cls_all = 0, 0, 0, 0, 0
pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc='Train ')
for _, (images, masks, class_label) in pbar:
optimizer.zero_grad()
images = images.to(CFG.device, dtype=torch.float) # [b, c, w, h]
masks = masks.to(CFG.device, dtype=torch.float) # [b, c, w, h]
class_label = class_label.to(CFG.device, dtype=torch.float)
with amp.autocast(enabled=True):
y_preds, class_preds = model(images) # [b, c, w, h]
class_loss = torch.nn.BCEWithLogitsLoss()(class_preds, class_label)
bce_loss = losses_dict["BCELoss"](y_preds, masks)
tverskly_loss = losses_dict["TverskyLoss"](y_preds, masks)
boundary_loss = nn.MSELoss()(conv_operator(torch.sigmoid(y_preds).cpu().float()), conv_operator(masks.cpu().float()))
losses = 10*bce_loss + tverskly_loss + 1000*boundary_loss + 5*class_loss
scaler.scale(losses).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
losses_all += losses.item() / images.shape[0]
bce_all += bce_loss.item() / images.shape[0]
tverskly_all += tverskly_loss.item() / images.shape[0]
boundary_all += boundary_loss.item() / images.shape[0]
cls_all += class_loss.item() / images.shape[0]
current_lr = optimizer.param_groups[0]['lr']
print("lr: {:.4f}".format(current_lr), flush=True)
print("loss: {:.3f}, bce_all: {:.3f}, tverskly_all: {:.3f}, boundary_all: {:.3f}, cls_all: {:.3f}".\
format(losses_all, bce_all, tverskly_all, boundary_all, cls_all), flush=True)
- 加入了深监督: 考虑到有的层没有标注,有的层可能只有1-2个标签,故加入深监督希望在encoder部分实现多标签的分类,将encoder的输出头连接全连接层实现三分类输出,相当于实现了先分类再分割。对应前面提到的赛题难点修改推理部分,以0.3作为阈值,大于0.3则将预测输出生成mask,小于0.3则不管预测值直接生成全0的mask。该损失函数依然是多标签loss所以选择的是torch.nn.BCEWithLogitsLoss()。
所谓深监督(Deep Supervision),就是在深度神经网络的某些中间隐藏层加了一个辅助的分类器作为一种网络分支来对主干网络进行监督的技巧,用来解决深度神经网络训练梯度消失和收敛速度过慢等问题。通常而言,增加神经网络的深度可以一定程度上提高网络的表征能力,但随着深度加深,会逐渐出现神经网络难以训练的情况,其中就包括像梯度消失和梯度爆炸等现象。为了更好的训练深度网络,我们可以尝试给神经网络的某些层添加一些辅助的分支分类器来解决这个问题。这种辅助的分支分类器能够起到一种判断隐藏层特征图质量好坏的作用。
模型集成
为了充分利用数据的3D特性,将每个病例的横断面拼接在一起构成完整的3D体素,然后由3D体素生成每个病例的冠状面和矢状面,建立相应的冠状面和矢状面分割模型。本地CV来看冠状面可以达到和横断面相近的分割性能,但是矢状面差强人意,所以最终的模型集成采用了横断面和冠状面模型的集成。并将设置横断面分割模型的权重为0.6。
这里冠状面分割模型的输入分辨率根据冠状面原始分辨率的特点设置成宽高比为1:1.5。
生成3D数据获得冠状面和矢状面:
def build_3D_img(folder_path):
'''
Take all slices and store them into a 3D numpy array
'''
slice_list = sorted(os.listdir(folder_path))
slice_pixel_list = []
for slice_filename in slice_list:
pixel_slice = Image.open(folder_path+slice_filename)
slice_pixel_list.append(np.array(pixel_slice))
pixel_slice.close()
return np.array(slice_pixel_list)
def build_3D_segmentation(segmentation_df, img_width, img_height):
'''
Take all segmentation slices, decode the RLE and
concatenate the slices together to obtain a 3D numpy array
'''
slice_pixel_list = []
for segmentation in segmentation_df["segmentation"].values:
if pd.isnull(segmentation):
slice_pixel_list.append(np.zeros((img_height,img_width)).astype("float64"))
else:
slice_pixel_list.append(decode_rle(segmentation, img_height, img_width))
return np.array(slice_pixel_list)
def generate_sagittal_slices(df, case_id, day, mask_path, MRI_3D_img,
large_bowel_3D_mask, small_bowel_3D_mask, stomach_3D_mask):
'''
Create slices for an MRI from a sagittal plane and
returns a dataframe with filepaths and groundtruths
'''
slice_info_list = []
for i in range(MRI_3D_img.shape[2]):
filename = "case{}_{}_sagittal_{}.png".format(case_id, day, i)
sagittal_slice = Image.fromarray(MRI_3D_img[:,:,i])
width, height = sagittal_slice.size
sagittal_slice.save("{}/{}".format(image_output_folder,filename))
large_bowel_rle = rle_encoding(large_bowel_3D_mask[:,:,i].T)
small_bowel_rle = rle_encoding(small_bowel_3D_mask[:,:,i].T)
stomach_rle = rle_encoding(stomach_3D_mask[:,:,i].T)
slice_info_list.append({"filename":filename,
"organ_class":"large_bowel",
"segmentation":large_bowel_rle,
"slice_plane":"sagittal",
"width":width,
"height":height,
"mask_path": mask_path})
slice_info_list.append({"filename":filename,
"organ_class":"small_bowel",
"segmentation":small_bowel_rle,
"slice_plane":"sagittal",
"width":width,
"height":height,
"mask_path": mask_path})
slice_info_list.append({"filename":filename,
"organ_class":"stomach",
"segmentation":stomach_rle,
"slice_plane":"sagittal",
"width":width,
"height":height,
"mask_path": mask_path})
return df.append(slice_info_list)
def generate_coronal_slices(df, case_id, day, mask_path, MRI_3D_img,
large_bowel_3D_mask, small_bowel_3D_mask, stomach_3D_mask):
'''
Create slices for an MRI from a coronal plane and
returns a dataframe with filepaths and groundtruths
'''
slice_info_list = []
for i in range(MRI_3D_img.shape[1]):
filename = "case{}_{}_coronal_{}.png".format(case_id, day, i)
coronal_slice = Image.fromarray(MRI_3D_img[:,i,:])
width, height = coronal_slice.size
coronal_slice.save("{}/{}".format(image_output_folder,filename))
large_bowel_rle = rle_encoding(large_bowel_3D_mask[:,i,:].T)
small_bowel_rle = rle_encoding(small_bowel_3D_mask[:,i,:].T)
stomach_rle = rle_encoding(stomach_3D_mask[:,i,:].T)
slice_info_list.append({"filename":filename,
"organ_class":"large_bowel",
"segmentation":large_bowel_rle,
"slice_plane":"coronal",
"width":width,
"height":height,
"mask_path": mask_path})
slice_info_list.append({"filename":filename,
"organ_class":"small_bowel",
"segmentation":small_bowel_rle,
"slice_plane":"coronal",
"width":width,
"height":height,
"mask_path": mask_path})
slice_info_list.append({"filename":filename,
"organ_class":"stomach",
"segmentation":stomach_rle,
"slice_plane":"coronal",
"width":width,
"height":height,
"mask_path": mask_path})
return df.append(slice_info_list)
def create_3D_img_and_seg(df, new_dataset_df, img_id, img_width, img_height):
'''
concatenate MRI slices and their segmentation information
'''
img_info_df = df[df.id.str.contains(img_id)]
case_id = img_info_df.iloc[0].case_id
day = img_info_df.iloc[0].day
mask_path = img_info_df.iloc[0].mask_path
segmentation_large_bowel_df = img_info_df.query("organ_class=='large_bowel'")
segmentation_small_bowel_df = img_info_df.query("organ_class=='small_bowel'")
segmentation_stomach_df = img_info_df.query("organ_class=='stomach'")
MRI_3D_img = build_3D_img("{}/case{}/case{}_{}/scans/".format(train_folder, case_id, case_id, day))
large_bowel_3D_mask = build_3D_segmentation(segmentation_large_bowel_df, img_width, img_height)
small_bowel_3D_mask = build_3D_segmentation(segmentation_small_bowel_df, img_width, img_height)
stomach_3D_mask = build_3D_segmentation(segmentation_stomach_df, img_width, img_height)
new_dataset_df = generate_sagittal_slices(new_dataset_df, case_id, day, mask_path, MRI_3D_img,
large_bowel_3D_mask, small_bowel_3D_mask, stomach_3D_mask)
new_dataset_df = generate_coronal_slices(new_dataset_df, case_id, day, mask_path, MRI_3D_img,
large_bowel_3D_mask, small_bowel_3D_mask, stomach_3D_mask)
return new_dataset_df
new_dataset_df = pd.DataFrame()
for img_id in tqdm(train_df.image_id.unique()):
id_df = train_df.query("image_id==@img_id").iloc[0]
img_width = id_df.width
img_height = id_df.height
new_dataset_df = create_3D_img_and_seg(train_df, new_dataset_df, img_id, img_width, img_height)
推理的模型集成:
@torch.no_grad()
def generate_3D(slice_info, ckpt_paths, CFG):
data_transforms = build_transforms(CFG)
test_dataset = build_dataset(slice_info, transforms=data_transforms['valid_test'], cfg=CFG)
test_loader = DataLoader(test_dataset, batch_size=CFG.valid_bs, num_workers=8, shuffle=False, pin_memory=False)
pbar = tqdm(enumerate(test_loader), total=len(test_loader))
large_bowel_list, small_bowel_list, stomach_list = [], [], []
for _, (images, ids, cases, days, h, w) in pbar:
images = images.to(CFG.device, dtype=torch.float) # [b, c, w, h]
size = images.size()
masks = torch.zeros((size[0], 3, size[2], size[3]), device=CFG.device, dtype=torch.float32) # [b, c, w, h]
labels = torch.zeros((size[0], 3), device=CFG.device, dtype=torch.float32) # [b, c]
for sub_ckpt_path in ckpt_paths:
model = build_model(CFG, test_flag=True)
model.load_state_dict(torch.load(sub_ckpt_path))
model.eval()
y_preds, label_preds = model(images) # [b, c, w, h]
y_preds = torch.nn.Sigmoid()(y_preds)
label_preds = torch.nn.Sigmoid()(label_preds)
masks += y_preds
labels += label_preds
if CFG.tta:
rotates = [10, 350]
for rotate in rotates:
images_f = F.rotate(images, rotate, resample=2, expand=False)
y_preds, label_preds = model(images_f) # [b, c, w, h]
y_preds = F.rotate(y_preds, 360-rotate, resample=2, expand=False)
y_preds = torch.nn.Sigmoid()(y_preds)
label_preds = torch.nn.Sigmoid()(label_preds)
masks += y_preds
labels += label_preds
images_f = torch.flip(images, [-1])
y_preds, label_preds = model(images_f) # [b, c, w, h]
y_preds = torch.flip(y_preds, [-1])
y_preds = torch.nn.Sigmoid()(y_preds)
label_preds = torch.nn.Sigmoid()(label_preds)
masks += y_preds
labels += label_preds
del model, y_preds, label_preds
gc.collect()
if CFG.tta:
total_ckpt_paths = len(ckpt_paths) * 4
else:
total_ckpt_paths = len(ckpt_paths)
masks /= total_ckpt_paths
labels /= total_ckpt_paths
masks = masks.permute((0, 2, 3, 1)).to(torch.float32).cpu().detach().numpy()
labels = labels.to(torch.float32).cpu().detach().numpy()
for idx in range(masks.shape[0]):
height = h[idx].item()
width = w[idx].item()
msk = cv2.resize(masks[idx],
dsize=(width, height),
interpolation=cv2.INTER_LINEAR) # back to original shape
large_bowel = msk[:,:,0] if labels[idx, 0] > CFG.class_thr else np.zeros((height, width))
small_bowel = msk[:,:,1] if labels[idx, 1] > CFG.class_thr else np.zeros((height, width))
stomach = msk[:,:,2] if labels[idx, 2] > CFG.class_thr else np.zeros((height, width))
large_bowel_list.append(large_bowel)
small_bowel_list.append(small_bowel)
stomach_list.append(stomach)
del images, masks
gc.collect()
del test_dataset, test_loader
gc.collect()
if CFG.anatomical_plane == "transverse":
large_bowel_preds = np.stack(large_bowel_list,axis=0)
small_bowel_preds = np.stack(small_bowel_list,axis=0)
stomach_preds = np.stack(stomach_list,axis=0)
if CFG.anatomical_plane == "coronal":
large_bowel_preds = np.stack(large_bowel_list,axis=1)
small_bowel_preds = np.stack(small_bowel_list,axis=1)
stomach_preds = np.stack(stomach_list,axis=1)
if CFG.anatomical_plane == "sagittal":
large_bowel_preds = np.stack(large_bowel_list,axis=2)
small_bowel_preds = np.stack(small_bowel_list,axis=2)
stomach_preds = np.stack(stomach_list,axis=2)
del large_bowel_list, small_bowel_list, stomach_list
gc.collect()
return large_bowel_preds, small_bowel_preds, stomach_preds
pred_strings = []
pred_ids = []
pred_classes = []
for index, row in test_df2.iterrows():
case_id = row.case
day = row.day
if sub_firset:
folder_path = "../input/uw-madison-gi-tract-image-segmentation/train/case{}/case{}_day{}/scans/".format(case_id, case_id, day)
else:
folder_path = "../input/uw-madison-gi-tract-image-segmentation/test/case{}/case{}_day{}/scans/".format(case_id, case_id, day)
slice_list = sorted(os.listdir(folder_path))
slice_pixel_list = []
for slice_filename in slice_list:
pixel_slice = cv2.imread(folder_path+slice_filename, cv2.IMREAD_UNCHANGED)
slice_pixel_list.append(pixel_slice)
MRI_3D_img = np.array(slice_pixel_list).astype('float32')
del slice_list, slice_pixel_list
gc.collect()
transverse_slice_info_list = []
coronal_slice_info_list = []
for i in range(MRI_3D_img.shape[1]):
slice_ = MRI_3D_img[:,i,:].T
image_id = "case{}_day{}_slice".format(case_id, day)+str(i).zfill(4)
width, height = slice_.shape
slice_number = i
coronal_slice_info_list.append({"id":image_id,
"image": slice_.T,#h,w
"width":width,
"height":height,
"slice": slice_number})
for i in range(MRI_3D_img.shape[0]):
slice_ = MRI_3D_img[i,:,:].T # d,h,w
image_id = "case{}_day{}_slice".format(case_id, day)+str(i).zfill(4)
width, height = slice_.shape
slice_number = i
transverse_slice_info_list.append({"id":image_id,
"image": slice_.T,#h,w
"width":width,
"height":height,
"slice": slice_number})
transverse_slice_info = pd.DataFrame(transverse_slice_info_list)
coronal_slice_info = pd.DataFrame(coronal_slice_info_list)
display(transverse_slice_info.head(5))
del MRI_3D_img, coronal_slice_info_list, transverse_slice_info_list
gc.collect()
CFG.anatomical_plane = "transverse"
CFG.img_size = [384,384]
CFG.backbone = 'efficientnet-b6'
CFG.ckpt_name = "efficientnetb6Plus_img384384_bs48_fold5_2.5d_channel3_step2_boundary_delfault"
ckpt_path = f"../input/{CFG.ckpt_fold}/{CFG.ckpt_name}"
ckpt_paths = glob(f'{ckpt_path}/best*')
assert len(ckpt_paths) == CFG.n_fold, "ckpt path error!"
large_bowel_preds1, small_bowel_preds1, stomach_preds1 = generate_3D(transverse_slice_info, ckpt_paths, CFG)
del transverse_slice_info, ckpt_paths
gc.collect()
CFG.anatomical_plane = "coronal"
CFG.img_size = [256,384]
CFG.backbone = 'efficientnet-b6'
CFG.ckpt_name = "efficientnetb6Plus_img256384_bs64_fold5_2.5d_channel3_step2_online_cls_delfault_coronal"
ckpt_path = f"../input/{CFG.ckpt_fold}/{CFG.ckpt_name}"
ckpt_paths = glob(f'{ckpt_path}/best*')
assert len(ckpt_paths) == CFG.n_fold, "ckpt path error!"
large_bowel_preds2, small_bowel_preds2, stomach_preds2 = generate_3D(coronal_slice_info, ckpt_paths, CFG)
del coronal_slice_info, ckpt_paths
gc.collect()
large_bowel_preds1 = CFG.transverse_weight*large_bowel_preds1 + (1 - CFG.transverse_weight)*large_bowel_preds2
small_bowel_preds1 = CFG.transverse_weight*small_bowel_preds1 + (1 - CFG.transverse_weight)*small_bowel_preds2
stomach_preds1 = CFG.transverse_weight*stomach_preds1 + (1 - CFG.transverse_weight)*stomach_preds2
del large_bowel_preds2, small_bowel_preds2, stomach_preds2
gc.collect()
large_bowel_preds1 = np.where(large_bowel_preds1 > CFG.thr, 1, 0)
small_bowel_preds1 = np.where(small_bowel_preds1 > CFG.thr, 1, 0)
stomach_preds1 = np.where(stomach_preds1 > CFG.thr, 1, 0)
for idx, id in enumerate(row.id):
large_bowel_preds1[idx,:,:] = area_connection(large_bowel_preds1[idx,:,:], 9)
small_bowel_preds1[idx,:,:] = area_connection(small_bowel_preds1[idx,:,:], 9)
stomach_preds1[idx,:,:] = area_connection(stomach_preds1[idx,:,:], 9)
rle = [mask2rle(large_bowel_preds1[idx,:,:]), mask2rle(small_bowel_preds1[idx,:,:]), mask2rle(stomach_preds1[idx,:,:])]
pred_strings.extend(rle)
pred_ids.extend([id]*len(rle))
pred_classes.extend(['large_bowel', 'small_bowel', 'stomach'])
del large_bowel_preds1, small_bowel_preds1, stomach_preds1
gc.collect()
后处理
这里后处理除了前面通过encoder输出生成置信度高的mask以外,还去除了连通域中小的空洞和很小的连通域。实验来看这部分作用并不显著,但是可能有0.0001-0.001的影响。尝试过TTA的手段,没有作用。
def area_connection(result,area_threshold):
"""
result:预测影像
area_threshold:最小连通尺寸,小于该尺寸的都删掉
"""
# 去除小物体
result = skimage.morphology.remove_small_objects(result==1, min_size=area_threshold, connectivity=1, in_place=True)
# 去除孔洞
result = skimage.morphology.remove_small_holes(result==1, area_threshold=area_threshold, connectivity=1, in_place=True)
return result
large_bowel_preds1[idx,:,:] = area_connection(large_bowel_preds1[idx,:,:], 9)
small_bowel_preds1[idx,:,:] = area_connection(small_bowel_preds1[idx,:,:], 9)
stomach_preds1[idx,:,:] = area_connection(stomach_preds1[idx,:,:], 9)
参数设置
class CFG:
# step1: hyper-parameter
seed = 601 # birthday
num_worker = 8 # debug => 0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ckpt_fold = "ckpt-ruan"
ckpt_name = "efficientnetb6Plus_img384384_bs48_fold5_2.5d_channel3_step2_boundary_delfault"
# step2: data
n_25d_shift = 2
sample_step = 2
n_fold = 5
img_size = [384, 384]
train_bs = 48
valid_bs = 64
# step3: model
backbone = 'efficientnet-b6'
num_classes = 3
# step4: optimizer
epoch = 30
lr = 1e-3
scheduler = 'CosineAnnealingLR'
min_lr = 1e-6
T_max = int(30000/train_bs*epoch)+50
T_0 = 25
warmup_epochs = 0
wd = 1e-6
# step5: infer
thr = 0.4
结合金牌方案的总结
分类+分割是最佳方案的思路关键。分类中又可以分为两种思路:
-
直接分类:本人方案中就是加入了classfication loss就是为了实现分类,除外可以借鉴方案
-
区域提取:目标检测的一种思想,提取出包含标签的区域将该positive的切片作为下一阶段分割模型的训练集,是一种两阶段的思想。可以将第一阶段的切片预测结果进一步裁剪减少伪影的影响并降低内存需求。
本人在encoder部分加入了分类的思想,可能比较粗糙。另外,如何利用3D的信息也是本次比赛的关键。