「微信、中科院」FlashPrefill:长上下文加速 28 倍

发布时间:2026/6/29 22:20:04

「微信、中科院」FlashPrefill:长上下文加速 28 倍 一句话总结把长上下文 prefill 阶段的寻找重要块和筛选策略做到了近乎零成本用动态阈值彻底剪掉注意力长尾再用索引压紧实现物理跳转FlashPrefill 在 256K 上下文下实现 27.78 倍算子加速论文标题FlashPrefill: Instantaneous Pattern Discovery and Thresholding for Ultra-Fast Long-Context Prefilling论文地址https://arxiv.org/abs/2603.06199作者背景微信、中科院自动化所代码地址https://github.com/qhfan/FlashPrefill一、动机大模型推理分两个阶段Prefill预填充和Decoding逐 token 生成。Prefill 阶段需要把用户给的整段 prompt 过一遍 Transformer计算出每一层的 KV Cache“读完全文后做的笔记索引”Decoding 阶段每生成一个新 token都需查阅这份笔记Attention 的计算复杂度是 O(L²)意味着 Prefill 耗时会随着上下文爆炸式增长。相比于 Decoding 阶段好歹能看到一个字一个字往外蹦Prefill 期间用户面对的是完全卡住的界面因此成为了长上下文场景下用户体验的真正瓶颈。于是作者设计了 FlashPrefill旨在加速 Prefill降低 TTFTTime To First Token首 token 等待时间二、现有方法痛点加速 Prefill 的主流思路是稀疏注意力既然注意力图里大部分区域的分数很低那就只在重要区域做精细计算。在实际的注意力热力图中可以观察到三类典型的稀疏模式竖条某些锚点 token如bos、标题、分隔符被几乎所有位置关注斜条沿对角线的亮带反映局部依赖即每个位置更关注邻近上下文块状热点一块块的亮方块对应段落级的语义关联现有方法MInference、FlexPrefill、XAttention、MoBA 等的通用流水线是先发现这些模式 → 再选择要计算的块 → 最后做稀疏注意力。但 FlashPrefill 的作者指出真正的瓶颈往往不在第三步算注意力而在前两步找哪里要算模式发现有延迟估计每个 block 的重要性本身需要不少计算选择策略很贵Top-k 需要排序Top-p 需要累加概率做阈值截断这些操作在 GPU 上难并行长尾分布导致稀疏不彻底注意力分数经常呈长尾分布——少量块分数很高但大量块分数极低却数量庞大。Top-k/Top-p 为了凑够 k 个或凑够 p 概率不得不把很多贡献微乎其微的块也带上三、FlashPrefill 方法3.1 瞬时模式发现近似运算快速得到一个 block 级别的“注意力地图”而不是按 token 全量运算同一个 block 内的 token 语义往往很相似注意力 logit 的方差很小。论文通过 AM-GM 关系证明在低方差假设下用均值做代理时跨块的相对排序基本保持不变显存优化GPU 计算缓慢的主要原因在于显存读写数据太多而不是计算量太大需要减少数据搬运key 池化对 Key 分块然后平均池化理论证明不影响排序query 池化对 Query 分块计算完每个 block 中的 Q·K 分数后用块内最大值对它们做归一化再取 exp 汇总得到一个代表“query块对key块的总体吸引力”的分数在一个 kernel 中实现 全局再加权上一步的结果是块对块的局部分数只在同行、同列上可比。为了得到全局可比的“注意力地图”还需要再做一次全局的归一化固定 query block以分数最高的 key block 为基准再次对齐上述步骤存储的都是块级别的信息相比直接存储 token 级别的信息数据量小很多避免巨大中间矩阵的显存读写大幅减少 Memory IO。如下图所示左边是仅使用 key 池化的注意力计算流程它需要存储【token数(L) * 块数(L/B)】的数据而右边只需要存【块数(L/B) * 块数(L/B)】3.2 基于最大值的动态阈值得到块级分数后下一步要决定保留哪些块。FlashPrefill 不用 Top-k 也不用 Top-p而是设定一个动态阈值thresh I α ⋅ max ⁡ J ≤ I ( Score I , J ) \text{thresh}_I \alpha \cdot \max_{J \le I}(\text{Score}_{I,J})threshI​α⋅J≤Imax​(ScoreI,J​)即以每个 Query Block 的最大块分数为标杆乘以系数 α“相对于最重要块的最低门槛比例”低于阈值的块一律丢弃。这一算法的优势在于计算极轻只需一次 max-reduction不用排序也不用累加避免长尾堆积不会为了凑数而保留长尾中大量无效块实际效果在密度数据上非常明显序列长度FlashPrefillFlexPrefillXAttention4K70.4%82.0%90.3%64K7.9%18.0%37.5%256K3.5%8.4%18.5%FlashPrefill 在 256K 时只保留 3.5% 的块——越长越稀疏。安全网阈值过滤之外FlashPrefill 还显式保留两类块Attention Sinks前 256 tokens很多模型天然会高频关注序列开头Local Window最近 512 tokens保证局部连贯性3.3 索引驱动的物理跳转逻辑跳过的缺陷很多块稀疏实现的内层循环仍遍历所有候选块遇到 mask0 的就用if跳过。但在 GPU 上循环次数没减少分支判断、线程同步、分支发散等开销仍然存在FlashPrefill 的做法——物理跳转把稀疏掩码中需要计算的块索引压紧compact成一个连续列表Kernel 的内层循环只遍历这个列表中的索引直接跳转到对应的 K/V 位置循环次数从所有候选块缩减为真正要算的块内存访问更集中将逻辑跳过但仍然存在物理遍历的 Block-Sparse-Attention优化成索引驱动的物理跳转带来的效率提升对比四、实验结果4.1 算子加速各算子相当于 Flash Attention 的加速比以及算子内部各环节执行耗时可见在 Qwen3-30B-A3B 模型上prefill 在 256K 上下文上加速了 27.78 倍4.2 端到端 TTFT集成到 vLLM 后端到端首 token 等待时间的加速4.3 准确率FlashPrefill 在多个基准上的分数非常接近全量注意力RULERQwen3-30B-A3B 平均 92.68Full: 93.28同时有 18.67× 加速InfiniteBenchQwen2.5-7B 上甚至略超 Full24.93 vs 23.87VideoMMEQwen3-VL-30B-A3B 上 72.00 vs 72.11——几乎无差Needle-in-Haystack2K256K 范围内检索能力几乎不掉4.4 消融阈值策略对比在 Llama-3.1-8B RULER 上的对比Max-based 在更低密度下取得了更高分数充分验证了动态阈值对长尾的优势。

相关新闻