Skip to content

知识蒸馏:让小学生学会大学教授的本事

2006 年,还在 Toronto 大学的 Geoffrey Hinton 在 Caruana 的一篇关于模型压缩的工作启发下,开始认真思考一个问题:一个复杂模型的输出里,到底藏了多少被浪费的信息?

我们通常训练分类模型的方式是给它看一张图,然后告诉它 "这是 5"——一个 one-hot 标签,只有正确答案是 1,其他都是 0。但一个训练好的大模型在判断一张手写数字 "5" 的图片时,它的输出可能是这样的:

类别 0: 0.002    类别 3: 0.100    类别 6: 0.050
类别 1: 0.001    类别 4: 0.005    类别 7: 0.003
类别 2: 0.010    类别 5: 0.800    类别 8: 0.020    类别 9: 0.009

你看,虽然模型很确信这是 "5"(80%),但它还告诉你一件有意思的事:这个 "5" 长得有 10% 像 "3",5% 像 "6"。这可不是噪音——如果你自己写一个 "5",有时候上半部分确实容易被误认成 "3"。这种各类别之间的相似性结构信息,被 Hinton 称为 "暗知识"(Dark Knowledge)

暗知识就是蒸馏的秘密武器。如果让小模型(学生)直接学 one-hot 的硬标签,它只知道 "5 是正确答案"——但不知道 "5 长得有点像我 6"、"3 和 5、6、8 经常出现在同一类笔画里"。而这些信息,恰恰是大模型花了几千 GPU 小时学到的世界知识中分子级别的精华。

温度参数:别被名字骗了,它不是什么玄学

知识蒸馏有一个让新手困惑的超参数——温度 T(Temperature)。你可能会想:"什么鬼,神经网络跟热力学有什么关系?"

实际上这是一个极其朴素的设计。标准的 softmax 函数是:

python
# 标准 softmax
prob_i = exp(logit_i) / Σ_j exp(logit_j)

大模型的 logit 分布通常非常尖锐——正确的那个 logit 可能比其他的高几十倍,softmax 之后概率几乎全是 1 和 0。这意味着暗知识被压缩没了——那些 "有点像 3"、"有点像 6" 的信息在概率值上看不出来(全是 0.00000...)。

温度参数的做法是在做 softmax 之前,先用 T 把 logit 缩小一圈

python
# 带温度的 softmax
prob_i = exp(logit_i / T) / Σ_j exp(logit_j / T)
  • T = 1:标准 softmax,不做任何平滑
  • T = 3~10:分布被 "摊平",原本 0.0001 的概率可能变成 0.05——暗知识变得可见了
  • T → +∞:极限情况下所有类别的概率趋近于均等分布,完全失去区分能力

直观地理解:好比你在 500 度的烤箱里烤牛排——所有细节都被烧焦了,你只能看到最突出的轮廓(T=1)。把温度降到 80 度,肉的颜色、纹理、汁水的分布都变得可见了(T=5~10),这些细节就是你希望小模型学到的。

训练时,教师和学生都使用同一个 T 产生软标签,计算 KL 散度作为软损失;学生同时在硬标签(one-hot)上算交叉熵作为硬损失。总损失是两者的加权:

L_total = α × L_hard(student_logits, one_hot_labels) + (1-α) × T² × L_soft(student_logits/T, teacher_logits/T)

那个 乘子(梯度缩放修正)很多人会漏掉——因为 softmax 除以 T 让梯度也缩了 T 倍,必须乘回来才能保持软目标和硬目标在损失中的比例关系。

实际调参经验:α 通常在 0.1~0.5 之间(软损失占主导),T 在 3~10 之间。T 太高会导致模型过度关注 "无关紧要的相似性",T 太低软标签退化为硬标签——蒸馏就没意义了。 我一般从 T=4,α=0.3 开始,跑几个 epoch 看验证集表现再调。

不止一种蒸馏方式

Logit 蒸馏(输出层)

