多器官分割比赛总结

比赛目标:

本次竞赛的目标是确定来自几个不同器官的活检切片中每个功能组织单元(FTU)的位置。基础数据包括来自不同来源的图像,这些图像采用不同的协议以不同的分辨率制作,反映了处理医疗数据的典型挑战。本次比赛使用了来自两个不同联盟的数据,即人类蛋白质图谱(HPA)和人类生物分子图谱计划(HuBMAP)。训练数据集由公共HPA数据中的数据组成,公共测试集是私有HPA数据和HuBMAP数据的组合,私有测试集仅包含HuBMAP数据。当使用不同协议下的数据时,调整模型以使其正常工作将是这场竞争的核心挑战之一。开发泛化的模型是这项工作的关键目标。

难点

In this competition we have:

  • Private dataset has a different set of image scales compared to the train (relatively easy to model)
  • Private dataset has a different color domain (different stains which attach to different molecules/tissues) - (Harder to model)
  • Different slice thickness… That's going to be tough to incorporate.

Color Domain

数字病理切片的制作首先需要组织染色。为了突出切片中特定的细胞核和腺体特征,限定并检查组织,通常使用染色剂来增强组织成分间的对比度,,主要包括苏木精-伊红(hematoxylin-eosin, 简称H&E)和免疫组织化学(immuno histo chemical, 简称IHC),H&E是最常用的染色方法。与H&E常规染色相比,IHC染色利用抗原抗体的特异性结合反应来检测和定位组织和细胞中的某些化学物质,具有较高的敏感性,可将形态学改变与功能代谢变化相结合,从而能够鉴别、诊断和治疗恶性肿。H&E的问题在于,在一周的不同日期进行染色时,实验室中的染色变异很大(HPA和HuBMAP的染色标准不一),甚至在同一实验室也是如此。这是因为最终结果很大程度上取决于染料的类型和密度以及组织实际暴露于染色剂的时间。


常用解决办法:

  1. 染色归一化:不同的实验室和扫描仪可以为特定污渍生成具有不同颜色配置文件的图像。染色标准化的目标是标准化这些染色的外观。传统上,使用颜色匹配 [ Reinhard2001 ] 和染色分离 [ Macenko2009 , Khan2014 , Vahadane2016 ] 等方法。但是,这些方法依赖于选择单个参考幻灯片。

病理图像常用的颜色标准化方法
STST: 这个方法是用c-gan做的,但我实际跑的时候效果很差,细胞核细胞质颜色很相近,分不开,可能因为他是对灰度图像染色,所以模型分不太开细胞核跟细胞质,不建议用。
Vahadane: 这个是比较推荐的方法。这个方法是用非负矩阵分解得到两个染料矩阵,然后将reference和source的进行配准,然后再合成新的图像。这个方法比较稳定,得到的图像颜色比较真实。实际的标化速度也可以接受。
Khan: 效果里面看很差,不太建议用。
Macehko: 这个方法是将RGB转成OD,然后再用SVD分解得到两个垂直的颜色矩阵,再将这个颜色矩阵进行归一化,或者归一化到reference上。在实际应用的时候得到的图像颜色不是很自然,会变得比较奇怪。一方面可能是SVD分解并没有非负分解这么好用,另一个可能是reference图像选择的不好。但速度上很快,效果一般。
Reinhard: 最开始并不是用在病理图像上的,它是用在自然图像上的,就是在Lab空间把两个图像的颜色进行统计学上的匹配。这个效果很差,但很快。差的原因很多,Lab颜色空间并不符合病理图像的光学特性,OD和HED更加符合他们的特性。其次,这个方法受输入图像的影响非常大,有些组织少的区域,为了得到相同的统计学分布,会将组织染色非常深,总>之就很不智能。不推荐用。

  1. 色彩增强:通过应用随机仿射变换或添加噪声来增强图像是对抗过度拟合的最常见正则化技术之一。类似地,可以利用染色的变化来增加训练期间呈现给模型的图像外观的多样性。虽然颜色的剧烈变化对于组织学来说是不现实的,但通过对每个颜色通道的随机加法和乘法变化产生的更微妙的变化已被证明可以提高模型性能。颜色增强的强度是一个额外的超参数,应该在训练期间进行试验,并在来自不同实验室或扫描仪的测试集上进行验证。
  2. 无监督域对抗训练:域适应的下一个技术是域对抗训练。这种方法利用来自目标域的未标记图像。域对抗模块被添加到现有模型中。该分类器的目标是预测图像属于源域还是目标域。梯度反转层将此模块连接到现有网络,以便训练优化原始任务并鼓励网络学习域不变特征。
    特征提取器提取的信息会传入域分类器,之后域分类器会判断传入的信息到底是来自源域还是目标域,并计算损失。域分类器的训练目标是尽量将输入的信息分到正确的域类别(源域还是目标域),而特征提取器的训练目标却恰恰相反(由于梯度反转层的存在),特征提取器所提取的特征(或者说映射的结果)目的是是域判别器不能正确的判断出信息来自哪一个域,因此形成一种对抗关系。特征提取器提取的信息也会传入Label predictor (类别预测器)了,因为源域样本是有标记的,所以在提取特征时不仅仅要考虑后面的域判别器的情况,还要利用源域的带标记样本进行有监督训练从而兼顾分类的准确性。

解决方案

采用了染色增强+CutMix+CutOut+像素大小自适应的数据增强策略。采用原始数据集训练CoaT+Daformer和Swin transformer+ UPerNet模型,然后用过采样肺部策略训练Swin transformer+UPerNet作为针对肺部的预测最终形成集成预测。loss采用Focal loss+Dice loss进行难例挖掘。

