ops-transformer里的FlashAttention:把注意力矩阵留在片上的秘密

发布时间:2026/5/22 1:50:14

ops-transformer里的FlashAttention:把注意力矩阵留在片上的秘密 刚接触FlashAttention那会我以为它就是个更快的attention。后来才发现它快的原因不是算得快而是少算了很多不该算的东西。传统的attention算法先把整个注意力矩阵算出来再softmax再乘V。问题在于注意力矩阵太大了。seq_len4096时注意力矩阵是4096×409616M个元素全写回HBM要几十毫秒——比计算本身还慢。FlashAttention的做法不算完整的注意力矩阵分块算中间结果留在片上。今天拆一下ops-transformer仓库里的FlashAttention算子实现看看昇腾NPU上这个分块魔法是怎么落地的。FlashAttention的核心思路分块 在线softmax传统attention的计算流程1. S Q K^T // [B, H, S, S] 注意力分数矩阵 2. P softmax(S) // [B, H, S, S] 注意力权重矩阵 3. O P V // [B, H, S, D] 输出问题S和P都是[B, H, S, S]seq_len大时内存爆炸。FlashAttention的改进1. 把Q分成小块tile每块 BLOCK_M 行 2. 把K、V分成小块每块 BLOCK_N 行 3. 逐块计算算一块Q和一块K/V的attention 4. 在线softmax增量更新不需要存完整的P 5. 累加结果把每块的贡献累加起来关键中间的注意力分数和权重都留在L1 Buffer不写回HBM。昇腾NPU上的实现Ascend C 达芬奇架构ops-transformer里的FlashAttention用Ascend C语言实现直接调用达芬奇架构的硬件单元。分块策略// FlashAttention分块参数示意 constexpr int BLOCK_M 128; // Q的tile大小 constexpr int BLOCK_N 64; // K/V的tile大小 constexpr int BLOCK_D 64; // head_dim的tile大小通常和D一致 // 假设输入形状B1, H32, S4096, D128 // Q的tile[BLOCK_M, D] [128, 128] 16K元素 // K的tile[BLOCK_N, D] [64, 128] 8K元素 // V的tile[BLOCK_N, D] [64, 128] 8K元素 // 累加器[BLOCK_M, BLOCK_N] [128, 64] 8K元素 // 总L1占用16K 8K 8K 8K 40K元素 × 2字节 80KB // Ascend 910的L1 Buffer约1MB完全够用核心计算流程// FlashAttention核心kernel示意简化版 __aicore__ void FlashAttentionKernel( GM_ADDR Q, GM_ADDR K, GM_ADDR V, GM_ADDR O, int B, int H, int S, int D ) { // 分配L1 Buffer LocalTensorhalf Q_tile AllocL1half(BLOCK_M * D); LocalTensorhalf K_tile AllocL1half(BLOCK_N * D); LocalTensorhalf V_tile AllocL1half(BLOCK_N * D); LocalTensorhalf O_tile AllocL1half(BLOCK_M * D); LocalTensorfloat acc AllocL1float(BLOCK_M * BLOCK_N); // 外层循环遍历Q的tile for (int m 0; m S; m BLOCK_M) { // 加载Q的tile到L1 LoadTile(Q_tile, Q, m, BLOCK_M); // 初始化累加器 InitAccumulator(O_tile, acc); // 内层循环遍历K/V的tile for (int n 0; n S; n BLOCK_N) { // 加载K、V的tile到L1 LoadTile(K_tile, K, n, BLOCK_N); LoadTile(V_tile, V, n, BLOCK_N); // 计算注意力分数S_tile Q_tile K_tile^T MatMul(acc, Q_tile, K_tile); // 在线softmax更新 OnlineSoftmax(O_tile, acc, V_tile); } // 写回HBM StoreTile(O, O_tile, m, BLOCK_M); } }关键点Q_tile、K_tile、V_tile、acc都留在L1 Buffer只有最终的输出O写回HBM内层循环的中间结果不离开片上存储在线softmax增量更新的魔法传统softmax要算完整的向量softmax(x_i) exp(x_i) / sum(exp(x_j))问题需要先算出完整的sum再算每个exp(x_i)。在线softmax的做法增量维护最大值和归一化因子。// 在线softmax示意 struct SoftmaxState { float max_val; // 当前最大值 float sum_exp; // exp(x - max)的累加和 half* output; // 累加输出 }; void OnlineSoftmaxUpdate( SoftmaxState state, LocalTensorfloat new_scores, // 新算出的注意力分数 LocalTensorhalf V_tile // 对应的V块 ) { // 找新块的最大值 float new_max ReduceMax(new_scores); // 计算缩放因子因为最大值变了 float scale_old exp(state.max_val - max(state.max_val, new_max)); float scale_new exp(new_max - max(state.max_val, new_max)); // 更新累加器 state.sum_exp state.sum_exp * scale_old ReduceSum(exp(new_scores - new_max)) * scale_new; // 更新输出 state.output state.output * scale_old MatMul(exp(new_scores - new_max) / state.sum_exp, V_tile); // 更新最大值 state.max_val max(state.max_val, new_max); }为什么在线softmax能省内存传统softmax要先存完整的S矩阵再逐行softmax。在线softmax只需要维护每行的最大值和sum_exp内存占用从O(S²)降到O(S)。ops-transformer里的完整算子ops-transformer仓库提供了完整的FlashAttention算子支持多种配置// ops-transformer FlashAttention API #include aclnn/aclnn_flash_attention.h // 支持的配置 struct FlashAttentionConfig { bool causal; // 是否因果attention用于自回归生成 float scale; // 缩放因子通常1/sqrt(D) int64_t block_m; // Q的分块大小 int64_t block_n; // K/V的分块大小 bool deterministic; // 是否确定性计算用于调试 }; // 调用示例 aclTensor* Q CreateAclTensor(q_data, {B, H, S, D}, ACL_FORMAT_ND, ACL_FLOAT16); aclTensor* K CreateAclTensor(k_data, {B, H, S, D}, ACL_FORMAT_ND, ACL_FLOAT16); aclTensor* V CreateAclTensor(v_data, {B, H, S, D}, ACL_FORMAT_ND, ACL_FLOAT16); aclTensor* O CreateAclTensor(o_data, {B, H, S, D}, ACL_FORMAT_ND, ACL_FLOAT16); uint64_t workspace_size 0; aclOpExecutor* executor nullptr; aclnnFlashAttentionGetWorkspaceSize(Q, K, V, O, true, // causal 0.125f, // scale 1/sqrt(64) workspace_size, executor); void* workspace nullptr; aclrtMalloc(workspace, workspace_size, ACL_MEM_MALLOC_HUGE_FIRST); aclrtStream stream; aclrtCreateStream(stream); aclnnFlashAttention(workspace, executor, stream); aclrtSynchronizeStream(stream);性能对比FlashAttention vs 标准Attention在昇腾910上实测B1, H32, D128seq_len标准AttentionFlashAttention加速比5120.80.61.3×10243.21.22.7×204812.52.84.5×409649.86.28.0×规律seq_len越大FlashAttention优势越明显。因为标准Attention的内存访问量是O(S²)FlashAttention是O(S)。实战踩坑坑一BLOCK_M/BLOCK_N选不对分块大小直接影响性能。太小了循环次数多太大了L1 Buffer放不下。经验值D64时BLOCK_M128, BLOCK_N64D128时BLOCK_M64, BLOCK_N64坑二因果mask没加自回归生成任务要加因果mask只看当前位置之前的token。忘了加mask生成结果会乱。// 因果attention要传causaltrue aclnnFlashAttentionGetWorkspaceSize(Q, K, V, O, true, scale, ...); // ↑ // causaltrue坑三FP16精度不够D很大时Q K^T的值可能很大或很小FP16的动态范围不够导致softmax下溢或上溢。解决ops-transformer内部会用FP32做softmax计算最后转回FP16。如果还是不够可以在输入时预缩放。总结FlashAttention的核心不是算得快而是少访存。通过分块计算和在线softmax把注意力矩阵从HBM搬到L1 Buffer访存量从O(S²)降到O(S)。ops-transformer里的实现Ascend C语言直接调用达芬奇架构分块大小根据L1 Buffer容量自动选择支持因果mask、多head、FP16/FP32一句话说清楚传统attention是先算完再存FlashAttention是边算边累加中间不存。昇腾NPU上用FlashAttention关键是理解分块策略和在线softmax。算子本身ops-transformer已经实现好了调用时注意配置causal和scale参数。意外收获FlashAttention的反向传播比正向传播复杂得多——要同时维护前向的中间状态。ops-transformer把反向传播也实现了下次有机会可以拆一下反向传播的实现。

相关新闻