PlainMamba---Direction-Aware Updating实现

class DirAwareSSM(nn.Module):
    """
    极简 SSM:仅保留线性更新 + 方向偏移
    方向:0→1→2→3 对应 4 条扫描;4 是 Begin
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        # 5 个方向:右 左 下 上 Begin
        self.delta_dir = nn.Parameter(torch.randn(5, dim))
        self.A  = nn.Parameter(torch.randn(dim, dim) * 0.02)
        self.Bp = nn.Linear(dim, dim, bias=False)
        self.Cp = nn.Linear(dim, dim, bias=False)

    def single_forward(self, x, dir_id):
        """
        x: (L, dim) 一条扫描序列
        dir_id: 0-3 扫描方向;Begin 方向固定用 4
        return: (L, dim)
        """
        L, d = x.shape
        h = torch.zeros(d, device=x.device)
        outs = []
        for t in range(L):
            delta = self.delta_dir[4] if t == 0 else self.delta_dir[dir_id] ##对每个序列开头,都使用Begin的向量
            B = self.Bp(x[t]) + delta
            h = h @ self.A + B * x[t]
            outs.append(self.Cp(h))
        return torch.stack(outs)

之前一直不太理解5个方向计算4条路径要怎么操作,看了代码才知道,Begin这个方向是在每条路径的开头都计算的,对于每条路径的开头都使用Begin的向量计算,后续序列都固定使用当下的方向的向量。

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容