A Comprehensive Study of Vision Transformers on Dense Prediction Tasks
文章研究了 Vision Transformers (VTs),VTs 和 CNN 作为特征提取器的不同方面,用于对具有挑战性的现实世界数据进行目标检测和语义分割。实验得出的主要结果和主要见解如下:

  • VTs 在分布式数据集中优于 CNN,同时具有较低的推理速度,但计算复杂度较低。因此,如果 GPU 针对 Transformer 架构进行了优化,它们就有可能在计算机视觉领域占据主导地位。
  • VTs 可以更好地泛化到 OOD 数据集。我们的损失情况分析表明,与 CNN 相比,VTs 收敛到更平坦的最小值,这可以解释它们的普遍性。
  • 与 CNN 相比,VTs 对自然损坏和对抗性攻击更稳健。我们认为这可以归因于全局感受域以及自我注意的动态特性。
  • VTs 的纹理偏差比 CNN 少,这可以归因于它们的全局感受野,这使得它们能够更好地关注基于全局形状的线索,而不是基于局部纹理的线索。

数据增强策略

基本数据增强
包括翻转、旋转、对比度、HSV颜色空间、噪声和一些尺度变换。

def do_random_flip(image, mask):
    if np.random.rand()>0.5:
        image = cv2.flip(image,0)
        mask = cv2.flip(mask,0)
    if np.random.rand()>0.5:
        image = cv2.flip(image,1)
        mask = cv2.flip(mask,1)
    if np.random.rand()>0.5:
        image = image.transpose(1,0,2)
        mask = mask.transpose(1,0)
    
    image = np.ascontiguousarray(image)
    mask = np.ascontiguousarray(mask)
    return image, mask

def do_random_rot90(image, mask):
    r = np.random.choice([
        0,
        cv2.ROTATE_90_CLOCKWISE,
        cv2.ROTATE_90_COUNTERCLOCKWISE,
        cv2.ROTATE_180,
    ])
    if r==0:
        return image, mask
    else:
        image = cv2.rotate(image, r)
        mask = cv2.rotate(mask, r)
        return image, mask
    
def do_random_contast(image, mask, mag=0.3):
    alpha = 1 + random.uniform(-1,1)*mag
    image = image * alpha
    image = np.clip(image,0,1)
    return image, mask

def do_random_hsv(image, mask, mag=[0.15,0.25,0.25]):
    image = (image*255).astype(np.uint8)
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

    h = hsv[:, :, 0].astype(np.float32)  # hue
    s = hsv[:, :, 1].astype(np.float32)  # saturation
    v = hsv[:, :, 2].astype(np.float32)  # value
    h = (h*(1 + random.uniform(-1,1)*mag[0]))%180
    s =  s*(1 + random.uniform(-1,1)*mag[1])
    v =  v*(1 + random.uniform(-1,1)*mag[2])

    hsv[:, :, 0] = np.clip(h,0,180).astype(np.uint8)
    hsv[:, :, 1] = np.clip(s,0,255).astype(np.uint8)
    hsv[:, :, 2] = np.clip(v,0,255).astype(np.uint8)
    image = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
    image = image.astype(np.float32)/255
    return image, mask


def do_random_noise(image, mask, mag=0.1):
    height, width = image.shape[:2]
    noise = np.random.uniform(-1,1, (height, width,1))*mag
    image = image + noise
    image = np.clip(image,0,1)
    return image, mask

def do_random_rotate_scale(image, mask, angle=30, scale=[0.8,1.2] ):
    angle = np.random.uniform(-angle, angle)
    scale = np.random.uniform(*scale) if scale is not None else 1
    
    height, width = image.shape[:2]
    center = (height // 2, width // 2)
    
    transform = cv2.getRotationMatrix2D(center, angle, scale)
    image = cv2.warpAffine( image, transform, (width, height), flags=cv2.INTER_LINEAR,
                            borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0))
    mask  = cv2.warpAffine( mask, transform, (width, height), flags=cv2.INTER_LINEAR,
                            borderMode=cv2.BORDER_CONSTANT, borderValue=0)
    return image, mask

针对性的数据增强
解决训练集HPA和测试集Hubmap中的图片像素间距层厚不同的自适应策略:

imaging_measurements = {
    'hpa': {
        'pixel_size': {
            'kidney': 0.4,
            'prostate': 0.4,
            'largeintestine': 0.4,
            'spleen': 0.4,
            'lung': 0.4
        },
        'tissue_thickness': {
            'kidney': 4,
            'prostate': 4,
            'largeintestine': 4,
            'spleen': 4,
            'lung': 4
        }
    },
    'hubmap': {
        'pixel_size': {
            'kidney': 0.5,
            'prostate': 6.263,
            'largeintestine': 0.229,
            'spleen': 0.4945,
            'lung': 0.7562
        },
        'tissue_thickness': {
            'kidney': 10,
            'prostate': 5,
            'largeintestine': 8,
            'spleen': 4,
            'lung': 5
        }
    }
}

def pixelSize_tissueThickness_adaptation(image, mask, organ, alpha=0.15):
    image = (image*255).astype(np.uint8)
    domain_pixel_size=imaging_measurements['hpa']['pixel_size'][organ],
    target_pixel_size=imaging_measurements['hubmap']['pixel_size'][organ],
    domain_tissue_thickness=imaging_measurements['hpa']['tissue_thickness'][organ],
    target_tissue_thickness=imaging_measurements['hubmap']['tissue_thickness'][organ],
    
    # Augment tissue thickness
    tissue_thickness_scale_factor = target_tissue_thickness[0] - domain_tissue_thickness[0]
    image_hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32)
    image_hsv[:, :, 1] *= (1 + (alpha * tissue_thickness_scale_factor))
    image_hsv[:, :, 2] *= (1 - (alpha * tissue_thickness_scale_factor))
    image_hsv = image_hsv.astype(np.uint8)
    image_scaled = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2RGB)
    
    # Standardize luminosity
    image_scaled = staintools.LuminosityStandardizer.standardize(image_scaled)

    # Augment pixel size
    pixel_size_scale_factor = domain_pixel_size[0] / target_pixel_size[0]
    image_resized = cv2.resize(
        image_scaled,
        dsize=None,
        fx=pixel_size_scale_factor,
        fy=pixel_size_scale_factor,
        interpolation=cv2.INTER_CUBIC
    )
    image_resized = cv2.resize(
        image_resized,
        dsize=(
            image.shape[1],
            image.shape[0]
        ),
        interpolation=cv2.INTER_CUBIC
    )
    
    # Standardize luminosity
    image = staintools.LuminosityStandardizer.standardize(image)
    image_augmented = staintools.LuminosityStandardizer.standardize(image_resized)

    image = image_augmented.astype(np.float32)/255
    
    return image, mask

