【LLM】OPD

发布时间:2026/5/21 17:39:34

【LLM】OPD 完整流程就是Prompt 喂给 StudentStudent 自回归生成输出序列把Prompt Student生成的序列整个喂给 Teacher 和 Student 各做一次前向传播在每个生成的 token 位置上拿到 Teacher 的 logits 和 Student 的 logits计算 KL 散度反向传播更新 Student一、为什么需要 OPD设计动机要理解 OPD 的设计思路先要理解它解决的是什么问题。传统的知识蒸馏Off-Policy做法是准备一批固定的文本数据Teacher 和 Student 同时在这批数据上计算 logits然后对齐。这个方式有一个根本性的缺陷叫做分布偏移Distribution Shift。想象一下 Student 在推理时的真实状态它是自回归生成的第 t 步的输入 prefix 是它自己在前 t-1 步生成的内容。但训练时prefix 全部来自外部数据集Student 从来没有在自己生成的文字上被训练过。于是就出现了一个矛盾训练时 Student 看到的是完美的 prefix推理时它看到的是自己可能犯错的 prefix。一旦 Student 在某一步生成了一个偏差词后续所有 token 都处于训练时从未遇到过的分布之下错误会滚雪球式累积。OPD 的核心思想就是既然推理时 Student 要面对自己生成的 prefix那训练时就让它在自己生成的 prefix 上学习。让 Teacher 跟着 Student 走而不是让 Student 跟着 Teacher 走。二、具体实现流程整体流程给定一个 promptx xx完整的一个训练步骤如下Step 1 — Student 自回归采样把 prompt 喂给 Student让它自回归地生成完整的输出序列 $hat{y} h a t y 1 , y ^ 2 , … , y ^ T ) hat{y}_1, \hat{y}_2, \ldots, \hat{y}_T)haty1​,y^​2​,…,y^​T​)。注意此时是真实的采样或 greedy decode不是 teacher-forcing。Step 2 — 双向前向传播把完整的序列( x , y ^ ) (x, \hat{y})(x,y^​)同时喂给 Teacher 和 Student各做一次前向传播。由于 Transformer 的 causal mask每个位置 t 只能看到它之前的 token所以一次前向传播就能拿到所有位置的 logits等价于逐个 prefix 单独输入但效率高得多。Step 3 — 计算每个位置的 KL 散度在每个生成 token 的位置 t 上分别拿到Teacher 的概率分布p T p_TpT​cdot \mid x, \hat{y}_{t})$Student 的概率分布p S p_SpS​cdot \mid x, \hat{y}_{t})$计算这两个分布之间的 KL 散度。注意 prompt 部分的位置不参与 loss 计算。Step 4 — 汇总损失更新 Student把所有位置的 KL 累加对 Student 的参数做反向传播KaTeX parse error: Cant use function $ in math mode at position 83: …hat{y} \sim p_S$̲cdot|x)} \left[…重复以上步骤每次迭代 Student 参数更新后下一轮采样的 $hat{y}$ 自然也会随之变化这就是On-Policy的含义——训练分布始终跟随当前策略。KL 方向的选择KL 散度不对称方向的选择有实质影响前向 KL即D K L ( p T ∥ p S ) D_{KL}(p_T \| p_S)DKL​(pT​∥pS​)也叫 Mode-Covering。它惩罚 Student 对 Teacher 高概率区域给出低概率驱使 Student 尽量覆盖 Teacher 的所有可能性。实现简单训练稳定大多数工作默认使用这个方向。反向 KL即D K L ( p S ∥ p T ) D_{KL}(p_S \| p_T)DKL​(pS​∥pT​)也叫 Mode-Seeking。它惩罚 Student 在 Teacher 低概率的地方给出高概率驱使 Student 专注于模仿 Teacher 最主要的生成模式可以忽略一些次要的长尾分布。MiniLLM 使用的就是这个方向理论上更适合生成任务但由于梯度需要用 REINFORCE 估计训练方差较高。三、与 Off-Policy 蒸馏的本质区别两者的数学形式看起来类似但期望的分布完全不同Off-Policy 计算的是KaTeX parse error: Cant use function $ in math mode at position 53: …ata}}}[\,D_{KL}$̲cdots)\,]OPD 计算的是KaTeX parse error: Cant use function $ in math mode at position 37: …t}\,\sim\, p_S$̲cdot|x)}[\,D_{K…这一个下标的差异决定了 Student 是否能在自己真实会走到的状态上得到指导。四、优点分布一致性强训练和推理时 Student 面对的 prefix 分布相同从根本上解决了 exposure bias 问题在长文本生成任务上效果提升尤为明显。错误恢复能力Student 生成了错误的 prefix 时Teacher 会在这个错误的上下文上给出指导Student 能学到走偏之后如何纠正这是 Off-Policy 完全无法做到的。无需高质量标注数据训练数据完全由 Student 自己生成只需要 prompt不需要人工标注的 ground truth 回答数据准备成本低。兼容性好可以和 SFT、RLHF 等训练范式结合也可以作为任何序列生成模型的压缩手段。五、缺点计算开销大每个训练步都需要 Student 做一次完整的自回归采样然后 Teacher 和 Student 各做一次前向传播。相比 Off-Policy 只需要两次前向传播训练成本大约高 2~3 倍Teacher 越大开销越明显。训练初期不稳定Student 初期质量很差采出来的序列往往是乱的导致 Teacher 打分的上下文也很混乱梯度信号噪声大。常见解决办法是先做 Off-Policy 热身warm-up让 Student 具备基本能力后再切换到 On-Policy。梯度方差问题尤其是使用反向 KL 时梯度需要通过 REINFORCE 估计方差较高需要加 baseline 或 variance reduction 技术来稳定训练。Teacher 必须在线可用Teacher 需要在训练过程中实时做推理无法提前把 Teacher 的输出缓存下来这意味着 Teacher 必须全程驻留在显存中对硬件资源要求较高。六、应用场景LLM 模型压缩这是最主流的应用用 70B、405B 级别的大模型作为 Teacher蒸馏出 7B、13B 级别的小模型使其在对话、推理等任务上尽量接近大模型的能力。MiniLLM、DistiLLM 都是这个场景下的代表工作。长文本生成任务摘要、故事续写、代码生成等任务中输出序列很长Off-Policy 的分布偏移问题会随序列长度成倍放大OPD 的优势在这类任务上尤为突出。推理与规划任务对于数学解题、逻辑推理这类需要多步骤连贯输出的任务中间步骤的错误会直接影响后续步骤OPD 能让 Student 在自己真实会生成的推理链上得到纠正。领域适配蒸馏在医疗、法律、代码等垂直领域收集高质量标注数据成本极高而 OPD 只需要 prompt由 Student 自己生成回答后再由领域 Teacher 打分指导极大降低了数据门槛。强化学习结合OPD 和 RLHF/PPO 的结构非常相似都是 on-policy 的策略优化可以把 KL 蒸馏损失作为 reward shaping 的一部分约束 RL 训练时 Student 不要偏离 Teacher 太远防止 reward hacking。

相关新闻