让大模型跑得快一点:Speculative Decoding 实战与加速比分析

发布时间:2026/6/14 11:36:07

让大模型跑得快一点:Speculative Decoding 实战与加速比分析 让大模型跑得快一点Speculative Decoding 实战与加速比分析一、为什么 Decode 阶段那么慢大模型推理其实就两件事Prefill把提示词塞进去和 Decode一个字一个字往外吐。Prefill 阶段是计算密集型GPU 吃得饱饱的但到了 Decode 阶段情况就变了。每生成一个字GPU 都得把整个模型的权重从显存里读一遍。以 70B 模型为例一次 Decode 的计算量只有 0.01 TFLOPS但光读权重就要 140GB 带宽。这就导致 GPU 大部分时间都在等数据计算单元空转。更麻烦的是自回归的串行特性第 N 个字必须等第 N-1 个字生成完才能开始。不管 GPU 算力多强Decode 阶段的延迟始终被内存带宽和串行步数卡着。Speculative Decoding投机解码就是为了解决这个问题它试图打破这种串行依赖一次吐出多个字。二、投机解码是怎么工作的核心逻辑很简单找个“小机灵鬼”Draft Model先猜几个字然后让“老大哥”Target Model一次性验证这些猜测对不对。sequenceDiagram participant Draft as Draft Modelbr/(小模型, 快速) participant Target as Target Modelbr/(大模型, 精确) participant Output as 输出 Note over Draft: Step 1: Draft 模型快速生成 K 个候选 Token Draft-Draft: 生成 t1, t2, t3, t4, t5 Draft-Target: 传入候选序列 Note over Target: Step 2: Target 模型单次前向验证 Target-Target: 并行计算 P(t1), P(t2), P(t3), P(t4), P(t5) Target-Target: 同时生成 Target 的 P(t|prefix) Note over Target: Step 3: 逐个验证候选 Token Target-Target: t1: P_draft0.8, P_target0.9 → 接受 Target-Target: t2: P_draft0.7, P_target0.6 → 接受 Target-Target: t3: P_draft0.5, P_target0.3 → 拒绝 Note over Output: 输出 t1, t2 Target 修正的 t3 Target-Output: t1, t2, t3 Note over Draft: 从 t3 重新开始投机具体流程分三步Draft 模型快速生成 K 个候选 Token小模型速度快几秒钟就能吐出几个字。Target 模型单次前向验证大模型把 Draft 生成的序列一次性跑一遍并行计算每个位置的概率。逐个验证候选 Token对比 Draft 和大模型的概率决定接受还是拒绝。数学上的保证通过特定的接受概率计算Speculative Decoding 能保证输出分布和原始自回归解码完全一致。也就是说它只是加速不会改变生成质量。接受概率的计算逻辑如果 Draft 的概率 $p_d(t_i) \leq p_t(t_i)$直接接受。如果 $p_d(t_i) p_t(t_i)$以 $p_t(t_i) / p_d(t_i)$ 的概率接受。拒绝时从修正分布 $\max(0, p_t - p_d) / \sum \max(0, p_t - p_d)$ 中采样一个 Token 作为修正。三、工程实现代码# speculative_decoding.py — 投机解码引擎 import time from dataclasses import dataclass, field from typing import Optional import numpy as np dataclass class DraftResult: Draft 模型的生成结果 tokens: list[int] log_probs: list[float] # 每个 Token 的 log 概率 latency_ms: float dataclass class VerifyResult: Target 模型的验证结果 accepted_count: int # 被接受的 Token 数 rejected_at: int # 拒绝位置-1 表示全部接受 corrected_token: int # 拒绝时修正的 Token corrected_log_prob: float # 修正 Token 的 log 概率 target_log_probs: list[float] # Target 模型对每个位置的 log 概率 latency_ms: float class SpeculativeDecoder: 投机解码引擎 def __init__(self, draft_model_fn, target_model_fn, speculate_length: int 5, temperature: float 1.0): self._draft_fn draft_model_fn self._target_fn target_model_fn self.speculate_length speculate_length self.temperature temperature # 统计信息 self._stats { total_tokens_generated: 0, total_draft_tokens: 0, total_accepted_tokens: 0, total_target_calls: 0, total_draft_calls: 0, } def generate(self, prompt_tokens: list[int], max_tokens: int 256) - list[int]: 使用投机解码生成文本 generated list(prompt_tokens) while len(generated) - len(prompt_tokens) max_tokens: remaining max_tokens - (len(generated) - len(prompt_tokens)) k min(self.speculate_length, remaining) # Step 1: Draft 模型快速生成 K 个候选 Token draft_result self._draft_generate(generated, k) self._stats[total_draft_calls] 1 self._stats[total_draft_tokens] len(draft_result.tokens) # Step 2: Target 模型单次前向验证 verify_result self._verify(generated, draft_result) self._stats[total_target_calls] 1 # Step 3: 根据验证结果更新生成序列 if verify_result.rejected_at -1: # 全部接受 generated.extend(draft_result.tokens) self._stats[total_accepted_tokens] len(draft_result.tokens) # Target 模型在最后一个位置也生成了一个 Token # 可以额外获取 1 个 Token bonus_token self._sample_from_logits( verify_result.target_log_probs[-1] ) generated.append(bonus_token) self._stats[total_accepted_tokens] 1 else: # 部分接受 accepted_tokens draft_result.tokens[:verify_result.rejected_at] generated.extend(accepted_tokens) generated.append(verify_result.corrected_token) self._stats[total_accepted_tokens] len(accepted_tokens) 1 self._stats[total_tokens_generated] \ len(generated) - len(prompt_tokens) return generated def _draft_generate(self, prefix: list[int], k: int) - DraftResult: Draft 模型生成 K 个候选 Token start time.time() tokens [] log_probs [] current list(prefix) for _ in range(k): # 调用 Draft 模型获取下一个 Token 的分布 logits self._draft_fn(current) probs self._softmax(logits / self.temperature) token np.random.choice(len(probs), pprobs) log_prob float(np.log(probs[token] 1e-10)) tokens.append(int(token)) log_probs.append(log_prob) current.append(token) latency (time.time() - start) * 1000 return DraftResult(tokenstokens, log_probslog_probs, latency_mslatency) def _verify(self, prefix: list[int], draft: DraftResult) - VerifyResult: Target 模型验证候选 Token start time.time() # 构造验证输入prefix draft tokens verify_input prefix draft.tokens # Target 模型单次前向传播 # 返回每个位置的 logits包括 draft tokens 的位置 all_logits self._target_fn(verify_input) # 提取 draft tokens 位置的 logits # 位置 i 对应 prefix 长度 i - 1 处的 logits # 因为位置 i 的 Token 是由前 i-1 个 Token 预测的 prefix_len len(prefix) target_log_probs [] for i in range(len(draft.tokens)): pos prefix_len i - 1 if pos 0: pos 0 logits all_logits[pos] probs self._softmax(logits / self.temperature) token draft.tokens[i] target_log_probs.append(float(np.log(probs[token] 1e-10))) # 逐个验证候选 Token rejected_at -1 corrected_token -1 corrected_log_prob 0.0 accepted_count 0 for i in range(len(draft.tokens)): draft_lp draft.log_probs[i] target_lp target_log_probs[i] # 计算接受概率 # p_accept min(1, p_target / p_draft) # 在 log 空间: log(p_accept) min(0, target_lp - draft_lp) log_accept_ratio target_lp - draft_lp if log_accept_ratio 0: # p_target p_draft一定接受 accepted_count 1 else: # p_target p_draft以概率 p_target/p_draft 接受 accept_prob np.exp(log_accept_ratio) if np.random.random() accept_prob: accepted_count 1 else: # 拒绝从修正分布中采样 rejected_at i pos prefix_len i - 1 if pos 0: pos 0 corrected_token self._sample_corrected( all_logits[pos], draft.tokens[i], draft_lp, target_lp, ) corrected_log_prob target_log_probs[i] break latency (time.time() - start) * 1000 return VerifyResult( accepted_countaccepted_count, rejected_atrejected_at, corrected_tokencorrected_token, corrected_log_probcorrected_log_prob, target_log_probstarget_log_probs, latency_mslatency, ) def _sample_corrected(self, logits: np.ndarray, draft_token: int, draft_lp: float, target_lp: float) - int: 从修正分布中采样 Token probs self._softmax(logits / self.temperature) draft_prob np.exp(draft_lp) # 修正分布: max(0, p_target - p_draft) / Z corrected np.maximum(0, probs - draft_prob) total np.sum(corrected) if total 1e-10: # 修正分布退化为均匀分布使用 Target 分布 return int(np.random.choice(len(probs), pprobs)) corrected / total return int(np.random.choice(len(corrected), pcorrected)) def _softmax(self, logits: np.ndarray) - np.ndarray: 数值稳定的 Softmax shifted logits - np.max(logits) exp_vals np.exp(shifted) return exp_vals / (np.sum(exp_vals) 1e-10) def _sample_from_logits(self, logits) - int: 从 logits 中采样 if isinstance(logits, (int, float)): return int(logits) probs self._softmax(np.array(logits) / self.temperature) return int(np.random.choice(len(probs), pprobs)) def get_stats(self) - dict: 获取加速统计 total_draft self._stats[total_draft_tokens] total_accepted self._stats[total_accepted_tokens] acceptance_rate ( total_accepted / total_draft * 100 if total_draft 0 else 0 ) # 加速比 平均每次 Target 调用生成的 Token 数 target_calls self._stats[total_target_calls] avg_tokens_per_call ( total_accepted / target_calls if target_calls 0 else 1.0 ) return { **self._stats, acceptance_rate: round(acceptance_rate, 1), avg_tokens_per_target_call: round(avg_tokens_per_call, 2), effective_speedup: round(avg_tokens_per_call, 2), }四、实际加速效果与 Draft 模型选择投机解码能不能提速主要看两点Draft 模型的接受率以及 Target 模型的验证开销。接受率是关键Draft 模型和大模型的输出分布越接近接受率越高。实测中如果用同一模型系列的 Draft比如 LLaMA-7B 给 LLaMA-70B 当小弟5-Token 投机的接受率大概在 70%-80%平均加速比 2-2.5x。如果 Draft 模型和大模型完全不搭界接受率可能掉到 40%-50%加速比只有 1.3-1.5x甚至不如直接跑自回归解码。验证开销不能忽视Target 模型验证时需要一次前向传播处理 K1 个 TokenK 个候选 1 个额外位置。虽然比 K 次单步 Decode 高效KV Cache 可以复用但 Prefill 阶段的计算量是随序列长度增长的。如果 K 设得太大验证的 Prefill 开销可能会把投机带来的加速抵消掉。经验上K 值在 4-7 之间比较合适。Draft 模型怎么选有两个硬指标——推理速度至少比 Target 快 3x且输出分布要接近。最佳实践是用同一模型系列的较小版本。比如 LLaMA-70B 的 Draft 模型用 LLaMA-7B 或 LLaMA-13B 都不错。如果 Draft 太小比如 1B速度快但接受率低如果太大比如 30B接受率高但速度优势不明显。适用场景投机解码适合单请求低延迟场景比如实时对话这时候 Decode 阶段的串行性是主要瓶颈。如果是高并发批量推理GPU 本来就能通过批量并行把算力吃满投机解码的加速效果就有限了。五、总结Speculative Decoding 通过 Draft-Verify 机制打破了自回归解码的串行瓶颈在保证输出分布无偏的前提下通常能实现 2-2.5x 的加速。加速比的核心决定因素是 Draft 模型的接受率而接受率取决于 Draft 与 Target 模型的分布匹配度。建议用同一模型系列的较小版本作为 Draft 模型投机长度设为 5。投机解码最适用于单请求低延迟场景在高并发批量场景下收益有限。

相关新闻