解决训练集HPA和测试集Hubmap中的染色标准不一的颜色增强策略:

def color_transfer(image, mask):
    image = (image*255).astype(np.uint8)
    hed_lighter_aug = stainlib.augmentation.augmenter.HedLightColorAugmenter()
    hed_lighter_aug.randomize()
    transformed = hed_lighter_aug.transform(image)
    image = image.astype(np.float32)/255
    return image, mask

提高在测试集泛化能力的其他数据增强策略:CutMix+CutOut(数据增强:Mixup,Cutout,CutMix Mosaic),提供高强度的、变化的扰动。消融实验显示CutMix+CutOut的涨点是可观的。

#  CutMix 的切块功能
def rand_bbox(size, lam):
    if len(size) == 4:
        W = size[2]
        H = size[3]
    elif len(size) == 3:
        W = size[0]
        H = size[1]
    else:
        raise Exception

    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

randoms_int = random.randint(0, 99)
if randoms_int < 20:
    rand_index = random.randint(0, len(self.df)-1)
    rand_path = self.df.loc[rand_index, 'image_path']
    rand_img_height = self.df.loc[rand_index, 'img_height']
    rand_img_width = self.df.loc[rand_index, 'img_width']

    if stained:
        rand_image = tifffile.imread(rand_path.replace('train_images', 'train_images_augment'))
    else:
        rand_image = tifffile.imread(rand_path)
    
    rand_rle = self.df.loc[rand_index, 'rle']
    rand_mask = rle2mask(rand_rle, (rand_img_height, rand_img_width))

    rand_image = rand_image.astype(np.float32)/255

    rand_image = cv2.resize(rand_image,dsize=(image_size,image_size),interpolation=cv2.INTER_LINEAR)
    rand_mask  = cv2.resize(rand_mask, dsize=(image_size,image_size),interpolation=cv2.INTER_LINEAR)

    lam = np.random.beta(1,1)
    bbx1, bby1, bbx2, bby2 = rand_bbox(rand_image.shape, lam)

    image[bbx1:bbx2, bby1:bby2, :] = rand_image[bbx1:bbx2, bby1:bby2, :]
    mask[bbx1:bbx2, bby1:bby2] = rand_mask[bbx1:bbx2, bby1:bby2]

