深度学习模块34-STVit模块

33、STVit模块

论文《Vision Transformer with Super Token Sampling》

1、作用

STVit旨在通过改进视觉Transformer的空间-时间效率,解决在处理视频和图像任务时常见的计算冗余问题。该模型尝试减少早期层次捕捉局部特征时的冗余计算,从而减少不必要的计算成本。

2、机制

STVit引入了一种类似于图像处理中“超像素”的概念,称为“超级令牌”(super tokens),以减少自注意力计算中元素的数量,同时保留对全局关系建模的能力。该过程涉及从视觉令牌中采样超级令牌,对这些超级令牌执行自注意力操作,并将它们映射回原始令牌空间。

3、独特优势

STVit在不同的视觉任务中展示了强大的性能,包括图像分类、对象检测和分割,同时拥有更少的参数和较低的计算成本。例如,STVit在没有额外训练数据的情况下,在ImageNet-1K分类任务上达到了86.4%的顶级1准确率,且参数少于100M。

4、代码

import torch
import torch.nn as nn
import torch.nn.functional as F

# Unfold模块使用给定的kernel_size对输入进行展开
class Unfold(nn.Module):
    def __init__(self, kernel_size=3):
        super().__init__()
# kernel_size定义了展开操作的窗口大小
        self.kernel_size = kernel_size