最经典、最简单、用得最多。只匹配教师和学生的最终输出概率分布,不需要了解内部结构。

python
import torch
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.3):
    # 硬损失:学生 vs one-hot 标签
    hard_loss = F.cross_entropy(student_logits, labels)

    # 软损失:学生 vs 教师的软标签(都经过温度缩放)
    soft_student = F.log_softmax(student_logits / T, dim=-1)
    soft_teacher = F.softmax(teacher_logits / T, dim=-1)
    soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T ** 2)

    return alpha * hard_loss + (1 - alpha) * soft_loss

特征蒸馏(中间层)

只匹配输出太粗糙了——就好比你只看了厨师最后的摆盘,却不知道他是怎么切菜、怎么控制火候的。特征蒸馏让学生的中间层表示去逼近教师的对应层表示。

常见做法包括:

  • Attention Transfer:让学生第 N 层的注意力矩阵去匹配教师第 N 层(或按比例对应)的注意力矩阵,损失用 MSE 或 KL
  • Hidden State MSE:对学生某一层的隐状态做线性投影后,跟教师对应层的隐状态算 MSE
  • FitNet:先训练学生的中间层去回归教师的中间层(Stage 1),再在整个模型上做 logit 蒸馏(Stage 2)

我个人觉得特征蒸馏在视觉任务(ResNet蒸馏MobileNet之类)上收益比 NLP 更大。LLM 的蒸馏,目前最有效的方式还是 数据蒸馏——用大模型生成高质量训练数据喂给小模型,这个我们在下一节展开。

数据蒸馏 / 指令蒸馏

这是 2023 年以来 LLM 蒸馏最主流的方式。思路简单粗暴:拿 GPT-4 或 DeepSeek-R1 这种顶级模型,在大量 prompt 上生成回答,然后用这些问答对去微调小模型。

Alpaca(斯坦福,2023年3月)是最早的尝试:用 GPT-3.5 生成 52000 条指令数据训练 LLaMA 7B,花了不到 $600 的 API 费,得到了类似 text-davinci-003 水平的模型。虽然后续评估发现实际能力折扣不小,但这个思路彻底开了 "穷人也能玩 LLM" 的大门。

Vicuna、Orca、WizardLM 都是这条路上的重要节点。Orca 的特别之处在于它不是只蒸馏最终回答——它让 GPT-4 在 system prompt 里输出 "思维过程"(explain your reasoning step by step),把隐式的推理过程直接暴露出来,小模型因此学到了 "怎么思考" 而不只是 "答案是什么"。

DeepSeek-R1 蒸馏:小模型推理能力爆炸

2025 年 1 月 DeepSeek 发布的 R1 论文,是蒸馏领域最近最震撼的一个案例。

DeepSeek-R1 本身是一个 671B 的 MoE 模型(激活 37B),通过纯强化学习训练出了惊艳的推理能力——它在 AIME 2024 数学竞赛题上拿到 79.8% 的准确率,跟 OpenAI o1 处于同一量级。但真正让社区震动的不是 R1 本身,而是他们对 R1 做了蒸馏后产出的那批小模型:

DeepSeek-R1-Distill-Qwen-1.5B 在 AIME 2024 上拿了 28.9%,DeepSeek-R1-Distill-Qwen-7B 拿了 55.5%,而 DeepSeek-R1-Distill-Qwen-32B 拿了 72.6%——一个 32B 的蒸馏模型几乎比肩原始的 671B 大模型。 更夸张的是,1.5B 这个小不点,在某些数学基准上甚至超过了未经推理优化的 GPT-4o-mini。

这里的关键发现是什么?蒸馏的不是知识,是推理模式。 R1 生成的数据包含了大量的 self-reflection(自我反思)、verification(验证步骤)、alternative exploration(备选方案探索)等推理行为——这些 "思考痕迹" 被蒸馏数据编码后,小模型在做 SFT 的过程中自动学会了这些模式,而不需要从头用 RL 去探索。