if randoms_int >= 80:
    alpha = random.uniform(0.2, 0.5)
    y = np.random.randint(image_size)
    x = np.random.randint(image_size)
    #划出正方形区域,边界处截断
    y1 = np.clip(y - int(alpha*image_size) // 2, 0, image_size)
    y2 = np.clip(y + int(alpha*image_size) // 2, 0, image_size)
    x1 = np.clip(x - int(alpha*image_size) // 2, 0, image_size)
    x2 = np.clip(x + int(alpha*image_size) // 2, 0, image_size)
    #全0填充区域
    image[y1: y2, x1: x2, :] = 0
    mask[y1: y2, x1: x2] = 0

fname = self.fnames[index]
organ = self.organ_to_label[self.df.loc[index].organ]

CutOut和CutMix已被证实可以用于语义分割(Semi-supervised semantic segmentation needs strong, varied perturbations),其在语义分割中的作用符合人的直观理解,但是其背后的数学解释值得探究。本比赛中也发现一个有趣的现象,好像CutOut和CutMix对于transformer模型相比CNN更有价值。

模型主要利用transformer模型作为特征提取器(Encoder),然后采用不用的Decoder构成语义分割网络。



Swin transformer+UPerNet总结:

Swin transformer

性能优于DeiT、ViT和EfficientNet等主干网络,已经替代经典的CNN架构,成为了计算机视觉领域通用的backbone。它基于了ViT模型的思想,创新性的引入了滑动窗口机制,让模型能够学习到跨窗口的信息,同时也。同时通过下采样层,使得模型能够处理超分辨率的图片,节省计算量以及能够关注全局和局部的信息。
目前将 Transformer 从自然语言处理领域应用到计算机视觉领域主要有两大挑战:

  • 视觉实体的方差较大,例如同一个物体,拍摄角度不同,转化为二进制后的图片就会具有很大的差异。同时在不同场景下视觉 Transformer 性能未必很好。
  • 图像分辨率高,像素点多,如果采用ViT模型,自注意力的计算量会与像素的平方成正比。
    针对上述两个问题,论文中提出了一种基于滑动窗口机制,具有层级设计(下采样层) 的 Swin Transformer。

其中滑窗操作包括不重叠的 local window,和重叠的 cross-window。将注意力计算限制在一个窗口(window size固定)中,一方面能引入 CNN 卷积操作的局部性,另一方面能大幅度节省计算量,它只和窗口数量成线性关系。

整体结构

整个模型采取层次化的设计,一共包含 4 个 Stage,除第一个 stage 外,每个 stage 都会先通过 Patch Merging 层缩小输入特征图的分辨率,进行下采样操作,像 CNN 一样逐层扩大感受野,以便获取到全局的信息:

  • 在输入开始的时候做了一个Patch Partition,即ViT中Patch Embedding操作,通过 Patch_size 为4的卷积层将图片切成一个个 Patch ,并嵌入到Embedding,将 embedding_size转变为48(可以将 CV 中图片的通道数理解为NLP中token的词嵌入长度)。
  • 随后在第一个Stage中,通过Linear Embedding调整通道数为C。
  • 在每个 Stage 里(除第一个 Stage ),均由Patch Merging和多个Swin Transformer Block组成。
  • Swin Transformer Block注意这里的Block其实有两种结构W-MSASW-MSA。两个结构是成对使用的,先使用一个W-MSA结构再使用一个SW-MSA结构。
  • Patch Merging模块主要在每个 Stage 一开始降低图片分辨率,进行下采样的操作。
  • Swin Transformer Block具体结构如右图所示,主要是LayerNorm,Window Attention ,Shifted Window Attention和MLP组成 。

Patch Embedding

在输入进 Block 前,我们需要将图片切成一个个 patch,然后嵌入向量。具体做法是对原始图片裁成一个个 window_size * window_size 的窗口大小,然后进行嵌入。这里可以通过二维卷积层,将 stride,kernel_size 设置为 window_size 大小。设定输出通道来确定嵌入向量的大小。最后将 H,W 维度展开,并移动到第一维度。
class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """
    def __init__(self,
                 patch_size=4,
                 in_chans=3,
                 embed_dim=96,
                 norm_layer=None
                 ):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size
        
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None
    
    def forward(self, x):
        B, C, H, W = x.shape
        
        # padding
        if W % self.patch_size[1] != 0:
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
        if H % self.patch_size[0] != 0:
            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
            
            
        x = self.proj(x)  # B C Wh Ww
        if self.norm is not None:
            Wh, Ww = x.size(2), x.size(3)
            x = x.flatten(2).transpose(1, 2)
            x = self.norm(x)
            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
        
        return x
#通过程序实现可以发现其并没有使用nn.Linear转换输入通道数,而是使用nn.Conv2d在进行patches转换时同时更换了通道数。

Patch Merging

Patch Merging的作用是分辨率减半,通道数加倍,类似于CNN的作用,在Transformer中实现Hierarchical。用在每个 Stage 开始前做降采样能节省一定运算量。在 CNN 中,则是在每个 Stage 开始前用stride=2的卷积/池化层来降低分辨率。
每次降采样是两倍,因此在行方向和列方向上,间隔 2 选取元素。然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的 4 倍(因为 H,W 各缩小 2 倍),此时再通过一个全连接层再调整通道维度为原来的两倍。

# H,W as input
# padding
class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)
    
    def forward(self, x, H, W):
        """
        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """
     
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
 
        
        x = x.view(B, H, W, C)
        # padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
        
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B, H/2, W/2, 4*C
        x = x.view(B, -1, 4 * C)  # B, H/2*W/2, 4*C
        
        x = self.norm(x)
        x = self.reduction(x)
        
        return x

Window Partition/Reverse

window partition函数是用于对张量划分窗口,指定窗口大小。将原本的张量从 N H W C, 划分成 num_windows*B, window_size, window_size, C,其中 num_windows = H*W / (window_size*window_size),即窗口的个数。而window reverse函数则是对应的逆过程。这两个函数会在后面的Window Attention用到。


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

Window Attention

这传统的 Transformer 都是基于全局来计算注意力的,因此计算复杂度十分高。而 Swin Transformer 则将注意力的计算限制在每个窗口内,进而减少了计算量。
先简单看下公式:
主要区别是在原始计算Attention的公式中的Q,K时加入了相对位置编码。后续实验有证明相对位置编码的加入提升了模型性能。

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim # 输入通道的数量
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0]) # coords_h = tensor([0,1,2,...,self.window_size[0]-1])  维度=Wh
        coords_w = torch.arange(self.window_size[1]) # coords_w = tensor([0,1,2,...,self.window_size[1]-1])  维度=Ww

        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww


        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1

        '''
        后面我们需要将其展开成一维偏移量。而对于(2,1)和(1,2)这两个坐标,在二维上是不同的,但是通过将x\y坐标相加转换为一维偏移的时候
        他们的偏移量是相等的,所以需要对其做乘法操作,进行区分
        '''

        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        # 计算得到相对位置索引
        # relative_position_index.shape = (M2, M2) 意思是一共有这么多个位置
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww 

        '''
        relative_position_index注册为一个不参与网络学习的变量
        '''
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        '''
        使用从截断正态分布中提取的值填充输入张量
        self.relative_position_bias_table 是全0张量,通过trunc_normal_ 进行数值填充
        '''
        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            N: number of all patches in the window
            C: 输入通过线性层转化得到的维度C
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        '''
        x.shape = (num_windows*B, N, C)
        self.qkv(x).shape = (num_windows*B, N, 3C)
        self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).shape = (num_windows*B, N, 3, num_heads, C//num_heads)
        self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).shape = (3, num_windows*B, num_heads, N, C//num_heads)
        '''
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        '''
        q.shape = k.shape = v.shape = (num_windows*B, num_heads, N, C//num_heads)
        N = M2 代表patches的数量
        C//num_heads代表Q,K,V的维数
        '''
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # q乘上一个放缩系数,对应公式中的sqrt(d)
        q = q * self.scale

        # attn.shape = (num_windows*B, num_heads, N, N)  N = M2 代表patches的数量
        attn = (q @ k.transpose(-2, -1))

        '''
        self.relative_position_bias_table.shape = (2*Wh-1 * 2*Ww-1, nH)
        self.relative_position_index.shape = (Wh*Ww, Wh*Ww)
        self.relative_position_index矩阵中的所有值都是从self.relative_position_bias_table中取的
        self.relative_position_index是计算出来不可学习的量
        '''
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

        '''
        attn.shape = (num_windows*B, num_heads, M2, M2)  N = M2 代表patches的数量
        .unsqueeze(0):扩张维度,在0对应的位置插入维度1
        relative_position_bias.unsqueeze(0).shape = (1, num_heads, M2, M2)
        num_windows*B 通过广播机制传播,relative_position_bias.unsqueeze(0).shape = (1, nH, M2, M2) 的维度1会broadcast到数量num_windows*B
        表示所有batch通用一个索引矩阵和相对位置矩阵
        '''
        attn = attn + relative_position_bias.unsqueeze(0)

        # mask.shape = (num_windows, M2, M2)
        # attn.shape = (num_windows*B, num_heads, M2, M2)
        if mask is not None:
            nW = mask.shape[0]
            # attn.view(B_ // nW, nW, self.num_heads, N, N).shape = (B, num_windows, num_heads, M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
            # mask.unsqueeze(1).unsqueeze(0).shape =                (1, num_windows, 1,         M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
            # broadcast相加
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            # attn.shape = (B, num_windows, num_heads, M2, M2)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        '''
        v.shape = (num_windows*B, num_heads, M2, C//num_heads)  N=M2 代表patches的数量, C//num_heads代表输入的维度
        attn.shape = (num_windows*B, num_heads, M2, M2)
        attn@v .shape = (num_windows*B, num_heads, M2, C//num_heads)
        '''
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)   # B_:num_windows*B  N:M2  C=num_heads*C//num_heads

        #   self.proj = nn.Linear(dim, dim)  dim = C
        #   self.proj_drop = nn.Dropout(proj_drop)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x  # x.shape = (num_windows*B, N, C)  N:窗口中所有patches的数量

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops
  1. 首先输入张量形状为 [numWindows*B, window_size* window_size, C]
  2. 然后经过self.qkv这个全连接层后,进行 reshape,调整轴的顺序,得到形状为[3, numWindows*B, num_heads, window_size*window_size, c//num_heads],并分配给q,k,v。
  3. 根据公式,我们对q乘以一个scale缩放系数,然后与k(为了满足矩阵乘要求,需要将最后两个维度调换)进行相乘。得到形状为[numWindows*B, num_heads, window_size*window_size, window_size*window_size]的attn张量。
  4. 之前我们针对位置编码设置了个形状为(2*window_size-1*2*window_size-1, numHeads)的可学习变量。我们用计算得到的相对编码位置索引self.relative_position_index.vew(-1)选取,得到形状为(window_size*window_size, window_size*window_size, numHeads)的编码,再permute(2,0,1)后加到attn张量上。
  5. 暂不考虑 mask 的情况,剩下就是跟 transformer 一样的 softmax,dropout,与V矩阵乘,再经过一层全连接层和dropout。

相关位置编码的代码详解

绝对位置编码是在进行self-attention计算之前为每一个token添加一个可学习的参数,相对位置编码如上式所示,是在进行self-attention计算时,在计算过程中添加一个可学习的相对位置参数。
假设window_size = 2*2即每个窗口有4个token(M=2) ,如图所示,在计算self-attention时,每个token都要与所有的token计算QK值,如图6所示,当位置1的token计算self-attention时,要计算位置1与位置(1,2,3,4)的QK值,即以位置1的token为中心点,中心点位置坐标(0,0),其他位置计算与当前位置坐标的偏移量。

相对位置索引求解流程图

最后生成的是相对位置索引,relative_position_index.shape = (M^{2}*M^{2}) ,在网络中注册成为一个不可学习的变量,relative_position_index的作用就是根据最终的索引值找到对应的可学习的相对位置编码。relative_position_index的数值范围(0~8),即 (2M-1,2M-1),所以相对位置编码可以由一个3*3的矩阵表示,如图s所示:
相对位置编码
图中的0-8为索引值,每个索引值都对应了 维可学习数据(根据图1,每个token都要计算 个QK值,每个QK值都要加上对应的相对位置编码)

继续以图中 M=2 的窗口为例,当计算位置1对应的 M^2个QK值时,应用的relative_position_index = [ 4, 5, 7, 8] 个 ,对应的数据就是相对位置编码]图中位置索引4,5,7,8位置对应的 M^2 维数据,即relative_position.shape = (M^{2}*M^{2})

相对位置编码在源码WindowAttention中应用,了解原理之后就很容易能够读懂程序:


Shifted Window Attention

采用W-MSA模块时,只会在每个窗口内进行自注意力计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块,即进行偏移的W-MSA。如下图所示,左侧使用的是刚刚讲的W-MSA(假设是第L层),那么根据之前介绍的W-MSA和SW-MSA是成对使用的,那么第L+1层使用的就是SW-MSA(右侧图)。根据左右两幅图对比能够发现窗口(Windows)发生了偏移(可以理解成窗口从左上角分别向右侧和下方各偏移了⌊ M/2 ⌋ 个像素)。看下偏移后的窗口(右侧图),比如对于第一行第2列的2x4的窗口,它能够使第L层的第一排的两个窗口信息进行交流。再比如,第二行第二列的4x4的窗口,他能够使第L层的四个窗口信息进行交流,其他的同理。那么这就解决了不同窗口之间无法进行信息交流的问题。

根据上图,可以发现通过将窗口进行偏移后,由原来的4个窗口变成9个窗口了。后面又要对每个窗口内部进行MSA,这样做感觉又变麻烦了。为了解决这个麻烦,作者又提出而了Efficient batch computation for shifted configuration,一种更加高效的计算方法。下面是原论文给的示意图。

在进行cyclic shift之前,需要给子窗口进行编码,编码之后通过torch.roll对窗口进行滚动,达到cyclic shift的效果
计算 Attention 的时候,让具有相同 index QK 进行计算,而忽略不同 index QK 计算结果。

想在原始四个窗口下得到正确的结果,我们就必须给Attention的结果加入一个mask。

if self.shift_size > 0:
    # calculate attention mask for SW-MSA
    H, W = self.input_resolution
    # 生成全零张量
    img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
    # 按区域划分mask
    h_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    w_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    cnt = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1
    # tensor([[[0., 0., 0., 0., 1., 1., 2., 2.],
                 # [0., 0., 0., 0., 1., 1., 2., 2.],
                 # [0., 0., 0., 0., 1., 1., 2., 2.],
                 # [0., 0., 0., 0., 1., 1., 2., 2.],
                 # [3., 3., 3., 3., 4., 4., 5., 5.],
                 # [3., 3., 3., 3., 4., 4., 5., 5.],
                 # [6., 6., 6., 6., 7., 7., 8., 8.],
                 # [6., 6., 6., 6., 7., 7., 8., 8.]]])
    mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
    mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
    attn_mask = None

SwinTransformerBlock调用了上面介绍到的WindowAttention,有mask的情况下,在WindowAttention中应用mask对self-attention结果进行调整。SwinTransformerBlock主要就是W-MSA/SW-MSA的实现,其结构为:LN+(W−MSA/SW−MSA)+LN+MLP。要注意的是shifted的特征图最后会还原。这里的LN为nn.LayerNorm;MLP为作者自己的实现。

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion. 这里的分辨率是转换成patches之后的分辨率,不是原图像素的分辨率
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. dim*mlp_ration=隐藏层神经元个数
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim  # 输入通道C
        self.input_resolution = input_resolution # 输入分辨率
        self.num_heads = num_heads # self_attention head
        self.window_size = window_size # 窗口大小
        self.shift_size = shift_size # sw-window
        self.mlp_ratio = mlp_ratio #
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim) # nn.LayerNorm
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1

            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            # 上述操作是为了给每个窗口给上索引

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # nW, window_size*window_size
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size*window_size, window_size*window_size
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):

        H, W = self.input_resolution  # H,W不是像素的分辨率,而是转化成patches之后的分辨率
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)  # self.norm1 = norm_layer(dim) = nn.LayerNorm(dim)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            # torch.roll shifts为正则向下滚动,为负则向上滚动,可以是一个数组也可以是一个元组
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        '''
        经过torch.roll之后计算self.attn是SW-MSA
        不经过torch.roll计算的self.attn是W-MSA
        '''
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        '''
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        '''
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # W-MSA/SW-MSA
        nW = H * W / self.window_size / self.window_size
        flops += nW * self.attn.flops(self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops

Block整体结构如下:

  1. 先对特征图进行LayerNorm
  2. 通过self.shift_size决定是否需要对特征图进行shift
  3. 然后将特征图切成一个个窗口
  4. 计算Attention,通过self.attn_mask来区分Window Attention还是Shift Window Attention
  5. 将各个窗口合并回来
  6. 如果之前有做shift操作,此时进行reverse shift,把之前的shift操作恢复
  7. 做dropout和残差连接
  8. 再通过一层LayerNorm+全连接层,以及dropout和残差连接

SwinTransformerBlock,此程序中调用了上面介绍到的WindowAttention,有mask的情况下,在WindowAttention中应用mask对self-attention结果进行调整。

class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
    """
    def __init__(self,
        dim,
        depth,
        num_heads,
        window_size,
        mlp_ratio=4.,
        qkv_bias=True,
        qk_scale=None,
        drop=0.,
        attn_drop=0.,
        drop_path=0.,
        norm_layer=nn.LayerNorm,
        downsample=None,
        #use_checkpoint=False,
    ):
        super().__init__()
        self.window_size = window_size
        self.shift_size = window_size // 2
        self.depth = depth
        
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else window_size // 2,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer,
            )
            for i in range(depth)
        ])
        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None
    
    def forward(self, x, H, W):
        """
        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """
        
        # calculate attention mask for SW-MSA ----
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1
        
        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        #------
     
    
        for blk in self.blocks:
            x = blk(x, H, W, attn_mask)
            
        if self.downsample is not None:
            x_down = self.downsample(x, H, W)
            Wh, Ww = (H + 1) // 2, (W + 1) // 2
            return x, H, W, x_down, Wh, Ww
        else:
            return x, H, W, x, H, W

UperNet

主要使用的是UPerNet的decoder结构,Decoder = FPN+PPM:理论上讲,深度卷积网络的感受野足够大,但实际可用的要小很多。为克服这一问题,本文把 PSPNet 中的金字塔池化模块(PPM)用于骨干网络的最后一层,在其被馈送至 FPN 自上而下的分支之前。结果实验证明,在带来有效的全局先验表征方面,PPM 和 FPN 架构是高度一致的。

UPerNet 架构图

class UPerDecoder(nn.Module):
    def __init__(self,
        in_dim=[256, 512, 1024, 2048],
        ppm_pool_scale=[1, 2, 3, 6],
        ppm_dim=512,
        fpn_out_dim=256
    ):
        super(UPerDecoder, self).__init__()
        
        # PPM ----
        dim = in_dim[-1]
        ppm_pooling = []
        ppm_conv = []
        
        for scale in ppm_pool_scale:
            ppm_pooling.append(
                nn.AdaptiveAvgPool2d(scale)
            )
            ppm_conv.append(
                nn.Sequential(
                    nn.Conv2d(dim, ppm_dim, kernel_size=1, bias=False),
                    nn.BatchNorm2d(ppm_dim),
                    nn.ReLU(inplace=True)
                )
            )
        self.ppm_pooling   = nn.ModuleList(ppm_pooling)
        self.ppm_conv      = nn.ModuleList(ppm_conv)
        self.ppm_out = conv3x3_bn_relu(dim + len(ppm_pool_scale)*ppm_dim, fpn_out_dim, 1)
        
        # FPN ----
        fpn_in = []
        for i in range(0, len(in_dim)-1):  # skip the top layer
            fpn_in.append(
                nn.Sequential(
                    nn.Conv2d(in_dim[i], fpn_out_dim, kernel_size=1, bias=False),
                    nn.BatchNorm2d(fpn_out_dim),
                    nn.ReLU(inplace=True)
                )
            )
        self.fpn_in = nn.ModuleList(fpn_in)
        
        fpn_out = []
        for i in range(len(in_dim) - 1):  # skip the top layer
            fpn_out.append(
                conv3x3_bn_relu(fpn_out_dim, fpn_out_dim, 1),
            )
        self.fpn_out = nn.ModuleList(fpn_out)
        
        self.fpn_fuse = nn.Sequential(
            conv3x3_bn_relu(len(in_dim) * fpn_out_dim, fpn_out_dim, 1),
        )
    
    def forward(self, feature):
        f = feature[-1]
        pool_shape = f.shape[2:]
        
        ppm_out = [f]
        for pool, conv in zip(self.ppm_pooling, self.ppm_conv):
            p = pool(f)
            p = F.interpolate(p, size=pool_shape, mode='bilinear', align_corners=False)
            p = conv(p)
            ppm_out.append(p)
        ppm_out = torch.cat(ppm_out, 1)
        down = self.ppm_out(ppm_out)
        
        
        #--------------------------------------
        fpn_out = [down]
        for i in reversed(range(len(feature) - 1)):
            lateral = feature[i]
            lateral = self.fpn_in[i](lateral) # lateral branch
            down = F.interpolate(down, size=lateral.shape[2:], mode='bilinear', align_corners=False) # top-down branch
            down = down + lateral
            fpn_out.append(self.fpn_out[i](down))
        
        fpn_out.reverse() # [P2 - P5]
        fusion_shape = fpn_out[0].shape[2:]
        fusion = [fpn_out[0]]
        for i in range(1, len(fpn_out)):
            fusion.append(
                F.interpolate( fpn_out[i], fusion_shape, mode='bilinear', align_corners=False)
            )
        x = self.fpn_fuse( torch.cat(fusion, 1))
        
        return x, fusion

Swin+UPerNet

参数设置:

cfg = dict(

        #configs/_base_/models/upernet_swin.py
        basic = dict(
            swin=dict(
                embed_dim=96,
                depths=[2, 2, 6, 2],
                num_heads=[3, 6, 12, 24],
                window_size=7,
                mlp_ratio=4.,
                qkv_bias=True,
                qk_scale=None,
                drop_rate=0.,
                attn_drop_rate=0.,
                drop_path_rate=0.3,
                ape=False,
                patch_norm=True,
                out_indices=(0, 1, 2, 3),
                use_checkpoint=False
            ),

        ),

        #configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k.py
        swin_tiny_patch4_window7_224=dict(
            checkpoint = pretrain_dir+'swin_tiny_patch4_window7_224_22k.pth',

            swin = dict(
                embed_dim=96,
                depths=[2, 2, 6, 2],
                num_heads=[3, 6, 12, 24],
                window_size=7,
                ape=False,
                drop_path_rate=0.3,
                patch_norm=True,
                use_checkpoint=False,
            ),
            upernet=dict(
                in_channels=[96, 192, 384, 768],
            ),
        ),

        #/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k.py
        swin_small_patch4_window7_224_22k=dict(
            checkpoint = pretrain_dir+'swin_small_patch4_window7_224_22k.pth',

            swin = dict(
                embed_dim=96,
                depths=[2, 2, 18, 2],
                num_heads=[3, 6, 12, 24],
                window_size=7,
                ape=False,
                drop_path_rate=0.3,
                patch_norm=True,
                use_checkpoint=False
            ),
            upernet=dict(
                in_channels=[96, 192, 384, 768],
            ),
        ),
    )

整体结构:

class Net(nn.Module):
    
    def load_pretrain(self,):

        checkpoint = cfg[self.arch]['checkpoint']
        print('loading %s ...'%checkpoint)
     
        checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)['model']
        if 0:
            skip = ['relative_coords_table','relative_position_index']
            filtered={}
            for k,v in checkpoint.items():
                if any([s in k for s in skip ]): continue
                filtered[k]=v
            checkpoint = filtered
        print(self.encoder.load_state_dict(checkpoint,strict=False))  #True


    def __init__( self,):
        super(Net, self).__init__()
        self.output_type = ['inference', 'loss']

        self.rgb = RGB()
        self.arch = 'swin_small_patch4_window7_224_22k'#'swin_tiny_patch4_window7_224'

        self.encoder = SwinTransformerV1(
            ** {**cfg['basic']['swin'], **cfg[self.arch]['swin'],
                **{'out_norm' : LayerNorm2d} }
        )
        encoder_dim =cfg[self.arch]['upernet']['in_channels']
        #[96, 192, 384, 768]

        self.decoder = UPerDecoder(
            in_dim=encoder_dim,
            ppm_pool_scale=[1, 2, 3, 6],
            ppm_dim=512,
            fpn_out_dim=256
        )

        self.logit = nn.Sequential(
            nn.Conv2d(256, 1, kernel_size=1)
        )
        self.aux = nn.ModuleList([
            nn.Conv2d(256, 1, kernel_size=1, padding=0) for i in range(4)
        ])



    def forward(self, batch):
        x = batch['image']
        B,C,H,W = x.shape
        x = self.rgb(x)
        encoder = self.encoder(x)
        last, decoder = self.decoder(encoder)
        logit = self.logit(last)
        logit = F.interpolate(logit, size=None, scale_factor=4, mode='bilinear', align_corners=False)

        output = {}
        if 'loss' in self.output_type:
            output['bce_loss'] = F.binary_cross_entropy_with_logits(logit, batch['mask'])
            output['dice_loss'] = DiceLoss()(logit, batch['mask'])
            output['focal_loss'] = FocalLoss(logits=True, reduce=False)(logit, batch['mask'])
            for i in range(4):
                output['aux%d_loss'%i] = criterion_aux_loss(self.aux[i](decoder[i]),batch['mask'])

        if 'inference' in self.output_type:
            output['probability'] = torch.sigmoid(logit)

        return output

