Transformer:一个改变游戏规则的注意力赌注
如果你在 2016 年告诉一个 NLP 研究员 "RNN 可以被完全抛弃",他大概率会觉得你疯了。毕竟从 Elman Network(1990)到 LSTM(1997)再到 GRU,二十多年来递归神经网络一直是序列建模的绝对王者。当时的共识是:序列天然有顺序,按顺序处理不是最合理的吗?
Jakob Uszkoreit 不这么想。他在 2016 年的 ACL 上听了一场关于 "分解注意力"(decomposable attention)的报告后,冒出了一个在当时看来近乎异端的想法——如果没有递归,纯靠注意力,能不能做翻译? 这个想法连他老爹 Hans Uszkoreit(知名计算语言学家)都觉得不太靠谱。但就是这个赌注,催生了 2017 年那篇题为 "Attention Is All You Need" 的论文,和整个现代 LLM 时代的基石架构。
先忘掉公式,理解直觉
大多数教程一上来就甩公式:Attention(Q,K,V) = softmax(QK^T/√d_k)V。但这对真正理解一点帮助都没有。我们换个方式。
假设你读这句话:"我把钥匙放在桌子上,后来找不到__了。"
你怎么知道空白处应该填 "它"(指代钥匙)还是 "桌子"?因为你的大脑在读到每个词时,会下意识地关注跟它相关的其他词。当你处理 "找不到" 的时候,你的注意力会自动回溯到 "钥匙"——因为 "钥匙" 跟 "找不到" 的语义关联最强。
Transformer 做的就是这个,只不过用的是向量运算。具体来说:
- 序列中的每个 token "我"、"把"、"钥匙"... 先被转成嵌入向量(embedding vector),比如 768 维
- 对每个 token,我们通过三个不同的权重矩阵生成三个向量:
- Query(查询):你在找什么?——"我的上下文里需要什么信息?"
- Key(键):你是什么?——"我能提供什么信息?"
- Value(值):如果你被选中,你提供什么内容?——"我携带的实际信息是什么?"
- 用当前 token 的 Q 去跟所有 token 的 K 做点积,得到 "注意力分数"——分数越高,说明这两个 token 越相关
- 用 softmax 把分数转成概率分布,然后按概率加权取所有 token 的 V,得到当前 token 的上下文表示
以 "I love NLP" 为例(假设嵌入维度 d=4,实际当然远不止这么小):
Token: "I" "love" "NLP"
Embed: [0.1, [0.3, [0.5,
0.2, 0.1, 0.3,
0.4, 0.6, 0.2,
0.3] 0.5] 0.1]
# 经过 W_Q, W_K, W_V 三个矩阵变换后:
Q("love") = [0.8, 0.2, 0.5, 0.1] # "love" 在寻找什么
K("I") = [0.3, 0.4, 0.2, 0.9] # "I" 在提供什么特征
K("love") = [0.7, 0.1, 0.6, 0.2] # "love" 自己的特征
K("NLP") = [0.4, 0.6, 0.1, 0.8] # "NLP" 的特征
# 计算 "love" 对每个 token 的注意力分数(点积 / √d)
score("love","I") = dot([0.8,0.2,0.5,0.1], [0.3,0.4,0.2,0.9]) / 2
= (0.24+0.08+0.10+0.09) / 2 = 0.255
score("love","love")= (0.56+0.02+0.30+0.02) / 2 = 0.450 # 最高!
score("love","NLP") = (0.32+0.12+0.05+0.08) / 2 = 0.285
# softmax 后: ["I":0.28, "love":0.39, "NLP":0.33]
# 这意味着在理解 "love" 时,39% 注意力给自己,33% 给 "NLP",28% 给 "I"注意 √d_k 这个除法的巧妙之处:当维度 d 很大(比如 128),点积的值会随维度线性增大,导致 softmax 输出极度尖锐(几乎所有概率集中在一个 token 上),梯度接近零,模型学不动。除以 √d_k 就是把这个 "过热" 的分布拉平滑一点。这是一个工程细节,但没它训练不了大模型。
多头注意力:每个脑袋都在看什么?
单头注意力只能关注一种 "相关性模式"。但语言是多层次的——一个词在同一句话里可以跟主语有语法关系,跟另一个词有语义关系,跟第三个词有共指关系。多头注意力就是同时运行多组独立的 QKV 投影,让每个头去学不同的关系模式。
从可解释性研究来看(参考 Clark et al. 2019 的 BERT 注意力分析),不同头确实分化出不同功能:
- 语法头:固定关注依存关系中的 head 词(比如形容词总是关注它修饰的名词)
- 位置头:固定关注前一个或后一个 token(形成类似 CNN 的感受野)
- 分隔符头:对
[SEP]、句号等边界 token 有高注意力 - 全局头:均匀关注所有 token(类似 mean pooling)
我个人在可视化一些小的 Transformer 模型时,确实看到了这种分化,但也发现一个有趣的现象:浅层(前几层)的注意力模式相对分散、可解释性较强;深层(后几层)的注意力分布越来越 "奇怪"——它们学到的东西对 loss 下降很有用,但对人类来说很难描述。这其实暗示了一个问题:注意力矩阵未必是可解释的特征载体,更多时候它只是一个高效的加权信息路由机制。
实际代码中,多头注意力非常直观。以下是一个简化但可运行的 PyTorch 实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model=512, n_heads=8, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_k = d_model // n_heads
self.n_heads = n_heads
# 一次性投影所有头的 Q, K, V
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
B, T, D = x.shape # batch, seq_len, d_model
# 投影并拆分为多头: (B, T, D) → (B, n_heads, T, d_k)
Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
# Scaled Dot-Product Attention
scores = (Q @ K.transpose(-2, -1)) / (self.d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = self.dropout(F.softmax(scores, dim=-1))
# 合并多头: (B, n_heads, T, d_k) → (B, T, D)
out = (attn @ V).transpose(1, 2).contiguous().view(B, T, D)
return self.W_o(out)位置编码:从正弦波到旋转门的进化
Self-Attention 有个天生缺陷——它对位置完全不敏感。"我打你" 和 "你打我" 在纯注意力眼里是一样的一堆 token,只是排列不同而已。所以必须注入位置信息。
原始方案:正弦余弦编码
Vaswani 的原始方案用固定频率的正弦余弦函数,按公式 PE(pos, 2i)=sin(pos/10000^(2i/d)) 为偶数维度,余弦为奇数维度。为什么选正弦余弦?两个原因:第一,它们是连续的、有界的([-1,1]),数值稳定;第二,也是最精妙的地方——正弦/余弦的位置编码具有线性平移不变性:pos+k 的位置编码可以通过 pos 的编码经过一个简单的线性变换得到。这意味着模型可以通过学习这个线性变换,轻易地推导出两个 token 之间的相对距离。
但实际上,大部分现代 LLM(如 GPT-2、GPT-3)压根不用正弦余弦,而是直接用可学习的位置嵌入——每个位置 0,1,2,...,2048 直接分配一个可训练的向量,让模型自己学。效果差不多,实现更简单。
RoPE:一场位置编码的革命
2023 年,苏剑林(追一科技)提出的 RoPE(Rotary Position Embedding)成了现代开源 LLM 的标配——LLaMA、Qwen、Mistral、DeepSeek 全在用。它的核心思想非常优雅:不改变注意力矩阵的计算结构,而是在 Q 和 K 向量上直接施加旋转变换,使点积结果天然包含相对位置信息。
数学上,RoPE 将 Q 和 K 的每一对维度(2i, 2i+1)视为一个 2D 平面的坐标,然后按位置 m 施加一个旋转角 mθ_i,其中 θ_i = 10000^(-2i/d):
f(q, m) = q * e^(imθ) # 复数表示,本质是旋转做完这个旋转后,两个 token 的 Q 和 K 做点积时,结果只跟它们的相对位置差 (m-n) 有关,跟它们的绝对位置 m、n 无关。这个性质太重要了——它意味着模型在训练时见到的位置关系(比如 "前后两个词距离 3 步")可以在推理时泛化到更长的序列上。这就解释了为什么使用 RoPE 的模型(如 LLaMA)的外推能力(extrapolation)远好于使用绝对位置编码的 GPT-3。
RoPE 还有一个工程上的变体——NTK-aware 缩放(也称为 "动态 NTK"),它通过微调高频维度的旋转频率,可以在不重新训练的情况下将上下文窗口从 4K 扩展到 32K 甚至 128K。这是 2023-2024 年开源社区最 clever 的 "白嫖" 技巧之一。
为什么 Transformer 干掉了 RNN/LSTM
这个问题值得认真回答,因为它决定了整个 AI 产业的走向。
第一,并行性。 RNN 的核心计算是 h_t = f(W·x_t + U·h_{t-1})——每个时间步的计算必须等上一步完成。一个 1000 token 的序列,RNN 必须算 1000 步。而 Transformer 的 Self-Attention 是一次性对所有 token 做矩阵乘法——QK^T 这一步同时计算了所有 token 对之间的关系。这对 GPU 意味着什么?GPU 是为矩阵乘法设计的——这就是它的 "母语"。RNN 在 GPU 上利用率可能只有 30%,而 Transformer 能轻松跑到 70%+。
第二,长距离依赖。 LSTM 理论上可以通过门控机制保留长距离信息,但实践中的信息衰减是真实存在的。在一个 2000 token 的上下文中,LSTM 要记住开头的内容,必须经过 2000 次隐藏状态更新——每一步都是信息丢失的机会。而 Transformer 中任意两个 token 的距离是常数 O(1)——无论它们相隔多远,Self-Attention 一步到位。
第三,可扩展性(Scalability)。 这是最关键的。Transformer 的结构极其规整:同一层 Self-Attention + FFN,重复 N 次。这就意味着,你想让模型变强,几乎只需要调大几个数字:层数加倍、宽度加倍、数据加倍。这种 "暴力美学" 式的扩展性跟 Scaling Law 完美契合。而 RNN/LSTM 因为梯度传播路径长、训练不稳定,很难 scale 到千亿参数。
工程实战:没有这些优化你根本跑不动大模型
学术界讲 Transformer 原理,工业界讲 Transformer 优化。下面四个技术决定了你能不能只用 8 张 A100 跑一个像样的 LLM。
KV Cache
在自回归生成时,每生成一个新 token,我们需要拿它去关注之前所有的 token。但之前 token 的 K 和 V 已经在上一轮算过了——重新算一遍是巨大的浪费。 KV Cache 就是把已经算过的 K 和 V 存在显存里,生成新 token 时只算新 token 的 Q、K、V,然后用缓存的旧 K、V 和新 K、V 一起做注意力。
问题:一个 70B 的模型,batch=1,序列长度 4096 时,KV Cache 吃多少显存?大约 2 × 层数 × kv_head数 × d_head × seq_len × 2字节(fp16) = 2 × 80 × 8 × 128 × 4096 × 2 ≈ 2.6GB。批量推理时这个数字会乘以 batch size,很快成为显存瓶颈。
Flash Attention
Tri Dao 在 2022 年提出的 Flash Attention 可能是过去五年里对 Transformer 推理效率影响最大的算法。它的核心洞察是:注意力计算的瓶颈不是浮点运算,而是显存读写。 标准实现中,N×N 的注意力矩阵需要先从 HBM(高带宽显存)读到 SRAM,softmax 完再写回去——HBM 的带宽严重跟不上。
Flash Attention 用 tiling(分块计算)+ online softmax(在线归一化)两个技巧,把完整的注意力矩阵永远留在 SRAM 里,每次只读一个 tile,算完 softmax 后立即丢弃中间结果。最终结果是:显存占用从 O(N²) 降到 O(N),速度提升 2-4 倍,而且数学上完全等价,没有任何近似。
Flash Attention-2(2023 年)进一步优化了 GPU 上的 work partitioning,把硬件利用率从 25-40% 提到了 50-73%。Flash Attention-3(2024 年,针对 H100)用上了 Hopper 架构的异步指令和 FP8 数据类型,在 H100 上达到 740 TFLOPS——已经非常接近理论峰值。
GQA / MQA:多头注意力的瘦身手术
标准的多头注意力(MHA)中,每个头都有独立的 Q、K、V 投影——这意味着 KV Cache 的大小 = 层数 × 头数 × d_head × seq_len × 2。MQA(Multi-Query Attention,2019)的激进方案:所有头共享同一组 K 和 V,只有 Q 保持多头。KV Cache 直接除以头数(比如 8 头,显存少 8 倍)。但代价是注意力表达能力下降——所有头只能从相同的 Key 里抽取信息。
GQA(Grouped-Query Attention,2023,Ainslie et al.)是一个折中方案:把 Q 头分成若干组(比如 8 个 Q 头分 4 组),每组内共享一组 K 和 V。LLaMA 2 70B 就用了 8 组(原来 64 个 Q 头,K/V 头减少到 8 个),既节省了显存,又保持了足够的注意力多样性。实际用下来,GQA 的 perplexity 跟 MHA 几乎没有差距,是当前最实用的方案。
最后说两句
Transformer 在 2017 年出现时,没人想到它会在六年后变成万亿美元产业的基础设施。它的美不在于复杂——恰恰相反,它把序列建模这件事提炼到了最简单的形式:所有 token 同时关注所有其他 token,然后叠加 FFN 做非线性变换,周而复始。这种简单性就是它可扩展的秘密。
如果你现在才刚开始学 Transformer,建议不要只读书——亲手用 PyTorch 从零实现一个,跑一个简单的中文分词加序列标注任务,你会发现很多 "直觉" 是在跑代码的时候建立起来的。那些公式看起来复杂,写出来不过几十行。