# 初始化权重为单位矩阵,使得每个窗口内的元素直接复制到输出
        weights = torch.eye(kernel_size ** 2)
        weights = weights.reshape(kernel_size ** 2, 1, kernel_size, kernel_size)
        # 将权重设置为不需要梯度,因为它们不会在训练过程中更新
        self.weights = nn.Parameter(weights, requires_grad=False)

    def forward(self, x):   # 获取输入的批量大小、通道数、高度和宽度
        b, c, h, w = x.shape
         # 使用定义好的权重对输入进行卷积操作,实现展开功能
        x = F.conv2d(x.reshape(b * c, 1, h, w), self.weights, stride=1, padding=self.kernel_size // 2)
         # 调整输出的形状,使其包含展开的窗口
        return x.reshape(b, c * 9, h * w)

# Fold模块与Unfold相反,用于将展开的特征图折叠回原始形状
class Fold(nn.Module):
    def __init__(self, kernel_size=3):
        super().__init__()

        self.kernel_size = kernel_size
         # 与Unfold相同,初始化权重为单位矩阵
        weights = torch.eye(kernel_size ** 2)
        weights = weights.reshape(kernel_size ** 2, 1, kernel_size, kernel_size)
         # 权重不需要梯度
        self.weights = nn.Parameter(weights, requires_grad=False)

    def forward(self, x):
        b, _, h, w = x.shape
        # 使用转置卷积(逆卷积)操作恢复原始大小的特征图
        x = F.conv_transpose2d(x, self.weights, stride=1, padding=self.kernel_size // 2)
        return x

# Attention模块实现自注意力机制
class Attention(nn.Module):
    def __init__(self, dim, window_size=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()

        self.dim = dim# dim定义了特征维度,num_heads定义了注意力头的数量
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.window_size = window_size
# 根据给定的尺度因子或自动计算的尺度进行缩放
        self.scale = qk_scale or head_dim ** -0.5
 # qkv用一个卷积层同时生成查询、键和值
        self.qkv = nn.Conv2d(dim, dim * 3, 1, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Conv2d(dim, dim, 1)  # proj是输出的投影层
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, C, H, W = x.shape # 获取输入的形状
        N = H * W
# 将qkv的输出重塑为适合自注意力计算的形状
        q, k, v = self.qkv(x).reshape(B, self.num_heads, C // self.num_heads * 3, N).chunk(3,
                                                                                           dim=2)  # (B, num_heads, head_dim, N)
 # 计算注意力分数,注意力分数乘以尺度因子
        attn = (k.transpose(-1, -2) @ q) * self.scale
  # 应用softmax获取注意力权重
        attn = attn.softmax(dim=-2)  # (B, h, N, N)
    # 应用注意力dropout
        attn = self.attn_drop(attn)

        x = (v @ attn).reshape(B, C, H, W)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x

# StokenAttention模块通过迭代地细化空间Token以增强特征表示
class StokenAttention(nn.Module):
    def __init__(self, dim, stoken_size, n_iter=1, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
                 proj_drop=0.):
        super().__init__()

        self.n_iter = n_iter
        self.stoken_size = stoken_size

        self.scale = dim ** - 0.5

        self.unfold = Unfold(3)
        self.fold = Fold(3)

        self.stoken_refine = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                                       attn_drop=attn_drop, proj_drop=proj_drop)

    def stoken_forward(self, x):
        '''
           x: (B, C, H, W)
        '''
        B, C, H0, W0 = x.shape
        h, w = self.stoken_size
 # 计算padding
        pad_l = pad_t = 0
        pad_r = (w - W0 % w) % w
        pad_b = (h - H0 % h) % h
        if pad_r > 0 or pad_b > 0:
            x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))

        _, _, H, W = x.shape

        hh, ww = H // h, W // w
 # 使用自适应平均池化得到空间Token的特征
        stoken_features = F.adaptive_avg_pool2d(x, (hh, ww))  # (B, C, hh, ww)
  # 展开特征以进行精细化处理
        pixel_features = x.reshape(B, C, hh, h, ww, w).permute(0, 2, 4, 3, 5, 1).reshape(B, hh * ww, h * w, C)
 # 使用没有梯度的操作进行迭代精细化
        with torch.no_grad():
            for idx in range(self.n_iter):
                stoken_features = self.unfold(stoken_features)  # (B, C*9, hh*ww)
                stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9)
                affinity_matrix = pixel_features @ stoken_features * self.scale  # (B, hh*ww, h*w, 9)

                affinity_matrix = affinity_matrix.softmax(-1)  # (B, hh*ww, h*w, 9)

                affinity_matrix_sum = affinity_matrix.sum(2).transpose(1, 2).reshape(B, 9, hh, ww)

                affinity_matrix_sum = self.fold(affinity_matrix_sum)
                if idx < self.n_iter - 1:
                    stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix  # (B, hh*ww, C, 9)

                    stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(
                        B, C, hh, ww)

                    stoken_features = stoken_features / (affinity_matrix_sum + 1e-12)  # (B, C, hh, ww)

        stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix  # (B, hh*ww, C, 9)

        stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(B, C, hh, ww)

        stoken_features = stoken_features / (affinity_matrix_sum.detach() + 1e-12)  # (B, C, hh, ww)

        stoken_features = self.stoken_refine(stoken_features)

        stoken_features = self.unfold(stoken_features)  # (B, C*9, hh*ww)
        stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9)  # (B, hh*ww, C, 9)
# 通过affinity_matrix将精细化的特征映射回原始像素级别
        pixel_features = stoken_features @ affinity_matrix.transpose(-1, -2)  # (B, hh*ww, C, h*w)
 # 折叠特征,恢复原始形状
        pixel_features = pixel_features.reshape(B, hh, ww, C, h, w).permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)

        if pad_r > 0 or pad_b > 0:
            pixel_features = pixel_features[:, :, :H0, :W0]

        return pixel_features

    def direct_forward(self, x): # 直接对x应用Attention进行细化
        B, C, H, W = x.shape
        stoken_features = x
        stoken_features = self.stoken_refine(stoken_features)
        return stoken_features

    def forward(self, x):
        if self.stoken_size[0] > 1 or self.stoken_size[1] > 1:
            return self.stoken_forward(x)
        else:
            return self.direct_forward(x)


#  输入 N C H W,  输出 N C H W
if __name__ == '__main__':
    input = torch.randn(3, 64, 64, 64).cuda() # 创建一个随机输入 
    se = StokenAttention(64, stoken_size=[8,8]).cuda()# 实例化注意力模块 
    output = se(input)
    print(output.shape) # 打印输出形状
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容