这跟 Orca 的思路一致但效果差了十万八千里——因为 R1 的推理质量远高于 GPT-4 随便说说的 "let me think step by step"。

蒸馏的边界:什么不能蒸馏?

蒸馏能力极强,但不是万能的。以下几点是实际工作中容易碰壁的地方:

第一,学生永远无法超越教师。 如果一个 7B 学生模型去学一个 70B 教师模型,它最理想的情况是达到教师的 90-95% 的能力——但不可能超过。因为学生的所有知识都来源于教师,而教师自己也有盲区和错误(比如偏见、幻觉模式),这些会被一并蒸馏。

第二,复杂的多跳推理最容易在蒸馏中 "坍缩"。 我见过不少蒸馏后的模型,单轮问答很溜,但三步以上的逻辑推理就崩了。原因很简单:多跳推理的搜索空间是指数级的,教师模型在第 3 步的概率分布非常分散(有很多可能的推理路径),而学生模型容量有限,被迫把这些路径 "平均化",最后输出一个安全的、模糊的、但很可能错误的结论。

第三,知识截止日期无法被蒸馏。 如果教师模型的训练数据只到 2024 年 6 月,蒸馏出的学生模型也只能知道到 2024 年 6 月——你想往里面 "灌入" 2025 年的新知识是不可能的,除非重新预训练。

第四,蒸馏可以作为 "病毒传播" 的载体。 教师模型如果内化了某些系统性偏见(比如 "医生都是男性"、"工程师都是男性"),蒸馏会忠实地把这些偏见传给所有学生模型——而且比原始模型更难察觉和修正。

实战 tip:用 Hugging Face 做蒸馏

实际工作中,如果你要蒸馏一个 LLM,最保险的路径是用 transformers + trl 库。伪代码:

python
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer
from datasets import Dataset

# 1. 用大模型 API 生成高质量数据(这一步是蒸馏的灵魂)
def generate_teacher_data(prompts, teacher_model):
    teacher_outputs = []
    for p in prompts:
        response = teacher_model.chat(p, temperature=0.7, max_tokens=2048)
        teacher_outputs.append({
            "prompt": p,
            "response": response
        })
    return Dataset.from_list(teacher_outputs)

# 2. 按 ChatML 格式组织
def format_chat(example):
    return {"text": f"<|user|>\n{example['prompt']}\n<|assistant|>\n{example['response']}"}

# 3. 用 SFTTrainer 微调小模型
trainer = SFTTrainer(
    model=student_model,
    train_dataset=distill_data.map(format_chat),
    tokenizer=tokenizer,
    max_seq_length=4096,
    args=TrainingArguments(
        output_dir="./distilled-model",
        per_device_train_batch_size=4,
        gradient_accumulation_steps=8,
        learning_rate=2e-5,
        warmup_ratio=0.03,
        num_train_epochs=3,
        fp16=True,
    ),
)
trainer.train()

注意:这里用的是 SFT(监督微调),而不是经典的 "一边 soft label 一边 hard label" 在线蒸馏——因为 LLM 蒸馏的核心瓶颈是生成高质量训练数据的过程(教师模型推理成本高),而不是训练时刻的损失函数设计。你花 1000 美元让 GPT-4 生成 100 万条高质量数据,比你在训练时做什么 KL 散度的花活重要得多。

一个值得记住的隐喻

知识蒸馏可以理解成:让一个世界级大厨(教师模型)写一本食谱(蒸馏数据),然后一个聪明但经验不足的年轻厨师(学生模型)反复照着做。年轻厨师学不会大厨 30 年积累的 "手感"——那种凭直觉知道什么时候加一撮盐、多翻一次锅的神秘直觉。但他能学会食谱上的每一道菜,而且做得相当不错。

蒸馏不会让草鸡变凤凰,但它能让草鸡学会凤凰的食谱——对大多数应用来说,这已经足够了。

基于 VitePress 构建 | 部署于 Cloudflare Pages