CoaT+Daformer总结:

CoaT

Co-scale conv-attentional image Transformers(CoaT),这是一种基于Transformer的图像分类器,其主要包含Co-scale和conv-attentional机制设计。

  • 首先,Co-scale机制在各个尺度上都保持了Transformers编码器分支的完整性,同时允许在不同尺度下学习的表示形式能够有效地进行彼此间的通信。同时,作者还设计了一系列的串行和并行块用来实现Co-scale Attention机制。
  • 其次,本文通过一种类似于卷积的实现方式设计了一种Factorized Attention机制,可以使得在因式注意力模块中实现相对位置的嵌入。CoaT为 Vision Transformer提供了丰富的多尺度和上下文建模功能。
    尽管CNN和Self-Attention操作都执行一个加权和,但它们的权值计算方式不同:在CNN中权值在训练过程中学习,但在测试过程中固定;而在Self-Attention中,根据每对Token之间的相似度或亲和度动态计算权重。因此,Self-Attention中的自相似操作提供了比卷积操作更具有潜在适应性和通用性的建模手段。此外,位置编码和位置嵌入的引入为Transformer建模提供了灵活性。
    模型详解:https://jishuin.proginn.com/p/763bfbd566f3

Daformer

DAFormer是使用Transformer进行语义分割无监督域自适应的开篇之作。DAFormer的网络结构包括一个Transformer编码器和一个多级上下文感知特征融合解码器。它是由3个简单但很关键的训练策略来稳定训练和避免对源域的过拟合:

  • 源域上的罕见类采样通过减轻Self-training对普通类的确认偏差提高了Pseudo-labels的质量
  • Thing-Class ImageNet Feature Distance
  • Learning rate warmup促进了预训练的特征迁移

