class DirAwareSSM(nn.Module):
"""
极简 SSM:仅保留线性更新 + 方向偏移
方向:0→1→2→3 对应 4 条扫描;4 是 Begin
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# 5 个方向:右 左 下 上 Begin
self.delta_dir = nn.Parameter(torch.randn(5, dim))
self.A = nn.Parameter(torch.randn(dim, dim) * 0.02)
self.Bp = nn.Linear(dim, dim, bias=False)
self.Cp = nn.Linear(dim, dim, bias=False)
def single_forward(self, x, dir_id):
"""
x: (L, dim) 一条扫描序列
dir_id: 0-3 扫描方向;Begin 方向固定用 4
return: (L, dim)
"""
L, d = x.shape
h = torch.zeros(d, device=x.device)
outs = []
for t in range(L):
delta = self.delta_dir[4] if t == 0 else self.delta_dir[dir_id] ##对每个序列开头,都使用Begin的向量
B = self.Bp(x[t]) + delta
h = h @ self.A + B * x[t]
outs.append(self.Cp(h))
return torch.stack(outs)
之前一直不太理解5个方向计算4条路径要怎么操作,看了代码才知道,Begin这个方向是在每条路径的开头都计算的,对于每条路径的开头都使用Begin的向量计算,后续序列都固定使用当下的方向的向量。