PyTorch新手也能懂:手把手拆解Mamba-minimal中的selective_scan实现

发布时间:2026/6/2 5:46:07

PyTorch新手也能懂:手把手拆解Mamba-minimal中的selective_scan实现 PyTorch新手也能懂手把手拆解Mamba-minimal中的selective_scan实现在深度学习领域状态空间模型State Space Models, SSM正逐渐成为处理序列数据的新范式。而Mamba作为SSM家族的最新成员凭借其选择性扫描机制selective_scan在长序列建模任务中展现出惊人潜力。本文将带您深入Mamba-minimal实现中最核心的selective_scan函数用PyTorch初学者的视角逐行解析这个看似复杂实则精妙的设计。1. 理解状态空间模型的基础状态空间模型本质上描述了一个动态系统的演变过程。在离散时间步中系统状态$x_k$和输出$y_k$由以下方程决定x_k A * x_{k-1} B * u_k y_k C * x_k D * u_k其中$A$是状态转移矩阵$B$是输入矩阵$C$是输出矩阵$D$是前馈矩阵$u_k$是当前输入Mamba的创新之处在于让这些矩阵参数动态依赖于输入数据而传统SSM如S4使用固定参数。这种数据依赖性使得模型能够根据输入内容自适应地调整状态转移方式。2. selective_scan的输入参数解析让我们先看看selective_scan函数的完整签名def selective_scan(self, u, delta, A, B, C, D):各参数含义及维度如下表所示参数维度说明u(b, l, d_in)输入序列b为batch大小l为序列长度delta(b, l, d_in)数据依赖的时间步长参数A(d_in, n)状态转移矩阵B(b, l, n)输入矩阵数据依赖C(b, l, n)输出矩阵数据依赖D(d_in)前馈矩阵关键点在于与传统SSM不同B和C矩阵每个时间步都有不同值delta参数控制着离散化的时间步长也是输入依赖的3. 离散化过程详解Mamba采用两种离散化方法的组合**零阶保持ZOH**用于状态矩阵A前向欧拉用于输入矩阵B对应的离散化公式实现如下deltaA torch.exp(einsum(delta, A, b l d_in, d_in n - b l d_in n)) deltaB_u einsum(delta, B, u, b l d_in, b l n, b l d_in - b l d_in n)这里使用了爱因斯坦求和约定einsum进行高效张量运算。让我们拆解这两行代码3.1 状态矩阵的ZOH离散化deltaA的计算对应ZOH离散化中的$e^{ΔA}$项delta形状(b, l, d_in)A形状(d_in, n)通过einsum在d_in维度上做乘法然后取指数3.2 输入矩阵的欧拉离散化deltaB_u对应欧拉离散化的$ΔB·u$项将delta、B和u三个张量在多个维度上相乘结果形状为(b, l, d_in, n)提示einsum操作虽然高效但对初学者可能不太直观。可以想象它是在指定维度上进行乘法求和的组合操作。4. 选择性扫描的核心循环真正的扫描过程由一个简单的for循环实现x torch.zeros((b, d_in, n), devicedeltaA.device) ys [] for i in range(l): x deltaA[:, i] * x deltaB_u[:, i] # 状态更新 y einsum(x, C[:, i, :], b d_in n, b n - b d_in) # 输出计算 ys.append(y)这个循环实现了以下功能初始化状态x为零对每个时间步i更新状态$x e^{ΔA}x ΔB·u$计算输出$y Cx$收集所有时间步的输出值得注意的是这与原始论文的并行实现不同这里使用顺序扫描时间复杂度O(l)论文使用并行算法时间复杂度O(log l)教学实现更易理解但实际应用应使用CUDA优化版本5. 完整流程与输出处理扫描完成后我们需要处理输出结果y torch.stack(ys, dim1) # 将列表转为张量 (b, l, d_in) y y u * D # 添加前馈连接 return y最后一步y u * D体现了状态空间模型的完整输出方程$y Cx Du$D矩阵提供了从输入到输出的直接路径这种跳跃连接有助于梯度传播6. 与理论公式的对照让我们将代码与离散化理论公式做对比前向欧拉离散化用于B矩阵x_k (I Δ_k A)x_{k-1} Δ_k B u_k零阶保持离散化用于A矩阵x_k e^{Δ_k A}x_{k-1} (Δ_k A)^{-1}(e^{Δ_k A} - I)Δ_k B u_k在Mamba-minimal实现中对A矩阵采用完整ZOH离散化包含指数项对B矩阵做了简化相当于只保留一阶泰勒展开这种混合策略在效果和效率间取得了平衡7. 实际应用中的注意事项在您自己的项目中实现或修改selective_scan时需要注意数值稳定性指数运算可能导致数值爆炸实际实现可能需要添加归一化初始化策略A矩阵初始化为对数空间使用softplus确保delta为正性能考量序列较长时顺序扫描会成为瓶颈实际应用应参考论文的并行实现以下是一个简化的初始化示例# A矩阵的初始化对数空间 A repeat(torch.arange(1, args.d_state 1), n - d n, dargs.d_inner) self.A_log nn.Parameter(torch.log(A)) # delta的处理 delta F.softplus(self.dt_proj(delta)) # 确保为正理解selective_scan的实现是掌握Mamba架构的关键。虽然这个简化版本牺牲了部分效率但它清晰地展现了选择性状态空间的核心思想——通过输入相关的参数实现动态序列建模。

相关新闻