Vision Mamba核心模块实现

class VimBlock(nn.Module):
    """Vim块实现:包含双向SSM和门控机制"""

    def __init__(self, hidden_dim, expand_dim, ssm_dim):
        super().__init__()
        self.hidden_dim = hidden_dim  # D
        self.expand_dim = expand_dim  # E
        self.ssm_dim = ssm_dim  # N

        # 输入归一化
        self.norm = nn.LayerNorm(hidden_dim)

        # 线性投影层:将输入映射到扩展维度
        self.proj_x = nn.Linear(hidden_dim, expand_dim)
        self.proj_z = nn.Linear(hidden_dim, expand_dim)

        # 前向/后向1D卷积(用于局部上下文建模)
        self.conv1d_forward = nn.Conv1d(
            in_channels=expand_dim,
            out_channels=expand_dim,
            kernel_size=3,
            padding=1
        )
        self.conv1d_backward = nn.Conv1d(
            in_channels=expand_dim,
            out_channels=expand_dim,
            kernel_size=3,
            padding=1
        )

        # SSM参数投影层
        self.proj_B = nn.Linear(expand_dim, ssm_dim)
        self.proj_C = nn.Linear(expand_dim, ssm_dim)
        self.proj_Delta = nn.Linear(expand_dim, expand_dim)

        # SSM初始参数(A是固定参数,B/C/Delta是数据依赖的)
        self.A = nn.Parameter(torch.randn(expand_dim, ssm_dim))

        # 输出投影层(将扩展维度映射回隐藏维度)
        self.proj_out = nn.Linear(expand_dim, hidden_dim)

    def forward(self, x):
        """
        输入: (B, M, D) 其中B=批次大小, M=序列长度, D=隐藏维度
        输出: (B, M, D)
        """
        residual = x
        x = self.norm(x)  # (B, M, D)

        # 线性投影到扩展维度
        x_proj = self.proj_x(x)  # (B, M, E)
        z = self.proj_z(x)  # (B, M, E) 门控向量

        # 双向SSM处理
        def process_direction(x_dir, conv1d):
            """处理单个方向(前向/后向)的SSM"""
            # 1D卷积 + 激活
            x_dir = rearrange(x_dir, 'b m e -> b e m')  # 转换为卷积输入格式
            x_dir = conv1d(x_dir)  # (B, E, M)
            x_dir = rearrange(x_dir, 'b e m -> b m e')  # 转换回序列格式
            x_dir = F.silu(x_dir)  # SiLU激活

            # 投影得到SSM参数
            B = self.proj_B(x_dir)  # (B, M, N)
            C = self.proj_C(x_dir)  # (B, M, N)
            Delta = F.softplus(self.proj_Delta(x_dir) + 0.5)  # 确保Delta为正

            # 计算离散化的A和B(论文公式2)
            A_discrete = torch.exp(Delta.unsqueeze(-1) * self.A)  # (B, M, E, N)
            B_discrete = Delta.unsqueeze(-1) * B.unsqueeze(2)  # (B, M, E, N)

            # SSM递归计算(论文算法1中的循环)
            h = torch.zeros(x_dir.size(0), self.expand_dim, self.ssm_dim, device=x.device)  # (B, E, N)
            y_dir = []
            for i in range(x_dir.size(1)):  # 遍历序列长度M
                h = A_discrete[:, i] * h + B_discrete[:, i] * x_dir[:, i].unsqueeze(-1)  # (B, E, N)
                y_i = torch.sum(h * C[:, i].unsqueeze(1), dim=-1)  # (B, E)
                y_dir.append(y_i)
            return torch.stack(y_dir, dim=1)  # (B, M, E)

        # 前向处理
        y_forward = process_direction(x_proj, self.conv1d_forward)

        # 后向处理(反转序列)
        x_backward = x_proj.flip(dims=[1])  # 反转序列
        y_backward = process_direction(x_backward, self.conv1d_backward)
        y_backward = y_backward.flip(dims=[1])  # 恢复原始顺序

        # 门控机制 + 残差连接
        z_gate = F.silu(z)
        y = self.proj_out(y_forward * z_gate + y_backward * z_gate)
        return y + residual

最近再学习Mamba相关知识,这个Vision Mamba 算法的核心就是上面的vimblock模块。VimBlock 用一次“轻量级、数据驱动的双向 SSM”替换了 ViT 的 Self-Attention,既保留全局感受野,又把复杂度从二次降到线性,并通过门控、残差、LayerNorm 保证训练稳定性与表达力。Vision Mamba整体上其实跟VIT一样的我感觉。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。