33、STVit模块
论文《Vision Transformer with Super Token Sampling》
1、作用
STVit旨在通过改进视觉Transformer的空间-时间效率,解决在处理视频和图像任务时常见的计算冗余问题。该模型尝试减少早期层次捕捉局部特征时的冗余计算,从而减少不必要的计算成本。
2、机制
STVit引入了一种类似于图像处理中“超像素”的概念,称为“超级令牌”(super tokens),以减少自注意力计算中元素的数量,同时保留对全局关系建模的能力。该过程涉及从视觉令牌中采样超级令牌,对这些超级令牌执行自注意力操作,并将它们映射回原始令牌空间。
3、独特优势
STVit在不同的视觉任务中展示了强大的性能,包括图像分类、对象检测和分割,同时拥有更少的参数和较低的计算成本。例如,STVit在没有额外训练数据的情况下,在ImageNet-1K分类任务上达到了86.4%的顶级1准确率,且参数少于100M。
4、代码
import torch
import torch.nn as nn
import torch.nn.functional as F
# Unfold模块使用给定的kernel_size对输入进行展开
class Unfold(nn.Module):
def __init__(self, kernel_size=3):
super().__init__()
# kernel_size定义了展开操作的窗口大小
self.kernel_size = kernel_size
# 初始化权重为单位矩阵,使得每个窗口内的元素直接复制到输出
weights = torch.eye(kernel_size ** 2)
weights = weights.reshape(kernel_size ** 2, 1, kernel_size, kernel_size)
# 将权重设置为不需要梯度,因为它们不会在训练过程中更新
self.weights = nn.Parameter(weights, requires_grad=False)
def forward(self, x): # 获取输入的批量大小、通道数、高度和宽度
b, c, h, w = x.shape
# 使用定义好的权重对输入进行卷积操作,实现展开功能
x = F.conv2d(x.reshape(b * c, 1, h, w), self.weights, stride=1, padding=self.kernel_size // 2)
# 调整输出的形状,使其包含展开的窗口
return x.reshape(b, c * 9, h * w)
# Fold模块与Unfold相反,用于将展开的特征图折叠回原始形状
class Fold(nn.Module):
def __init__(self, kernel_size=3):
super().__init__()
self.kernel_size = kernel_size
# 与Unfold相同,初始化权重为单位矩阵
weights = torch.eye(kernel_size ** 2)
weights = weights.reshape(kernel_size ** 2, 1, kernel_size, kernel_size)
# 权重不需要梯度
self.weights = nn.Parameter(weights, requires_grad=False)
def forward(self, x):
b, _, h, w = x.shape
# 使用转置卷积(逆卷积)操作恢复原始大小的特征图
x = F.conv_transpose2d(x, self.weights, stride=1, padding=self.kernel_size // 2)
return x
# Attention模块实现自注意力机制
class Attention(nn.Module):
def __init__(self, dim, window_size=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim# dim定义了特征维度,num_heads定义了注意力头的数量
self.num_heads = num_heads
head_dim = dim // num_heads
self.window_size = window_size
# 根据给定的尺度因子或自动计算的尺度进行缩放
self.scale = qk_scale or head_dim ** -0.5
# qkv用一个卷积层同时生成查询、键和值
self.qkv = nn.Conv2d(dim, dim * 3, 1, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Conv2d(dim, dim, 1) # proj是输出的投影层
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, C, H, W = x.shape # 获取输入的形状
N = H * W
# 将qkv的输出重塑为适合自注意力计算的形状
q, k, v = self.qkv(x).reshape(B, self.num_heads, C // self.num_heads * 3, N).chunk(3,
dim=2) # (B, num_heads, head_dim, N)
# 计算注意力分数,注意力分数乘以尺度因子
attn = (k.transpose(-1, -2) @ q) * self.scale
# 应用softmax获取注意力权重
attn = attn.softmax(dim=-2) # (B, h, N, N)
# 应用注意力dropout
attn = self.attn_drop(attn)
x = (v @ attn).reshape(B, C, H, W)
x = self.proj(x)
x = self.proj_drop(x)
return x
# StokenAttention模块通过迭代地细化空间Token以增强特征表示
class StokenAttention(nn.Module):
def __init__(self, dim, stoken_size, n_iter=1, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0.):
super().__init__()
self.n_iter = n_iter
self.stoken_size = stoken_size
self.scale = dim ** - 0.5
self.unfold = Unfold(3)
self.fold = Fold(3)
self.stoken_refine = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=proj_drop)
def stoken_forward(self, x):
'''
x: (B, C, H, W)
'''
B, C, H0, W0 = x.shape
h, w = self.stoken_size
# 计算padding
pad_l = pad_t = 0
pad_r = (w - W0 % w) % w
pad_b = (h - H0 % h) % h
if pad_r > 0 or pad_b > 0:
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
_, _, H, W = x.shape
hh, ww = H // h, W // w
# 使用自适应平均池化得到空间Token的特征
stoken_features = F.adaptive_avg_pool2d(x, (hh, ww)) # (B, C, hh, ww)
# 展开特征以进行精细化处理
pixel_features = x.reshape(B, C, hh, h, ww, w).permute(0, 2, 4, 3, 5, 1).reshape(B, hh * ww, h * w, C)
# 使用没有梯度的操作进行迭代精细化
with torch.no_grad():
for idx in range(self.n_iter):
stoken_features = self.unfold(stoken_features) # (B, C*9, hh*ww)
stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9)
affinity_matrix = pixel_features @ stoken_features * self.scale # (B, hh*ww, h*w, 9)
affinity_matrix = affinity_matrix.softmax(-1) # (B, hh*ww, h*w, 9)
affinity_matrix_sum = affinity_matrix.sum(2).transpose(1, 2).reshape(B, 9, hh, ww)
affinity_matrix_sum = self.fold(affinity_matrix_sum)
if idx < self.n_iter - 1:
stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix # (B, hh*ww, C, 9)
stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(
B, C, hh, ww)
stoken_features = stoken_features / (affinity_matrix_sum + 1e-12) # (B, C, hh, ww)
stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix # (B, hh*ww, C, 9)
stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(B, C, hh, ww)
stoken_features = stoken_features / (affinity_matrix_sum.detach() + 1e-12) # (B, C, hh, ww)
stoken_features = self.stoken_refine(stoken_features)
stoken_features = self.unfold(stoken_features) # (B, C*9, hh*ww)
stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9) # (B, hh*ww, C, 9)
# 通过affinity_matrix将精细化的特征映射回原始像素级别
pixel_features = stoken_features @ affinity_matrix.transpose(-1, -2) # (B, hh*ww, C, h*w)
# 折叠特征,恢复原始形状
pixel_features = pixel_features.reshape(B, hh, ww, C, h, w).permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)
if pad_r > 0 or pad_b > 0:
pixel_features = pixel_features[:, :, :H0, :W0]
return pixel_features
def direct_forward(self, x): # 直接对x应用Attention进行细化
B, C, H, W = x.shape
stoken_features = x
stoken_features = self.stoken_refine(stoken_features)
return stoken_features
def forward(self, x):
if self.stoken_size[0] > 1 or self.stoken_size[1] > 1:
return self.stoken_forward(x)
else:
return self.direct_forward(x)
# 输入 N C H W, 输出 N C H W
if __name__ == '__main__':
input = torch.randn(3, 64, 64, 64).cuda() # 创建一个随机输入
se = StokenAttention(64, stoken_size=[8,8]).cuda()# 实例化注意力模块
output = se(input)
print(output.shape) # 打印输出形状