小白也能看懂的知识蒸馏:大模型变小模型的“浓缩咖啡“术

发布时间:2026/6/30 14:52:03

小白也能看懂的知识蒸馏:大模型变小模型的“浓缩咖啡“术 系列文章AI大模型知识体系 | 第三周·第三篇引言大模型太贵了能不能浓缩精华上一篇我们聊了量化——把模型的数字精度从 FP16 压到 INT8/INT4就像把高清视频压缩成标清体积小了但画面还在。但量化有个天花板再怎么压精度模型的结构没变参数量还是那么多。那有没有一种方法能让一个 700 亿参数的大模型浓缩成一个 70 亿参数的小模型还能保留大部分能力有这就是今天的主角——知识蒸馏Knowledge Distillation。打个比方你有一个教了 30 年书的老教授大模型知识渊博但讲课慢、出场费贵。你能不能让老教授把毕生所学传授给一个年轻讲师小模型让年轻讲师也能讲出 80% 的水平但出场费只要十分之一知识蒸馏就是让大模型当老师小模型当学生把知识从老师脑子里蒸馏到学生脑子里的过程。一、知识蒸馏的核心思想软标签 vs 硬标签1.1 硬标签非黑即白的标准答案传统训练中模型学习的是硬标签Hard Labels。比如一张猫的图片标签就是[猫1, 狗0, 鸟0]——非黑即白只有正确答案是 1其他全是 0。这就像考试只告诉你正确答案是 C但没告诉你 A 和 B 错在哪里、D 离正确答案有多远。信息量很有限。1.2 软标签老师给的参考分布大模型输出的不是非黑即白的答案而是一个概率分布。比如对同一张图片大模型可能输出猫0.85, 狗0.10, 鸟0.03, 老虎0.02这个概率分布就是软标签Soft Labels。它不仅告诉你答案是猫还告诉你这图跟狗也有点像跟鸟和老虎几乎没关系。软标签的信息量远大于硬标签它包含了类别之间的相似性关系——这种关系就是大模型的暗知识Dark Knowledge。1.3 一个直观的例子假设我们要识别一个数字3硬标签和软标签的区别信息类型数字0数字1数字2数字3数字4数字5硬标签000100软标签0.010.010.050.850.040.04硬标签只说这是3软标签还说3 跟 2 和 5 比较像跟 0 和 1 不太像。学生模型从软标签中学到的远比硬标签多。二、蒸馏的损失函数KL散度与温度参数2.1 温度参数 T让分布更软大模型输出的原始分数叫 logits经过 softmax 转成概率。标准的 softmax 公式$$ \text{softmax}(z_i) \frac{e^{z_i}}{\sum_j e^{z_j}} $$引入温度参数 T后$$ \text{softmax}(z_i, T) \frac{e^{z_i/T}}{\sum_j e^{z_j/T}} $$温度 T 越大概率分布越平滑越软T 越小分布越尖锐越硬。打个比方温度就像照片的对比度。T1 是原始对比度T10 就像把对比度调低——原本黑的地方变灰了白的地方也变灰了细节反而更清楚了。T1 (原始): 猫0.85 狗0.10 鸟0.03 老虎0.02 T5 (升温): 猫0.45 狗0.25 鸟0.18 老虎0.12 T20 (高温): 猫0.28 狗0.26 鸟0.24 老虎0.22温度越高暗知识暴露得越多学生模型能学到的类别间关系越丰富。2.2 KL散度衡量两个分布的距离蒸馏的核心损失函数是KL 散度Kullback-Leibler Divergence用来衡量学生模型的输出分布和老师模型的输出分布之间的差距$$ D_{KL}(p | q) \sum_i p_i \log \frac{p_i}{q_i} $$其中 p 是老师的软标签分布q 是学生的软标签分布。KL 散度越小说明学生越像老师。2.3 蒸馏的总损失实际训练中学生模型的损失由两部分组成┌─────────────────────────────────────────────────────┐ │ 蒸馏总损失 α × 蒸馏损失 (1-α) × 学生损失 │ │ │ │ 蒸馏损失 KL散度(老师软标签, 学生软标签) × T² │ │ 学生损失 交叉熵(硬标签, 学生原始输出) │ │ │ │ α 通常取 0.5~0.9T 通常取 2~20 │ └─────────────────────────────────────────────────────┘为什么要乘 T²因为软标签的梯度会被温度稀释掉 1/T²乘回来才能保持梯度量级一致。2.4 代码蒸馏损失函数import torch import torch.nn as nn import torch.nn.functional as F def distillation_loss(student_logits, teacher_logits, labels, temperature5.0, alpha0.7): 知识蒸馏损失函数 Args: student_logits: 学生模型的原始输出 (logits) teacher_logits: 教师模型的原始输出 (logits) labels: 真实标签 (硬标签) temperature: 温度参数越高软标签越平滑 alpha: 蒸馏损失的权重 # 软标签损失KL散度 soft_student F.log_softmax(student_logits / temperature, dim-1) soft_teacher F.softmax(teacher_logits / temperature, dim-1) loss_soft F.kl_div(soft_student, soft_teacher, reductionbatchmean) * (temperature ** 2) # 硬标签损失标准交叉熵 loss_hard F.cross_entropy(student_logits, labels) # 总损失 total_loss alpha * loss_soft (1 - alpha) * loss_hard return total_loss # 使用示例 student_out torch.randn(32, 1000) # 学生输出: batch32, classes1000 teacher_out torch.randn(32, 1000) # 老师输出 labels torch.randint(0, 1000, (32,)) # 真实标签 loss distillation_loss(student_out, teacher_out, labels, temperature5.0, alpha0.7) print(f蒸馏损失: {loss.item():.4f})三、大模型蒸馏的三种方式知识蒸馏不是只有一种玩法。根据你能看到老师模型的多少信息蒸馏可以分为三种方式┌──────────────┐ │ 大模型(老师) │ └──────┬───────┘ │ ┌──────────────┼──────────────┐ │ │ │ ▼ ▼ ▼ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ 白盒蒸馏 │ │ 黑盒蒸馏 │ │ 特征蒸馏 │ │(Logit级) │ │(输出级) │ │(中间层) │ └──────────┘ └──────────┘ └──────────┘ 能看到logits 只能看到输出 能看到中间层 和概率分布 文本/结果 的特征表示3.1 白盒蒸馏Logit蒸馏你能访问老师模型的内部拿到 logits 或概率分布。这是最正统的蒸馏方式也就是 Hinton 2015 年论文里提出的方法。学生模型直接模仿老师模型的软标签分布。优点信息最丰富效果最好缺点需要老师模型的完整访问权限适用场景你自己训练的大模型想蒸馏成小模型部署# 白盒蒸馏训练循环简化版 def train_distillation(teacher, student, dataloader, optimizer, T5.0, alpha0.7): teacher.eval() # 老师冻结只做推理 student.train() for batch in dataloader: inputs, labels batch # 老师推理不计算梯度 with torch.no_grad(): teacher_logits teacher(inputs) # 学生推理 student_logits student(inputs) # 计算蒸馏损失 loss distillation_loss(student_logits, teacher_logits, labels, T, alpha) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step()3.2 黑盒蒸馏数据蒸馏 / 模仿学习你只能调用老师模型的 API拿到它的输出文本看不到内部 logits。这就像你请不到老教授来亲自带学生但你可以把老教授的讲义和答题记录拿来给学生学习。黑盒蒸馏的流程┌────────────┐ 提问 ┌────────────┐ │ 数据集/ │ ──────────→ │ 大模型API │ │ Prompt集合 │ │ (黑盒老师) │ └────────────┘ └─────┬──────┘ │ 返回回答 ▼ ┌────────────┐ │ 生成的 │ │ 指令数据 │ └─────┬──────┘ │ 作为训练数据 ▼ ┌────────────┐ │ 小模型 │ │ (学生) │ └────────────┘具体做法准备一批 prompt问题/指令用大模型如 GPT-4生成回答把 (prompt, 回答) 组成训练数据用 SFT监督微调的方式训练小模型优点不需要模型内部权限只要有 API 就行缺点信息损失大只能学到输入→输出的映射学不到推理过程适用场景用闭源大模型GPT-4、Claude蒸馏开源小模型3.3 特征蒸馏中间层蒸馏不仅看 logits还看老师模型中间层的特征表示。如果说 Logit 蒸馏是让学生模仿老师的最终答案那特征蒸馏就是让学生模仿老师的解题过程——每一步中间结果都要对齐。老师模型: 学生模型: Input → Layer1 → Layer2 → Layer3 → Output │ │ │ │ │ 对齐 │ 对齐 │ 对齐 │ 对齐 ▼ ▼ ▼ ▼ Input → Layer1 → Layer2 → Layer3 → Output具体做法在老师和学生的对应层之间加一个MSE 损失让学生的中间层特征尽量接近老师的中间层特征。def feature_distillation_loss(student_features, teacher_features, student_logits, teacher_logits, labels, T5.0, alpha0.5, beta0.3): 特征蒸馏损失 beta: 特征对齐损失的权重 # Logit蒸馏损失 loss_kd distillation_loss(student_logits, teacher_logits, labels, T, alpha) # 特征对齐损失MSE loss_feature 0.0 for s_feat, t_feat in zip(student_features, teacher_features): # 如果维度不同需要投影层对齐 loss_feature F.mse_loss(s_feat, t_feat) # 总损失 total_loss loss_kd beta * loss_feature return total_loss优点信息最充分学生能学到老师的思维方式缺点实现复杂需要设计层间映射关系适用场景同架构模型间的蒸馏如 BERT → 小BERT三种蒸馏方式对比对比维度白盒蒸馏黑盒蒸馏特征蒸馏需要模型内部权限是需 logits否只需 API是需中间层信息丰富度中等低高实现难度低最低高蒸馏效果好一般最好典型应用DistilBERTAlpaca/VicunaTinyBERT是否需要同架构不需要不需要最好同架构四、经典案例那些成功的蒸馏实践4.1 DistilBERTBERT 的浓缩版2019 年HuggingFace 发布了DistilBERT这是知识蒸馏在 NLP 领域最经典的案例之一。指标BERT-baseDistilBERT保留率参数量110M66M60%层数12650%推理速度1x1.6x60%GLUE 得分79.577.097%关键做法白盒蒸馏 特征蒸馏学生只取老师的一半层数12层→6层损失函数 蒸馏损失 掩码语言模型损失 余弦嵌入损失参数量减少 40%性能只掉 3%4.2 Alpaca斯坦福的穷人版 GPT2023 年斯坦福团队用GPT-3.5text-davinci-003作为老师生成了 52K 条指令数据然后用这些数据微调LLaMA-7B得到了 Alpaca。这是典型的黑盒蒸馏花 500 美元调用 GPT-3.5 API 生成数据用 SFT 训练 LLaMA-7B效果惊人在很多任务上接近 GPT-3.5Alpaca 证明了黑盒蒸馏虽然粗糙但性价比极高。4.3 Vicuna聊天能力的蒸馏Vicuna 的做法和 Alpaca 类似但数据来源不同——它从ShareGPT收集了 7 万条用户与 ChatGPT 的对话记录用这些对话数据微调 LLaMA。关键创新点用多轮对话数据而非单轮指令蒸馏出来的模型在聊天场景下表现更好在 GPT-4 评估中Vicuna-13B 达到了 ChatGPT 90% 的质量4.4 DeepSeek-R1 蒸馏推理能力的迁移2025 年初DeepSeek-R1 的蒸馏实践引起了广泛关注。DeepSeek-R1 是一个强大的推理模型类似 OpenAI o1团队将其推理能力蒸馏到了多个小模型上学生模型参数量数学(AIME)代码(LiveCodeBench)Qwen-1.5B1.5B28.9%18.4%Qwen-7B7B55.5%43.2%Qwen-32B32B72.6%57.2%Llama-70B70B78.9%63.1%关键发现用 R1 生成的 80 万条推理数据做 SFT小模型的推理能力大幅提升蒸馏比直接在小模型上做 RL 训练更高效推理能力是可以迁移的——这是之前很多人怀疑的五、实操用 HuggingFace 做一次简单的蒸馏下面我们用 HuggingFace Transformers 实现一个完整的蒸馏流程用 BERT-base 蒸馏出一个更小的模型。5.1 环境准备pip install transformers datasets torch accelerate5.2 完整蒸馏代码import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments from datasets import load_dataset # 1. 加载老师和模型 teacher_name bert-base-uncased student_name distilbert-base-uncased # 也可以用 prajjwal1/bert-tiny 做更小的学生 tokenizer AutoTokenizer.from_pretrained(teacher_name) teacher AutoModelForSequenceClassification.from_pretrained(teacher_name, num_labels2) student AutoModelForSequenceClassification.from_pretrained(student_name, num_labels2) # 冻结老师参数 for param in teacher.parameters(): param.requires_grad False # 2. 加载数据集 dataset load_dataset(imdb) # 电影评论情感分类 def tokenize(batch): return tokenizer(batch[text], paddingTrue, truncationTrue, max_length256) dataset dataset.map(tokenize, batchedTrue) dataset.set_format(torch, columns[input_ids, attention_mask, label]) # 3. 自定义蒸馏 Trainer class DistillationTrainer(Trainer): def __init__(self, teacher_model, temperature5.0, alpha0.7, *args, **kwargs): super().__init__(*args, **kwargs) self.teacher teacher_model self.temperature temperature self.alpha alpha def compute_loss(self, model, inputs, return_outputsFalse, **kwargs): # 学生前向传播 student_outputs model(**inputs) student_logits student_outputs.logits # 老师前向传播不计算梯度 with torch.no_grad(): teacher_outputs self.teacher(**inputs) teacher_logits teacher_outputs.logits # 蒸馏损失 loss_soft F.kl_div( F.log_softmax(student_logits / self.temperature, dim-1), F.softmax(teacher_logits / self.temperature, dim-1), reductionbatchmean ) * (self.temperature ** 2) # 硬标签损失 loss_hard F.cross_entropy(student_logits, inputs[labels]) # 总损失 loss self.alpha * loss_soft (1 - self.alpha) * loss_hard return (loss, student_outputs) if return_outputs else loss # 4. 训练 training_args TrainingArguments( output_dir./distilled-model, num_train_epochs3, per_device_train_batch_size16, learning_rate5e-5, logging_steps100, eval_strategyepoch, save_strategyepoch, ) trainer DistillationTrainer( teacher_modelteacher, temperature5.0, alpha0.7, modelstudent, argstraining_args, train_datasetdataset[train], eval_datasetdataset[test], tokenizertokenizer, ) trainer.train() # 5. 保存模型 student.save_pretrained(./my-distilled-bert) tokenizer.save_pretrained(./my-distilled-bert) print(蒸馏完成模型已保存到 ./my-distilled-bert)5.3 效果对比训练完成后你可以对比老师和学生模型的效果from transformers import pipeline teacher_pipe pipeline(sentiment-analysis, modelteacher, tokenizertokenizer) student_pipe pipeline(sentiment-analysis, modelstudent, tokenizertokenizer) test_text This movie is absolutely fantastic! Best film Ive ever seen. print(老师模型:, teacher_pipe(test_text)) print(学生模型:, student_pipe(test_text))六、蒸馏 vs 量化 vs 剪枝三大压缩方式大对比上一篇我们讲了量化今天讲了蒸馏还有一个常见的压缩方式是剪枝Pruning。三者有什么区别6.1 三种方式一句话总结量化把模型里的数字从高精度变成低精度FP16→INT8就像把高清图压缩成标清图蒸馏让大模型教小模型小模型从头学就像老教授带研究生剪枝把模型里不重要的参数直接删掉就像修剪树枝——去掉枯枝保留精华6.2 全面对比对比维度量化蒸馏剪枝核心思想降低数值精度大模型教小模型删除不重要参数是否需要训练不需要PTQ/ 需要QAT需要需要模型架构变化不变变换成小架构不变但参数变稀疏压缩比2x~4x任意取决于学生大小2x~10x精度损失较小可控较大容易过度剪枝实现难度低中高是否需要数据PTQ 不需要 / QAT 需要需要需要典型工具GPTQ、AWQ、bitsandbytesHuggingFace、DistilBERTTorch pruning、Lottery Ticket生活类比视频压缩名师出高徒修剪盆栽6.3 三者可以组合使用实际项目中这三种方式经常组合使用原始大模型 (70B FP16) │ ├──→ 蒸馏 → 小模型 (7B FP16) ← 先蒸馏缩小架构 │ │ │ ├──→ 量化 → 小模型 (7B INT4) ← 再量化降低精度 │ │ │ └──→ 剪枝量化 → 更小模型 (5B INT4) ← 剪枝量化 │ └──→ 量化 → 大模型 (70B INT4) ← 直接量化保持架构最佳实践先蒸馏再量化效果通常最好。因为蒸馏后的模型更小量化带来的精度损失也更容易控制。七、蒸馏的常见坑与避坑指南坑1温度参数设太大或太小T 太小如 T1软标签太硬跟硬标签差不多蒸馏效果差T 太大如 T100分布太平坦所有类别概率差不多学生学不到有用信息建议T 通常取 2~20从 T5 开始调坑2学生模型太小学生模型太小容量不够再怎么蒸馏也学不到老师的知识。就像让小学生学微积分——老师再厉害学生也吸收不了。建议学生模型参数量至少是老师的 1/10~1/5。坑3黑盒蒸馏数据质量差用大模型 API 生成数据时如果 prompt 设计不好生成的数据质量会很差学生模型学到的就是垃圾知识。建议精心设计 prompt 模板对生成数据进行质量过滤数据多样性要够不能全是同一类型的问题坑4只蒸馏不验证蒸馏完直接上线结果发现模型在某些场景下表现很差。建议蒸馏后一定要在目标场景的测试集上做全面评估不能只看通用 benchmark。八、总结今天我们聊了知识蒸馏——让大模型变小模型的浓缩咖啡术。核心要点回顾核心思想大模型当老师小模型当学生通过软标签传递暗知识关键机制温度参数 T 控制软标签的平滑度KL 散度衡量师生差距三种方式白盒蒸馏看 logits、黑盒蒸馏看输出、特征蒸馏看中间层经典案例DistilBERT、Alpaca、Vicuna、DeepSeek-R1 都证明了蒸馏的有效性组合使用蒸馏 量化是最佳拍档先缩小架构再降低精度一句话总结蒸馏是让小模型站在巨人的肩膀上用更少的参数达到接近大模型的效果。下篇预告模型压缩完了量化、蒸馏都搞定了但推理还是不够快下一篇我们聊推理加速框架——vLLM、TensorRT-LLM、TGI 这些框架是怎么让模型推理速度翻倍的PagedAttention、Continuous Batching、KV Cache 优化……这些黑科技到底是什么系列文章AI大模型知识体系 | 第三周·第四篇 ——推理加速框架让大模型飞起来

相关新闻