这里我们只是借用decoder模块,完整的Daformer:https://blog.csdn.net/amusi1994/article/details/124833996

class DaformerDecoder(nn.Module):
    def __init__(
            self,
            encoder_dim = [32, 64, 160, 256],
            decoder_dim = 256,
            dilation = [1, 6, 12, 18],
            use_bn_mlp  = True,
            fuse = 'conv3x3',
    ):
        super().__init__()
        self.mlp = nn.ModuleList([
            nn.Sequential(
                # Conv2dBnReLU(dim, decoder_dim, 1, padding=0), #follow mmseg to use conv-bn-relu
                *(
                  ( nn.Conv2d(dim, decoder_dim, 1, padding= 0,  bias=False),
                    nn.BatchNorm2d(decoder_dim),
                    nn.ReLU(inplace=True),
                )if use_bn_mlp else
                  ( nn.Conv2d(dim, decoder_dim, 1, padding= 0,  bias=True),)
                ),
                
                MixUpSample(2**i) if i!=0 else nn.Identity(),
            ) for i, dim in enumerate(encoder_dim)])
      
        if fuse=='conv1x1':
            self.fuse = nn.Sequential(
                nn.Conv2d(len(encoder_dim) * decoder_dim, decoder_dim, 1, padding=0, bias=False),
                nn.BatchNorm2d(decoder_dim),
                nn.ReLU(inplace=True),
            )
        
        if fuse=='conv3x3':
            self.fuse = nn.Sequential(
                nn.Conv2d(len(encoder_dim) * decoder_dim, decoder_dim, 3, padding=1, bias=False),
                nn.BatchNorm2d(decoder_dim),
                nn.ReLU(inplace=True),
            )
        
        if fuse=='aspp':
            self.fuse = ASPP(
                decoder_dim*len(encoder_dim),
                decoder_dim,
                dilation,
            )
            
        if fuse=='ds-aspp':
            self.fuse = DSASPP(
                decoder_dim*len(encoder_dim),
                decoder_dim,
                dilation,
            )
        
    
    def forward(self, feature):
        
        out = []
        for i,f in enumerate(feature):
            f = self.mlp[i](f)
            out.append(f)
            #print(f.shape)
        x = self.fuse(torch.cat(out, dim = 1))
        return x, out


