class VimBlock(nn.Module):
"""Vim块实现:包含双向SSM和门控机制"""
def __init__(self, hidden_dim, expand_dim, ssm_dim):
super().__init__()
self.hidden_dim = hidden_dim # D
self.expand_dim = expand_dim # E
self.ssm_dim = ssm_dim # N
# 输入归一化
self.norm = nn.LayerNorm(hidden_dim)
# 线性投影层:将输入映射到扩展维度
self.proj_x = nn.Linear(hidden_dim, expand_dim)
self.proj_z = nn.Linear(hidden_dim, expand_dim)
# 前向/后向1D卷积(用于局部上下文建模)
self.conv1d_forward = nn.Conv1d(
in_channels=expand_dim,
out_channels=expand_dim,
kernel_size=3,
padding=1
)
self.conv1d_backward = nn.Conv1d(
in_channels=expand_dim,
out_channels=expand_dim,
kernel_size=3,
padding=1
)
# SSM参数投影层
self.proj_B = nn.Linear(expand_dim, ssm_dim)
self.proj_C = nn.Linear(expand_dim, ssm_dim)
self.proj_Delta = nn.Linear(expand_dim, expand_dim)
# SSM初始参数(A是固定参数,B/C/Delta是数据依赖的)
self.A = nn.Parameter(torch.randn(expand_dim, ssm_dim))
# 输出投影层(将扩展维度映射回隐藏维度)
self.proj_out = nn.Linear(expand_dim, hidden_dim)
def forward(self, x):
"""
输入: (B, M, D) 其中B=批次大小, M=序列长度, D=隐藏维度
输出: (B, M, D)
"""
residual = x
x = self.norm(x) # (B, M, D)
# 线性投影到扩展维度
x_proj = self.proj_x(x) # (B, M, E)
z = self.proj_z(x) # (B, M, E) 门控向量
# 双向SSM处理
def process_direction(x_dir, conv1d):
"""处理单个方向(前向/后向)的SSM"""
# 1D卷积 + 激活
x_dir = rearrange(x_dir, 'b m e -> b e m') # 转换为卷积输入格式
x_dir = conv1d(x_dir) # (B, E, M)
x_dir = rearrange(x_dir, 'b e m -> b m e') # 转换回序列格式
x_dir = F.silu(x_dir) # SiLU激活
# 投影得到SSM参数
B = self.proj_B(x_dir) # (B, M, N)
C = self.proj_C(x_dir) # (B, M, N)
Delta = F.softplus(self.proj_Delta(x_dir) + 0.5) # 确保Delta为正
# 计算离散化的A和B(论文公式2)
A_discrete = torch.exp(Delta.unsqueeze(-1) * self.A) # (B, M, E, N)
B_discrete = Delta.unsqueeze(-1) * B.unsqueeze(2) # (B, M, E, N)
# SSM递归计算(论文算法1中的循环)
h = torch.zeros(x_dir.size(0), self.expand_dim, self.ssm_dim, device=x.device) # (B, E, N)
y_dir = []
for i in range(x_dir.size(1)): # 遍历序列长度M
h = A_discrete[:, i] * h + B_discrete[:, i] * x_dir[:, i].unsqueeze(-1) # (B, E, N)
y_i = torch.sum(h * C[:, i].unsqueeze(1), dim=-1) # (B, E)
y_dir.append(y_i)
return torch.stack(y_dir, dim=1) # (B, M, E)
# 前向处理
y_forward = process_direction(x_proj, self.conv1d_forward)
# 后向处理(反转序列)
x_backward = x_proj.flip(dims=[1]) # 反转序列
y_backward = process_direction(x_backward, self.conv1d_backward)
y_backward = y_backward.flip(dims=[1]) # 恢复原始顺序
# 门控机制 + 残差连接
z_gate = F.silu(z)
y = self.proj_out(y_forward * z_gate + y_backward * z_gate)
return y + residual
最近再学习Mamba相关知识,这个Vision Mamba 算法的核心就是上面的vimblock模块。VimBlock 用一次“轻量级、数据驱动的双向 SSM”替换了 ViT 的 Self-Attention,既保留全局感受野,又把复杂度从二次降到线性,并通过门控、残差、LayerNorm 保证训练稳定性与表达力。Vision Mamba整体上其实跟VIT一样的我感觉。