class daformer_conv3x3 (DaformerDecoder):
    def __init__(self, **kwargs):
        super(daformer_conv3x3, self).__init__(
            fuse = 'conv3x3',
            **kwargs
        )
class daformer_conv1x1 (DaformerDecoder):
    def __init__(self, **kwargs):
        super(daformer_conv1x1, self).__init__(
            fuse = 'conv1x1',
            **kwargs
        )

class daformer_aspp (DaformerDecoder):
    def __init__(self, **kwargs):
        super(daformer_aspp, self).__init__(
            fuse = 'aspp',
            **kwargs
        )

CoaT+Daformer

class Net(nn.Module):
    
    def __init__(self,
                 encoder=coat_lite_medium,
                 decoder=daformer_conv3x3,
                 encoder_cfg={},
                 decoder_cfg={},
                 ):
        
        super(Net, self).__init__()
        self.output_type = ['inference', 'loss']
        decoder_dim = decoder_cfg.get('decoder_dim', 320)
        self.encoder = encoder
        self.rgb = RGB()
        encoder_dim = self.encoder.embed_dims
        # [64, 128, 320, 512]

        self.decoder = decoder(
            encoder_dim=encoder_dim,
            decoder_dim=decoder_dim,
        )
        self.logit = nn.Sequential(
            nn.Conv2d(decoder_dim, 1, kernel_size=1),
            nn.Upsample(scale_factor = 4, mode='bilinear', align_corners=False),
        )

        self.aux = nn.ModuleList([
            nn.Conv2d(decoder_dim, 1, kernel_size=1, padding=0) for i in range(4)
        ])

    def forward(self, batch):
        x = batch['image']
        x = self.rgb(x)

        B, C, H, W = x.shape
        encoder = self.encoder(x)

        last, decoder = self.decoder(encoder)
        logit = self.logit(last)

        output = {}
        if 'loss' in self.output_type:
            output['bce_loss'] = F.binary_cross_entropy_with_logits(logit, batch['mask'])
            output['dice_loss'] = DiceLoss()(logit, batch['mask'])
            output['focal_loss'] = FocalLoss(logits=True, reduce=False)(logit, batch['mask'])
            for i in range(4):
                output['aux%d_loss'%i] = criterion_aux_loss(self.aux[i](decoder[i]),batch['mask'])

        if 'inference' in self.output_type:
            output['probability'] = torch.sigmoid(logit)

        return output

Badcase分析

Yellow: True Positive, Red: False Negative, Green: False Positive

在肺部区域表现较差,解决思路使用过采样肺部数据的方式重新训练Swin transformer+UPerNet,将其作为肺部预测模型,与原数据模型构成集成预测。

if self.train:
   self.df = df.append(df.loc[df[df['organ']=="lung"].index.repeat(5)]).reset_index(drop=True)
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,287评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,346评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,277评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,132评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,147评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,106评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,019评论 3 417
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,862评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,301评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,521评论 2 332
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,682评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,405评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,996评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,651评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,803评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,674评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,563评论 2 352

推荐阅读更多精彩内容