LLM 推理全过程的维度变化与核心公式

把 Llama 3 风格的 dense decoder-only 模型从 embedding 到采样的整条路径一次写清楚,顺便把常见变体(MHA / MQA / GQA / MLA、RMSNorm / LayerNorm、SwiGLU / GeGLU、Flash Attention / Paged Attention)的数学位置都嵌进去 — 读完之后自己能把这张图默写出来。整篇围绕一个事实展开:这套结构二十年没怎么变过,所有”推理优化”都是在同一张骨架的某个位置做局部手术。

参数符号约定 — Llama 3 8B 为基准

全篇沿用同一套符号,举例用 Llama 3 8B:

符号含义示例值(Llama 3 8B)
BBbatch size2
SSprompt 长度10
LL层数32
HHhidden dim4096
VV词表大小128256
nqn_qQ 头数32
nkvn_{kv}KV 头数(GQA)8
dd每头维度 =H/nq= H/n_q128
IIFFN 中间维度14336
TT生成 token 数100
ttdecode 当前步数1..T1..T

所有 shape 标注都按 PyTorch 习惯写成 [B, ..., H];权重矩阵按”输入维 × 输出维”约定写成 WR[in,out]W \in \mathbb{R}^{[\text{in}, \text{out}]}

核心公式速查 — Embedding · Norm · Attn · FFN · LM Head

Embedding

xi=E[token_idi]RH\mathbf{x}_i = E[\text{token\_id}_i] \in \mathbb{R}^{H}
  • xi\mathbf{x}_i — 位置 ii 的 embedding 向量
  • ERV×HE \in \mathbb{R}^{V \times H} — embedding 查找表(V 个 token,每个 H 维)
  • token_idi{0,1,,V1}\text{token\_id}_i \in \{0, 1, \ldots, V-1\} — 第 ii 个输入 token 的整数 ID

许多实现把 EE 与 LM Head 的 WlmW_{\text{lm}} 共享(tied embedding),省显存也略微正则;Llama 系列默认不共享

归一化:LayerNorm vs RMSNorm

标准 LayerNorm:

LN(x)=xμσ2+ϵγ+β,μ=1Hxi, σ2=1H(xiμ)2\text{LN}(\mathbf{x}) = \frac{\mathbf{x} - \mu}{\sqrt{\sigma^{2} + \epsilon}} \odot \boldsymbol{\gamma} + \boldsymbol{\beta}, \quad \mu = \tfrac{1}{H}\sum x_i,\ \sigma^{2} = \tfrac{1}{H}\sum (x_i - \mu)^{2}
  • xRH\mathbf{x} \in \mathbb{R}^{H} — 单 token 的隐状态向量
  • μ,σ2R\mu, \sigma^{2} \in \mathbb{R} — 在 HH 维上算的均值和方差
  • ϵ\epsilon — 防除零小常数(通常 10510^{-5}
  • γ,βRH\boldsymbol{\gamma}, \boldsymbol{\beta} \in \mathbb{R}^{H} — 可学习的 scale / shift
  • \odot — 逐元素乘

RMSNorm(Llama / Mistral / Qwen 主流选择):

RMSNorm(x)=x1Hi=1Hxi2+ϵγ\text{RMSNorm}(\mathbf{x}) = \frac{\mathbf{x}}{\sqrt{\frac{1}{H}\sum_{i=1}^{H} x_i^{2} + \epsilon}} \odot \boldsymbol{\gamma}
  • xRH\mathbf{x} \in \mathbb{R}^{H} — 单 token 的隐状态向量
  • γRH\boldsymbol{\gamma} \in \mathbb{R}^{H} — 可学习的 per-channel scale(无 β\boldsymbol{\beta}
  • ϵ\epsilon — 防除零小常数
  • 分母是 x\mathbf{x} 的均方根(root mean square),故得名 RMSNorm

RMSNorm 省了均值、也省了 β\boldsymbol{\beta},计算量和参数都约减半;实证发现对质量几乎无损。所有主流推理引擎都按 pre-norm 组织:norm 在残差分支内部,残差主干不经过 norm。

Q/K/V 投影 + 位置编码

Q=XWQ,K=XWK,V=XWVQ = X W_Q, \quad K = X W_K, \quad V = X W_V
  • XRB×S×HX \in \mathbb{R}^{B \times S \times H} — 上一步 RMSNorm 的输出
  • WQRH×nqdW_Q \in \mathbb{R}^{H \times n_q d} — Q 投影权重
  • WK,WVRH×nkvdW_K, W_V \in \mathbb{R}^{H \times n_{kv} d} — K、V 投影权重(GQA 下 nkv<nqn_{kv} < n_q,所以更窄)
  • QRB×S×nqdQ \in \mathbb{R}^{B \times S \times n_q d}K,VRB×S×nkvdK, V \in \mathbb{R}^{B \times S \times n_{kv} d} — 投影输出,会再 reshape 出 head 维

RoPE(Rotary Positional Embedding)的核心想法是:与其把”绝对位置 mm“加到 embedding 上(像原始 Transformer 的 sinusoidal),不如让位置以旋转的形式作用在 Q、K 上,使两个向量做内积时只剩下相对位置 nmn - m

把每个头的 dd 维切成 d/2d/2 个二维子空间,第 kk 个子空间(k=0,1,,d/21k = 0, 1, \ldots, d/2 - 1)对应坐标 (q2k,q2k+1)(q_{2k}, q_{2k+1})。位置 mm 上对这一对乘旋转角 mθkm\theta_k 的 2D 旋转矩阵:

(q2k(m)q2k+1(m))=(cos(mθk)sin(mθk)sin(mθk)cos(mθk))(q2kq2k+1),θk=base2k/d\begin{pmatrix} q'^{(m)}_{2k} \\ q'^{(m)}_{2k+1} \end{pmatrix} = \begin{pmatrix} \cos(m\theta_k) & -\sin(m\theta_k) \\ \sin(m\theta_k) & \phantom{-}\cos(m\theta_k) \end{pmatrix} \begin{pmatrix} q_{2k} \\ q_{2k+1} \end{pmatrix}, \quad \theta_k = \text{base}^{-2k/d}
  • m{0,1,,S1}m \in \{0, 1, \ldots, S-1\} — 当前 token 的绝对位置
  • k{0,1,,d/21}k \in \{0, 1, \ldots, d/2 - 1\} — 二维子空间索引(在单头维度 dd 上每两维一组)
  • (q2k,q2k+1)(q_{2k}, q_{2k+1}) — 投影后 Q 向量的第 kk 个 2D pair
  • (q2k(m),q2k+1(m))(q'^{(m)}_{2k}, q'^{(m)}_{2k+1}) — 在位置 mm 旋转后的 pair
  • θk\theta_k — 第 kk 个子空间的基础角速度
  • base\text{base} — 频率衰减基(Llama 默认 1000010000;YaRN 等会动态拉大)

展开成标量:

q2k(m)=q2kcos(mθk)q2k+1sin(mθk),q2k+1(m)=q2ksin(mθk)+q2k+1cos(mθk)q'^{(m)}_{2k} = q_{2k}\cos(m\theta_k) - q_{2k+1}\sin(m\theta_k), \qquad q'^{(m)}_{2k+1} = q_{2k}\sin(m\theta_k) + q_{2k+1}\cos(m\theta_k)

符号同上式 — 这只是把 2×2 矩阵展开成两条标量等式,方便对照实现。

这就是标准的逆时针旋转矩阵 R(ϕ)=(cosϕsinϕsinϕcosϕ)R(\phi) = \begin{pmatrix} \cos\phi & -\sin\phi \\ \sin\phi & \phantom{-}\cos\phi \end{pmatrix},作用在 2D 向量上等价于在复平面上乘 eiϕe^{i\phi} — 也是为什么 Llama 官方 repo 直接把 (q2k,q2k+1)(q_{2k}, q_{2k+1}) reshape 成 complex64 然后乘 cos(mθk)+isin(mθk)\cos(m\theta_k) + i\sin(m\theta_k)(HuggingFace 则等价地拆成 (qk,qk+d/2)(q_k, q_{k+d/2}) 的半旋转形式,二者只差一个坐标排列)。K 用同样的角度做同样的旋转,得到 k(n)\mathbf{k}'^{(n)}

为什么旋转能编码相对位置? 把整块 dd 维旋转记作 Rm\mathbf{R}_m(分块对角,第 kk2×22\times 2 块用角度 mθkm\theta_k)。它正交,且 RmRn=Rnm\mathbf{R}_m^{\top}\mathbf{R}_n = \mathbf{R}_{n-m},因此:

Rmq,  Rnk=qRmRnk=qRnmk\langle \mathbf{R}_m \mathbf{q},\; \mathbf{R}_n \mathbf{k} \rangle = \mathbf{q}^{\top} \mathbf{R}_m^{\top} \mathbf{R}_n \mathbf{k} = \mathbf{q}^{\top} \mathbf{R}_{n-m} \mathbf{k}
  • q,kRd\mathbf{q}, \mathbf{k} \in \mathbb{R}^{d} — 单头的 Q、K 向量(未做位置旋转)
  • Rm,RnRd×d\mathbf{R}_m, \mathbf{R}_n \in \mathbb{R}^{d \times d} — 位置 mmnn 对应的分块对角旋转矩阵
  • ,\langle \cdot, \cdot \rangle — 标准内积
  • 最后一步用了 RmRn=Rnm\mathbf{R}_m^{\top} \mathbf{R}_n = \mathbf{R}_{n-m}:旋转矩阵正交且角度可加

attention 算 QKQK^{\top} 时,每对 (qi,kj)(q_i, k_j) 的分数只依赖差 jij - i — 绝对位置在内积里被自动消掉,留下相对位置。这是 RoPE 的关键性质,也是它比加性位置编码更稳的根本原因。

频率谱设计θk=base2k/d\theta_k = \text{base}^{-2k/d}base\text{base} 一般取 1000010000dd单头维度,不是模型 hidden dim)让 d/2d/2 个子空间分到从快到慢的角速度:

  • k=0k = 0θ0=1\theta_0 = 1,周期 2π6.282\pi \approx 6.28 token,承担”近邻”信号。
  • k=d/21k = d/2 - 1θbase(d2)/d104\theta \approx \text{base}^{-(d-2)/d} \approx 10^{-4},周期 2π100006.28\approx 2\pi \cdot 10000 \approx 6.28\text{万} token,承担”远距离”信号。

几何级数分布让一个 head 内同时携带不同尺度的位置信号 — 与原始 Transformer 的 sinusoidal 同构,只是从加法搬到了乘法。

长上下文缩放:训练只见过 mLtrainm \le L_{\text{train}},低频子空间在 LtrainL_{\text{train}} 内甚至跑不完一个周期;推理一旦 m>Ltrainm > L_{\text{train}},低频角度进入训练分布外,attention 立刻退化。三类常见解法都在改 θk\theta_k

  • Position Interpolation(Chen et al. 2023):mm/sm \to m/s,等价 θkθk/s\theta_k \to \theta_k/s 同比例压缩所有频率。简单但牺牲高频精度。
  • NTK-aware scaling:只缩放低频、保留高频,等价于 basebasesd/(d2)\text{base} \to \text{base} \cdot s^{d/(d-2)}
  • YaRN(Peng et al. 2023):分频段处理 — 高频(周期 Ltrain\ll L_{\text{train}},已学过完整周期)不变,低频(周期 Ltrain\gg L_{\text{train}},没学过完整周期)按 PI 缩放,中间段平滑过渡,再叠加一个 1/t1/\sqrt{t} 的温度修正抵消缩放带来的注意力熵漂移。Llama 3.1 / 3.2 从 8K → 128K 用的就是 YaRN。

RoPE 只作用于 Q、K,不作用于 V — V 是被加权求和的值本身,不需要位置信号。

Scaled Dot-Product Attention

Attention(Q,K,V)=softmax ⁣(QKd+M)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^{\top}}{\sqrt{d}} + M\right) V
  • Q,K,VQ, K, V — 投影并 reshape 出 head 维后的张量,最后两维为 [S,d][S, d](多头维隐式 broadcast)
  • QKRS×SQK^{\top} \in \mathbb{R}^{S \times S} — 每对 (qi,kj)(q_i, k_j) 的相似度矩阵
  • d\sqrt{d} — 缩放因子,防止 softmax 进入梯度近乎为零的区域
  • MRS×SM \in \mathbb{R}^{S \times S} — causal mask,上三角为 -\infty(位置 ii 只能看 i\le i 的位置)
  • softmax 沿 K 维归一化,输出 attention 权重,再与 V 加权求和回到 [S,d][S, d]

多头变体:MHA / MQA / GQA / MLA

四种变体的差别只在 K、V 这一支:Q 永远是 nqn_q 个独立头,变的是有多少套独立 K、V,以及 K、V 是否被低秩压缩。下面用同一套记号一遍把原版 MHA 和三种变体的 head-level 公式写齐 — 省略 batch 和 layer 维,看位置 ii 上第 hh 个 head。

变体nkvn_{kv}单 token cache(per layer,fp16)代表模型
MHA=nq= n_q2nqd2B2 \cdot n_q \cdot d \cdot 2\text{B}GPT-2/3, Llama 1/2 7B
GQA分组 <nq< n_q2nkvd2B2 \cdot n_{kv} \cdot d \cdot 2\text{B}Llama 3, Qwen2, Mistral
MQA=1= 12d2B2 \cdot d \cdot 2\text{B}PaLM, Falcon
MLA低秩压缩(dc+dr)2B(d_c + d_r) \cdot 2\text{B}DeepSeek V2/V3

MHA — Multi-Head Attention(原版,Vaswani et al. 2017)

每个 Q head 都配一套独立的 K、V,nqn_q 套全独立:

qi(h)=WQ(h)xi,ki(h)=WK(h)xi,vi(h)=WV(h)xi,h=1,,nq\mathbf{q}^{(h)}_i = W_Q^{(h)} \mathbf{x}_i, \quad \mathbf{k}^{(h)}_i = W_K^{(h)} \mathbf{x}_i, \quad \mathbf{v}^{(h)}_i = W_V^{(h)} \mathbf{x}_i, \quad h = 1, \ldots, n_q
head(h)=softmax ⁣(Q(h)K(h)d+M)V(h),Attn=[head(1);;head(nq)]WO\text{head}^{(h)} = \text{softmax}\!\left(\frac{Q^{(h)} {K^{(h)}}^{\top}}{\sqrt{d}} + M\right) V^{(h)}, \qquad \text{Attn} = [\text{head}^{(1)}; \ldots; \text{head}^{(n_q)}] \cdot W_O
  • WQ(h),WK(h),WV(h)RH×dW_Q^{(h)}, W_K^{(h)}, W_V^{(h)} \in \mathbb{R}^{H \times d} — 第 hh 个 head 独立的 Q、K、V 投影(实现上拼成大矩阵一次算完,数学上等价)
  • 每个 token 入 cache 的是 nqn_q(k(h),v(h))(\mathbf{k}^{(h)}, \mathbf{v}^{(h)}) — 共 2nqd2 n_q d 个数
  • Llama 2 7B:nq=32,d=128n_q = 32, d = 128,单 token cache = 2321282B=16 KB2 \cdot 32 \cdot 128 \cdot 2\text{B} = 16\text{ KB} / layer

MQA — Multi-Query Attention(Shazeer 2019)

所有 nqn_q 个 Q head 共用同一套 K、V,只剩 1 套:

qi(h)=WQ(h)xi,ki=WKxi,vi=WVxi\mathbf{q}^{(h)}_i = W_Q^{(h)} \mathbf{x}_i, \quad \mathbf{k}_i = W_K \mathbf{x}_i, \quad \mathbf{v}_i = W_V \mathbf{x}_i
head(h)=softmax ⁣(Q(h)Kd+M)V,h=1,,nq\text{head}^{(h)} = \text{softmax}\!\left(\frac{Q^{(h)} K^{\top}}{\sqrt{d}} + M\right) V, \qquad h = 1, \ldots, n_q
  • WK,WVRH×dW_K, W_V \in \mathbb{R}^{H \times d} — 所有 head 共享同一份 K、V 投影
  • cache 收缩到 2d2d,比 MHA 小 nqn_q 倍;但表达能力受限,大模型直接套 MQA 容易掉点 — 工程上很少单独使用,更多被 GQA 替代

GQA — Grouped-Query Attention(Ainslie et al. 2023)

nqn_q 个 Q head 切成 nkvn_{kv} 组,组内共享 K、V — 是 MHA 与 MQA 之间的连续插值:

qi(h)=WQ(h)xi,ki(g)=WK(g)xi,vi(g)=WV(g)xi\mathbf{q}^{(h)}_i = W_Q^{(h)} \mathbf{x}_i, \quad \mathbf{k}^{(g)}_i = W_K^{(g)} \mathbf{x}_i, \quad \mathbf{v}^{(g)}_i = W_V^{(g)} \mathbf{x}_i
head(h)=softmax ⁣(Q(h)K(g(h))d+M)V(g(h)),g(h)=hnkv/nq\text{head}^{(h)} = \text{softmax}\!\left(\frac{Q^{(h)} {K^{(g(h))}}^{\top}}{\sqrt{d}} + M\right) V^{(g(h))}, \qquad g(h) = \lfloor h \cdot n_{kv} / n_q \rfloor
  • WK(g),WV(g)RH×dW_K^{(g)}, W_V^{(g)} \in \mathbb{R}^{H \times d}g=1,,nkvg = 1, \ldots, n_{kv} — 每组一份 K、V 投影
  • g(h)g(h) — 第 hh 个 Q head 所属的组索引
  • nkv=nqn_{kv} = n_q 退化为 MHA,nkv=1n_{kv} = 1 退化为 MQA
  • kernel 里只实际算 nkvn_{kv} 套 K、V,分数计算时把 K、V 沿 group 维 broadcast 到 nqn_q不真的复制张量
  • Llama 3 70B:nq=64,nkv=8,d=128n_q = 64, n_{kv} = 8, d = 128,单 token cache = 281282B=4 KB2 \cdot 8 \cdot 128 \cdot 2\text{B} = 4\text{ KB} / layer — 比同样大小的 MHA 小 8 倍

MLA — Multi-Head Latent Attention(DeepSeek V2 2024)

GQA 只是按 head 数线性缩 cache;MLA 直接把 K、V 共同压成一个低秩潜向量 cKV\mathbf{c}^{KV},再为 RoPE 单独走一条共享的小分支。分成四步看:

(1) 内容分支 — K、V 共享一份 down-projection,入 cache 的只是潜向量 ciKV\mathbf{c}^{KV}_i

ciKV=WDKVxi,kiC,(h)=WKU,(h)ciKV,vi(h)=WVU,(h)ciKV\mathbf{c}^{KV}_i = W^{DKV} \mathbf{x}_i, \quad \mathbf{k}^{C,(h)}_i = W_K^{U,(h)} \mathbf{c}^{KV}_i, \quad \mathbf{v}^{(h)}_i = W_V^{U,(h)} \mathbf{c}^{KV}_i

(2) Q 端同样低秩 — 训练省显存,推理时 Q 不入 cache:

ciQ=WDQxi,qiC,(h)=WQU,(h)ciQ\mathbf{c}^{Q}_i = W^{DQ} \mathbf{x}_i, \quad \mathbf{q}^{C,(h)}_i = W_Q^{U,(h)} \mathbf{c}^{Q}_i

(3) RoPE 解耦分支 — K 侧只算一份共享的 drd_r 维向量,所有 head 共用:

qiR,(h)=RoPE(WQR,(h)ciQ),kiR=RoPE(WKRxi)\mathbf{q}^{R,(h)}_i = \text{RoPE}(W_Q^{R,(h)} \mathbf{c}^{Q}_i), \quad \mathbf{k}^{R}_i = \text{RoPE}(W^{KR} \mathbf{x}_i)

(4) 拼接 + Attention — 内容部分和 RoPE 部分沿 head 维拼起来再算分数:

qi(h)=[qiC,(h);qiR,(h)],ki(h)=[kiC,(h);kiR]\mathbf{q}^{(h)}_i = [\mathbf{q}^{C,(h)}_i; \mathbf{q}^{R,(h)}_i], \quad \mathbf{k}^{(h)}_i = [\mathbf{k}^{C,(h)}_i; \mathbf{k}^{R}_i]
head(h)=softmax ⁣(qi(h)ki(h)d+dr+M)vi(h)\text{head}^{(h)} = \text{softmax}\!\left(\frac{\mathbf{q}^{(h)\top}_i \mathbf{k}^{(h)}_{\le i}}{\sqrt{d + d_r}} + M\right) \mathbf{v}^{(h)}_{\le i}
  • WDKVRH×dcW^{DKV} \in \mathbb{R}^{H \times d_c} — KV 共用 down-projection(DeepSeek V3 取 dc=512d_c = 512
  • WKU,(h),WVU,(h)Rdc×dW_K^{U,(h)}, W_V^{U,(h)} \in \mathbb{R}^{d_c \times d} — 第 hh 个 head 自己的 up-projection
  • WKRRH×drW^{KR} \in \mathbb{R}^{H \times d_r} — 共享的 RoPE K 分支(DeepSeek V3 取 dr=64d_r = 64
  • ciKVRdc,kiRRdr\mathbf{c}^{KV}_i \in \mathbb{R}^{d_c}, \mathbf{k}^{R}_i \in \mathbb{R}^{d_r}MLA 实际缓存的就这两条 — 共 dc+drd_c + d_r 个数
  • DeepSeek V3:nq=128,d=128,dc=512,dr=64n_q = 128, d = 128, d_c = 512, d_r = 64,单 token cache = (512+64)2B=1152(512 + 64) \cdot 2\text{B} = 1152 字节 / layer — 比同等规模 GQA 又小一个量级

为什么 RoPE 要单走一条? 推理时有个”权重折叠”技巧 — 用矩阵结合律把 WKU,(h)W_K^{U,(h)} 折进 WQU,(h)W_Q^{U,(h)}

qiC,(h)kjC,(h)=ciQ(WQU,(h))WKU,(h)推理前预乘cjKV{\mathbf{q}^{C,(h)}_i}^{\top} \mathbf{k}^{C,(h)}_j = {\mathbf{c}^{Q}_i}^{\top} \underbrace{(W_Q^{U,(h)})^{\top} W_K^{U,(h)}}_{\text{推理前预乘}} \mathbf{c}^{KV}_j

这样 K 内容部分根本不用真的重建,attention 直接在缓存的 cKV\mathbf{c}^{KV} 上算。但 RoPE 的旋转角依赖绝对位置 mm,无法折叠进固定权重 — 一旦把 RoPE 加在重建后的 K 上,前面的折叠就失效。DeepSeek 的办法是把 RoPE 拆成独立的 drd_r 维浅分支,让”可折叠的内容部分”和”必须实时旋转的位置部分”互不干扰,KV-cache 压缩和相对位置信号才能同时保留。

对比总览:参数量 / 计算量 / KV cache

下面把每种变体的开销按步骤拆开。约定:单层 attention 子层,prefill 长度 SS,忽略 norm / bias / softmax 等非矩阵乘项;linear 投影 XWX WXRS×m,WRm×nX \in \mathbb{R}^{S \times m}, W \in \mathbb{R}^{m \times n})FLOPs 记 2Smn2Smn(每个 MAC 算 2 FLOPs),attention 的 QKQK^{\top}AVAV 不扣 causal mask 的一半。MLA 额外用到 dqd_q' = Q 端潜向量维度(DeepSeek V3 取 15361536)。

参数量(单层)

步骤MHAMQAGQAMLA
Q 投影HnqdH \cdot n_q dHnqdH \cdot n_q dHnqdH \cdot n_q dHdq+dqnqdH d_q' + d_q' \cdot n_q d
K 投影HnqdH \cdot n_q dHdH \cdot dHnkvdH \cdot n_{kv} dHdc+dcnqdH d_c + d_c \cdot n_q d
V 投影HnqdH \cdot n_q dHdH \cdot dHnkvdH \cdot n_{kv} ddcnqdd_c \cdot n_q d(与 K 共用 WDKVW^{DKV}
RoPE 分支dqnqdr+Hdrd_q' \cdot n_q d_r + H d_r
WOW_OnqdHn_q d \cdot HnqdHn_q d \cdot HnqdHn_q d \cdot HnqdHn_q d \cdot H
合计4Hnqd4 H n_q d2Hnqd+2Hd2 H n_q d + 2 H d2Hnqd+2Hnkvd2 H n_q d + 2 H n_{kv} d上述各行之和

Prefill FLOPs(单层,序列长度 SS

步骤MHAMQAGQAMLA
Q 投影2SHnqd2 S H n_q d2SHnqd2 S H n_q d2SHnqd2 S H n_q d2SHdq+2Sdqnqd2 S H d_q' + 2 S d_q' n_q d
K 投影2SHnqd2 S H n_q d2SHd2 S H d2SHnkvd2 S H n_{kv} d2SHdc+2Sdcnqd2 S H d_c + 2 S d_c n_q d
V 投影2SHnqd2 S H n_q d2SHd2 S H d2SHnkvd2 S H n_{kv} d2Sdcnqd2 S d_c n_q d
RoPE 分支O(Snqd)\mathcal{O}(S n_q d)O(Snqd)\mathcal{O}(S n_q d)O(Snqd)\mathcal{O}(S n_q d)2Sdqnqdr+2SHdr2 S d_q' n_q d_r + 2 S H d_r
QKQK^{\top}2nqS2d2 n_q S^2 d2nqS2d2 n_q S^2 d2nqS2d2 n_q S^2 d2nqS2(d+dr)2 n_q S^2 (d + d_r)
softmax · VV2nqS2d2 n_q S^2 d2nqS2d2 n_q S^2 d2nqS2d2 n_q S^2 d2nqS2d2 n_q S^2 d
WOW_O2SHnqd2 S H n_q d2SHnqd2 S H n_q d2SHnqd2 S H n_q d2SHnqd2 S H n_q d

表里 “RoPE 分支” 在 MHA/MQA/GQA 指的是把旋转矩阵作用在 Q、K 上的逐元素 mul/add,量级 O(Snqd)\mathcal{O}(S n_q d),相对矩阵乘可忽略;在 MLA 里专指那条额外的 WQRW_Q^R / WKRW^{KR} 投影,属于真矩阵乘,不能忽略

KV cache(单 token,单层,fp16)

变体缓存内容字节数
MHAnqn_q(k,v)Rd(\mathbf{k}, \mathbf{v}) \in \mathbb{R}^d2nqd2B2 \cdot n_q \cdot d \cdot 2\text{B}
MQA1 套 (k,v)(\mathbf{k}, \mathbf{v})2d2B2 \cdot d \cdot 2\text{B}
GQAnkvn_{kv}(k,v)(\mathbf{k}, \mathbf{v})2nkvd2B2 \cdot n_{kv} \cdot d \cdot 2\text{B}
MLAcKVRdc\mathbf{c}^{KV} \in \mathbb{R}^{d_c} + kRRdr\mathbf{k}^R \in \mathbb{R}^{d_r}(dc+dr)2B(d_c + d_r) \cdot 2\text{B}

把三张表横着读,能看出三件事:

  1. 变体只动 K/V 一侧。Q 投影、WOW_OAVAV 在四种变体里完全相同 — 真正动手术的只有 K、V 这一支,所以同尺寸下 attention 的总参数和总 FLOPs 不会差到数量级。
  2. MHA → MQA/GQA 是同时省 params、FLOPs、cacheMLA 是用 params/FLOPs 换 cache。GQA 通过减小 nkvn_{kv} 把 K、V 投影、cache 同步线性缩小;MLA 反过来,加了 DKV + 两段 UK/UV + 单独的 RoPE 分支,参数和 prefill FLOPs 跟同尺寸 GQA 同量级(不会显著降低),换来的是单 token cache 从 KB 量级压到 ~1 KB。
  3. S2S^2 项在 MHA/MQA/GQA 三者间没差QKQK^{\top}AVAV 都是 2nqS2d2 n_q S^2 d(K、V 共享只是 broadcast,不省二次项)。也就是说当 SH/dS \gg H / d 时,三种变体的 prefill FLOPs 会趋同 — 它们的差距全部体现在 decode 阶段每生成一 token 要读多少 KV cache,而那是 memory I/O,不是 compute。

Output 投影 + 残差

h=x+Attn(Q,K,V)WO\mathbf{h} = \mathbf{x} + \text{Attn}(Q', K', V') \cdot W_O
  • xRH\mathbf{x} \in \mathbb{R}^{H} — attention 子层的输入(pre-norm 之前的值)
  • Q,K,VQ', K', V' — Q、K 经过 RoPE 旋转后的版本;V=VV' = V(V 不旋转,记号统一而已)
  • WORnqd×HW_O \in \mathbb{R}^{n_q d \times H} — output projection,把多头 concat 后投回 HH
  • hRH\mathbf{h} \in \mathbb{R}^{H} — attention 子层输出 + 残差

FFN 变体

经典双线性 FFN(GPT-2)

FFN(x)=ϕ(xW1)W2\text{FFN}(\mathbf{x}) = \phi(\mathbf{x} W_1) W_2
  • xRH\mathbf{x} \in \mathbb{R}^{H} — FFN 输入
  • W1RH×IW_1 \in \mathbb{R}^{H \times I} — 升维投影
  • W2RI×HW_2 \in \mathbb{R}^{I \times H} — 降维投影
  • ϕ\phi — 逐元素的非线性激活(见下节)

激活函数

FFN 升维之后逐元素套的标量非线性。输入输出都是标量 xx;它在每个 token、每个隐藏维度上独立应用。下面四个函数覆盖了 Transformer 时代用过的所有主流选择。

ReLU(Nair & Hinton 2010)

ReLU(x)=max(0,x)\text{ReLU}(x) = \max(0, x)
  • x>0x > 0 时恒等映射,x0x \le 0 时输出 0
  • 计算最便宜,但负区梯度恒为 0 — “dead neuron” 问题
  • 用于原始 Transformer 和早期 BERT 实现

Sigmoid

σ(x)=11+ex\sigma(x) = \frac{1}{1 + e^{-x}}
  • R\mathbb{R} 压到 (0,1)(0, 1),天然适合做 “门” 信号 — GLU 原版的 ϕ\phi 就是它
  • 单独当 FFN 激活已经基本不用 — 两端饱和导致梯度消失

GeLU(Gaussian Error Linear Unit,Hendrycks & Gimpel 2016)

GeLU(x)=xΦ(x),Φ(x)=12(1+erf(x/2))\text{GeLU}(x) = x \cdot \Phi(x), \quad \Phi(x) = \tfrac{1}{2}\big(1 + \text{erf}(x/\sqrt{2})\big)

实现里常用 OpenAI 给的 tanh 近似(数值上误差 <103< 10^{-3},但省去 erf 调用):

GeLU(x)0.5x(1+tanh ⁣[2π(x+0.044715x3)])\text{GeLU}(x) \approx 0.5\, x \left(1 + \tanh\!\left[\sqrt{\tfrac{2}{\pi}}\left(x + 0.044715\, x^{3}\right)\right]\right)
  • Φ\Phi 是标准正态 CDF — 直观上 “按 xx 的尾部概率加权地放行 xx
  • 处处可导、非单调(在 x<0x < 0 一侧有一段轻微的负向凹陷),比 ReLU 平滑
  • 用于 GPT-2/3、BERT、ViT

SiLU / Swish(Ramachandran et al. 2017)

SiLU(x)=xσ(x)=x1+ex\text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}
  • 形状非常接近 GeLU(同样平滑、非单调、过原点),但表达式更简单 — 不用 erf\text{erf} 也不用三次项
  • 自门控(self-gated):用自身的 sigmoid 控制信号通过率
  • 用于 PaLM、Llama 全系(在 SwiGLU 里担任门控 ϕ\phi

这四个里 ReLU 和 sigmoid 已经退出主流 FFN,活跃的是 GeLU(GPT 时代)和 SiLU(Llama / PaLM 之后)。两者在 bf16 下数值差别 < 1%,论文里挑哪个多是路径依赖;真正的工程拐点是把激活从 “套在投影上”(经典 FFN)换成 “套在门控上”(GLU 家族)。

GLU 家族(Llama、PaLM、Mistral 都在用)

GLU(x)=(ϕ(xWgate)(xWup))Wdown\text{GLU}(\mathbf{x}) = \big(\phi(\mathbf{x} W_{\text{gate}}) \odot (\mathbf{x} W_{\text{up}})\big) W_{\text{down}}
  • xRH\mathbf{x} \in \mathbb{R}^{H} — FFN 输入

  • Wgate,WupRH×IW_{\text{gate}}, W_{\text{up}} \in \mathbb{R}^{H \times I} — 两个独立的升维投影

  • WdownRI×HW_{\text{down}} \in \mathbb{R}^{I \times H} — 降维投影

  • \odot — 逐元素乘

  • ϕ\phi — 门控激活,按选择决定变体名:

    变体ϕ\phi代表模型
    GLUσ\sigmaDauphin et al. 2017 原版
    ReGLUReLU\text{ReLU}
    GeGLUGeLU\text{GeLU}T5 v1.1
    SwiGLUSiLU\text{SiLU}PaLM, Llama 1/2/3

SwiGLU 比经典 GeLU-FFN 多一个投影(三矩阵 vs 两矩阵),为了参数预算对齐,实现里一般把 II 设成 234H\tfrac{2}{3} \cdot 4H4H4H 是 GPT-2 的惯例),Llama 3 8B 的 I=1433623440961.3I = 14336 \approx \tfrac{2}{3}\cdot 4\cdot 4096 \cdot 1.3

Mixture-of-Experts(MoE)

经典 dense FFN 里每个 token 都要过同一对 Wup/WdownW_{\text{up}} / W_{\text{down}} — 参数全用上,FLOPs 也全付。MoE(Shazeer et al. 2017)的核心是把 FFN 复制成 NN 份(“expert”),每个 token 只挑前 kk 份算,总参数线性放大,单 token 激活参数和 FLOPs 几乎不变 — 用容量换知识量,不换 compute。

形式上把 FFN 子层换成:

MoE(x)=iTk(x)gi(x)FFNi(x)\text{MoE}(\mathbf{x}) = \sum_{i \in \mathcal{T}_k(\mathbf{x})} g_i(\mathbf{x}) \cdot \text{FFN}_i(\mathbf{x})

Router(gating network) 决定每个 token 走哪几个 expert:

s(x)=softmax(xWg),Tk(x)=TopK(s(x),k)\mathbf{s}(\mathbf{x}) = \text{softmax}(\mathbf{x} W_g), \quad \mathcal{T}_k(\mathbf{x}) = \text{TopK}(\mathbf{s}(\mathbf{x}), k)
gi(x)=si(x)1[iTk(x)]jTk(x)sj(x)g_i(\mathbf{x}) = \frac{s_i(\mathbf{x}) \cdot \mathbb{1}[i \in \mathcal{T}_k(\mathbf{x})]}{\sum_{j \in \mathcal{T}_k(\mathbf{x})} s_j(\mathbf{x})}
  • xRH\mathbf{x} \in \mathbb{R}^{H} — 当前 token 的隐状态(FFN 子层输入)
  • WgRH×NW_g \in \mathbb{R}^{H \times N} — router 投影,把 hidden state 映到 NN 个 expert 的 logit
  • s(x)RN\mathbf{s}(\mathbf{x}) \in \mathbb{R}^{N} — 所有 expert 的 router 概率
  • Tk(x)\mathcal{T}_k(\mathbf{x}) — 被选中的 top-kk expert 索引集合
  • gi(x)g_i(\mathbf{x}) — combine 权重;Mixtral / DeepSeek 都对 top-kksis_i 再做一次归一化让它们和为 1
  • FFNi\text{FFN}_i — 第 ii 个 expert,结构通常就是 SwiGLU,仅权重不共享

MoE vs 传统 FFN

维度Dense FFN (SwiGLU)MoE (top-kk of NN)
参数量(FFN 段)3HI3 H I3HIN3 H I \cdot N + HNH N (router)
每 token 激活 FLOPs6HI6 H I6HIk6 H I \cdot k + 2HN2 H N (router)
每 token 权重带宽(decode)3HI3 H I bytes3HIk3 H I \cdot k bytes
显存占用3HI3 H I3HIN3 H I \cdot N(必须全装下)
Kernel 形状固定 GEMMgrouped GEMM / token permutation
多卡通信expert-parallel 时跨卡 all-to-all
训练稳定性直接router 易塌缩,要 load-balance

关键 trade-off:容量与 FLOPs 解耦。同样激活参数(即同 FLOPs 预算)下,MoE 可以塞 8–32 × 总参数 — 知识容量上去了。代价是显存(要装下所有 expert)、路由稳定性(避免热门 expert)、推理 batching(每 token 路径不同)。

Load Balance

最朴素的 router 训着训着会”塌缩”到少数热门 expert。GShard / Switch 用辅助损失(auxiliary loss)做软约束:

Laux=αNi=1Nfisˉi\mathcal{L}_{\text{aux}} = \alpha \cdot N \cdot \sum_{i=1}^{N} f_i \cdot \bar{s}_i
  • fif_i — 当前 batch 内被路由到 expert ii 的 token 占比
  • sˉi\bar{s}_i — 当前 batch 内 expert ii 的平均 router 概率
  • α\alpha — 辅助损失权重(Switch 用 0.010.01 量级)
  • 直觉:当 fif_isˉi\bar{s}_i 都大时 = 这个 expert 既被选很多次又信心很高 → 惩罚 → 反向梯度把 router 的这条 logit 推回去

DeepSeek V3 进一步抛弃辅助损失,用无损失 load balance:给每个 expert 维护偏置 bib_i,TopK 选择基于 si+bis_i + b_i;被选少的 bib_i 慢慢调高,被选多的 bib_i 调低 — 只影响选择不污染梯度,避免 aux loss 对主任务造成性能下拉。

现代实现对照

模型总参数 / 激活NN routedtop-kk共享 expert路由特征
Switch Transformer (2021)1.6 T / ~26 B20481第一个能稳定训的大 MoE;硬 top-1 + capacity factor 限流
GLaM (2022)1.2 T / 97 B642Google decoder-only MoE,证明 MoE 推理可比 dense 省一半
Mixtral 8×7B (2023)47 B / 13 B82“8 大 expert” 开源样板;每层独立路由
Mixtral 8×22B (2024)141 B / 39 B828×7B 放大版
Qwen1.5-MoE-A2.7B (2024)14 B / 2.7 B6044阿里第一代细粒度 MoE
DeepSeek V2 (2024)236 B / 21 B16062细粒度 + 共享 expert 范式确立
DeepSeek V3 (2024)671 B / 37 B25681无辅助损失 load balance;与 MLA 组合
Qwen3-MoE 235B-A22B (2025)235 B / 22 B1288DeepSeek 风格细粒度
Llama 4 Scout (2025)109 B / 17 B1611top-1 + 1 共享,极端稀疏
Llama 4 Maverick (2025)400 B / 17 B12811同上,把 expert 数推到 128
GPT-4(社区推测)~1.8 T / ~280 B162从未官方公开;半导体分析师拆解推测

几条规律:

  1. 从粗到细。Mixtral 这代是”少而胖”(8 个接近 dense 大小的 expert);DeepSeek 起把每个 expert 切小、数量推到上百,相当于”多而瘦” — 同样激活 FLOPs 下表达组合数指数级上升。
  2. 共享 expert 成为标配。DeepSeek / Qwen-MoE / Llama 4 都给每层保留 1–2 个”所有 token 必过”的共享 expert,专门吃通用模式;剩下的稀疏 expert 才负责特化。
  3. MoE 总和 KV cache 压缩配套。MoE 把总参数推到几百 B,long-context decode 同样压力大 — DeepSeek V3 同时上 MLA + 细粒度 MoE 不是巧合。
  4. top-kk 走极端。Switch (k=1) → Mixtral (k=2) → DeepSeek V3 (k=8 但 expert 小) → Llama 4 (k=1 + 共享)。top-kk 小好做 batching 和容量约束,但要靠细粒度 / 共享 expert 把表达力补回来。

残差

xout=h+FFN(RMSNorm(h))\mathbf{x}_{\text{out}} = \mathbf{h} + \text{FFN}(\text{RMSNorm}(\mathbf{h}))
  • h\mathbf{h} — attention 子层输出(已含第一道残差)
  • xoutRH\mathbf{x}_{\text{out}} \in \mathbb{R}^{H} — 整层 Transformer 输出,进入下一层
  • RMSNorm 仍在残差分支内(pre-norm),主干 h\mathbf{h} 直接跨过

LM Head + 采样

logits=xfinalWlm,p=softmax(logits/T)\text{logits} = \mathbf{x}_{\text{final}} W_{\text{lm}}, \quad \mathbf{p} = \text{softmax}(\text{logits}/T)
  • xfinalRH\mathbf{x}_{\text{final}} \in \mathbb{R}^{H} — 最后一层 RMSNorm 后、最后一个位置的隐状态
  • WlmRH×VW_{\text{lm}} \in \mathbb{R}^{H \times V} — 输出嵌入矩阵(可与 EE 共享,即 tied embedding)
  • logitsRV\text{logits} \in \mathbb{R}^{V} — 词表上每个 token 的未归一化分数
  • T>0T > 0 — temperature,越大概率分布越平坦(TT \to \infty 趋近均匀,T0T \to 0 趋近 argmax)
  • pRV\mathbf{p} \in \mathbb{R}^{V} — 归一化后的概率分布

采样前通常叠几层 logits 变换:

重复惩罚(repetition / frequency / presence penalty):

logitsv=logitsvα1[vhistory]βcount(v)\text{logits}'_v = \text{logits}_v - \alpha \cdot \mathbb{1}[v \in \text{history}] - \beta \cdot \text{count}(v)
  • vv — 词表中某个 token
  • logitsv\text{logits}_v — 该 token 的原始 logit
  • α0\alpha \ge 0 — presence penalty(只要出现过就扣一次)
  • β0\beta \ge 0 — frequency penalty(按出现次数线性扣)
  • 1[]\mathbb{1}[\cdot] — 指示函数(条件成立为 1,否则为 0)
  • count(v)\text{count}(v) — token vv 在已生成序列中的出现次数

Top-k:只保留最大 kk 个 logits,其它置 -\infty

Top-p(nucleus):按概率降序累加,保留累积概率 p\le p 的集合。

Min-p:保留 pvpminpmaxp_v \ge p_{\min} \cdot p_{\max} 的集合,对低熵分布更友好。

Typical-p:基于与条件熵的偏差做截断,保留 logpvH(p)|-\log p_v - H(\mathbf{p})| 小的集合。

所有截断都作用在概率分布本身之前/之后,不改变公式的骨架

Prefill 阶段维度流转 — S 个 token 一次性走一遍

输入:input_ids [B, S] = [2, 10]。下图是单层 Transformer 的前向,外面再套 32 层;shape 始终保持在 [B, S, H] = [2, 10, 4096],残差边以橙色虚线表示。

× L = 32 层input_ids · 整数张量 [B, S] = [2, 10]Embedding · 查表 E ∈ R^[V, H]x = E[token_id]in [B, S] = [2, 10]out [B, S, H] = [2, 10, 4096]x · [B, S, H] = [2, 10, 4096] ← 残差 #1 起点RMSNorm · pre-normx’ = x / √(mean(x²) + ε) ⊙ γin [B, S, H] = [2, 10, 4096]out [B, S, H] = [2, 10, 4096]Q/K/V 投影 + RoPE · GQA (n_q=32, n_kv=8, d=128)Q = x’·W_Q, K = x’·W_K, V = x’·W_VQ, K ← rotate(m·θ_k) · V 不旋转in [B, S, H] = [2, 10, 4096]out Q [B, n_q, S, d] = [2, 32, 10, 128]out K, V [B, n_kv, S, d] = [2, 8, 10, 128]写入 KV Cache · 一次写入全部 S 个位置cache[l][:, :, 0:S, :] = K, V · shape [B, n_kv, S_max, d]Scaled Dot-Product Attention · causal maskA = softmax(Q · Kᵀ / √d + M) · Vin Q [B, n_q, S, d] = [2, 32, 10, 128]in K, V [B, n_kv, S, d] = [2, 8, 10, 128]scores [B, n_q, S, S] = [2, 32, 10, 10]out A [B, n_q, S, d] = [2, 32, 10, 128]Output Projection · concat 多头 → 投回 Hy = A.reshape(B, S, n_q·d) · W_O, W_O ∈ R^[n_q·d, H]in [B, n_q, S, d] = [2, 32, 10, 128]out [B, S, H] = [2, 10, 4096]+h · [B, S, H] = [2, 10, 4096] ← 残差 #2 起点h = x + (Attn(RMSNorm(x)) · W_O)RMSNorm · pre-normh’ = h / √(mean(h²) + ε) ⊙ γin [B, S, H] = [2, 10, 4096]out [B, S, H] = [2, 10, 4096]FFN · SwiGLU (I = 14336)f = (SiLU(h’·W_gate) ⊙ (h’·W_up)) · W_downin [B, S, H] = [2, 10, 4096]mid [B, S, I] = [2, 10, 14336]out [B, S, H] = [2, 10, 4096]+x_out · [B, S, H] = [2, 10, 4096] → 进入下一层x_out = h + FFN(RMSNorm(h))Final RMSNorm · 32 层之后x_final = x_L / √(mean(x_L²) + ε) ⊙ γ_finalin [B, S, H] = [2, 10, 4096]out [B, S, H] = [2, 10, 4096]只取最后位置 · prefill 只需最后一个 token 的隐状态in [B, S, H] = [2, 10, 4096]out [B, H] = [2, 4096]LM Head · W_lm ∈ R^[H, V]logits = x_final · W_lmin [B, H] = [2, 4096]out [B, V] = [2, 128256]采样 · temperature / top-k / top-pp = softmax(logits / T) → filter (top-k/top-p) → samplenext_token · [B, 1] = [2, 1] ← 第一个输出 tokenattn 残差FFN 残差
Prefill — 每个 step 标注「公式 + 输入 shape → 输出 shape」(蓝色字)。橙色虚线为 pre-norm 残差,⊕ 为残差合并;左侧括号「× 32 层」标注 Transformer block 循环范围。

Prefill 结束后,KV Cache 状态:每层已填入前 10 个位置。

Decode 阶段维度流转 — 第 t 步只走 1 个 token

前置状态:已有 cache_len=S+t1\text{cache\_len} = S + t - 1 个位置。

输入:input_ids [B, 1] = [2, 1](上一步生成的 1 个 token)。

× L = 32 层input_ids · [B, 1] = [2, 1] ← 上一步生成的 1 个 tokenEmbedding · 查表 E ∈ R^[V, H]x = E[token_id]in [B, 1] = [2, 1]out [B, 1, H] = [2, 1, 4096]x · [B, 1, H] = [2, 1, 4096] ← 残差 #1 起点 · S = 1RMSNorm · pre-normx’ = x / √(mean(x²) + ε) ⊙ γin [B, 1, H] = [2, 1, 4096]out [B, 1, H] = [2, 1, 4096]Q/K/V 投影 · 只算 1 个新 token · GQA (n_q=32, n_kv=8, d=128)Q_new = x’·W_Q, K_new = x’·W_K, V_new = x’·W_Vin [B, 1, H] = [2, 1, 4096]out Q_new [B, n_q, 1, d] = [2, 32, 1, 128]out K_new, V_new [B, n_kv, 1, d] = [2, 8, 1, 128]RoPE · pos = cache_lenQ_new, K_new ← rotate by m·θ_k, m = cache_len (shape 不变)写入 KV Cache 的下一个位置cache[l][:, :, cache_len, :] = K_new, V_new · cache_len += 1从 cache 读全部历史 K, VK_full, V_full = cache[l][:, :, :cache_len, :]out K_full [B, n_kv, cache_len, d] = [2, 8, cache_len, 128]out V_full [B, n_kv, cache_len, d] = [2, 8, cache_len, 128]Scaled Dot-Product Attention · 无 causal mask(Q 只有 1 行)A = softmax(Q_new · K_fullᵀ / √d) · V_fullin Q_new [B, n_q, 1, d] = [2, 32, 1, 128]in K_full, V_full [B, n_kv, cache_len, d] = [2, 8, cache_len, 128]scores [B, n_q, 1, cache_len] = [2, 32, 1, cache_len]out A [B, n_q, 1, d] = [2, 32, 1, 128]Output Projection · concat 多头 → 投回 Hy = A.reshape(B, 1, n_q·d) · W_O, W_O ∈ R^[n_q·d, H]in [B, n_q, 1, d] = [2, 32, 1, 128]out [B, 1, H] = [2, 1, 4096]+h · [B, 1, H] = [2, 1, 4096] ← 残差 #2 起点h = x + (Attn(RMSNorm(x)) · W_O)RMSNorm · pre-normh’ = h / √(mean(h²) + ε) ⊙ γin [B, 1, H] = [2, 1, 4096]out [B, 1, H] = [2, 1, 4096]FFN · SwiGLU · 只处理 1 个 token (I = 14336)f = (SiLU(h’·W_gate) ⊙ (h’·W_up)) · W_downin [B, 1, H] = [2, 1, 4096]mid [B, 1, I] = [2, 1, 14336]out [B, 1, H] = [2, 1, 4096]+x_out · [B, 1, H] = [2, 1, 4096] → 进入下一层x_out = h + FFN(RMSNorm(h))Final RMSNorm · 32 层之后x_final = x_L / √(mean(x_L²) + ε) ⊙ γ_finalin [B, 1, H] = [2, 1, 4096]out [B, 1, H] = [2, 1, 4096]LM Head · W_lm ∈ R^[H, V]logits = x_final · W_lmin [B, 1, H] = [2, 1, 4096]out [B, V] = [2, 128256] (squeeze S=1)采样 · temperature / top-k / top-pp = softmax(logits / T) → filter (top-k/top-p) → samplenext_token · [B, 1] = [2, 1] ← 第 t 个输出 tokenattn 残差FFN 残差下一步输入 · t = t + 1
Decode — 第 t 步只走 1 个 token,每个 step 标注「公式 + 输入 shape → 输出 shape」(蓝色字)。橙色虚线为残差;蓝色虚线表示生成的 next_token 回灌为下一步输入;KV Cache 持续累计(attention 处的 K, V 维度随 cache_len 增长)。

Prefill vs Decode 维度对照 — GEMM vs GEMV · 算力 vs 带宽

位置PrefillDecode 每步
input_ids[B,S][B, S][B,1][B, 1]
embedding 后[B,S,H][B, S, H][B,1,H][B, 1, H]
Q[B,nq,S,d][B, n_q, S, d][B,nq,1,d][B, n_q, 1, d]
K_new / V_new[B,nkv,S,d][B, n_{kv}, S, d][B,nkv,1,d][B, n_{kv}, 1, d]
K_full / V_full(从 cache)同 K_new[B,nkv,cache_len,d][B, n_{kv}, \text{cache\_len}, d]
attention scores[B,nq,S,S][B, n_q, S, S][B,nq,1,cache_len][B, n_q, 1, \text{cache\_len}]
attention 输出[B,nq,S,d][B, n_q, S, d][B,nq,1,d][B, n_q, 1, d]
FFN 中间[B,S,I][B, S, I][B,1,I][B, 1, I]
logits[B,V][B, V](取最后位置)[B,V][B, V]
运算性质GEMM(矩 × 矩)GEMV(矩 × 量)
瓶颈算力内存带宽

这张表是理解所有推理加速工作的起点:prefill 像训练的 forward,compute-bound;decode 是一串 GEMV,memory-bound,绝大部分时间在往 SM 里搬权重。两段的优化方向天差地别

KV Cache 的形状与增长 — 单 token 几 KB · 长 context 几百 MB

每层一对 cache:

Kcache,VcacheRB×nkv×Smax×dK_{\text{cache}}, V_{\text{cache}} \in \mathbb{R}^{B \times n_{kv} \times S_{\max} \times d}
  • SmaxS_{\max} — 预分配的最大序列长度(一般 = 模型上下文上限或调度器允许的上限)
  • 其它符号沿用顶部约定(B,nkv,dB, n_{kv}, d);4 个维度的顺序按 PyTorch [B, head, seq, head_dim] 习惯

每个 token、每层的 cache 大小(fp16):

2×nkv×d×2 bytes=2×8×128×2=4 KB2 \times n_{kv} \times d \times 2\ \text{bytes} = 2 \times 8 \times 128 \times 2 = 4\ \text{KB}
  • 最左的 22 — K 和 V 两份
  • 2 bytes2\ \text{bytes} — fp16 每元素 2 字节(用 fp8 / int8 可减半到 1/4)
  • 右侧代入 Llama 3 8B:nkv=8n_{kv} = 8d=128d = 128

每个 token、全模型(32 层):4 KB×32=128 KB/token4\ \text{KB} \times 32 = 128\ \text{KB} / \text{token}。一个 4096-token 请求:128 KB×4096512 MB128\ \text{KB} \times 4096 \approx 512\ \text{MB}

几种工程优化:

  • Paged Attention(vLLM):把 cache 拆成固定大小 block(通常 16 token),用一张 block table 做虚地址到物理地址的映射,消除碎片。对应公式里没变化,只是张量布局和访问模式变了
  • Sliding Window Attention(Mistral):只保留最近 WW 个 token 的 K、V。cache 上限从 SmaxS_{\max} 降到 WW,代价是信息截断,靠层间层叠传递远距离依赖。
  • INT8 / FP8 KV Cache:把 fp16 cache 量化到 int8 甚至 fp8,per-channel 或 per-token 量化,误差可控,cache 占用减半到 1/4。代表工作 KIVI / KVQuant。
  • KV 压缩 / 丢弃(H2O、StreamingLLM、SnapKV):按注意力权重把不重要的位置踢出 cache,上下文超长时用。
  • MLA:前面已经提过,是从模型结构层面改 cache 形状,不是后处理

每步的计算量与内存开销 — Llama 3 8B fp16 · H100 拐点 ~330 FLOPs/byte

前面的维度图只说 shape,不说量级。推理优化 90% 的讨论都在算”这一步要花多少 FLOPs、搬多少字节”,所以这里直接把每步的开销拍成表。

以 Llama 3 8B、fp16、B=1B=1 为基准,prefill 取 S=2048S=2048,decode 取 cache_len=2048\text{cache\_len}=2048(即生成到第 2048 个 token 时的某一步)。

参考硬件拐点:H100 SXM fp16 理论算力 ~989 TFLOPs,HBM 带宽 ~3 TB/s,roofline 拐点 AI330 FLOPs/byte\text{AI}^{*} \approx 330\ \text{FLOPs/byte}。高于它是 compute-bound,低于它是 memory-bound

权重分布

组件Shapefp16 大小全模型(×32 层)
Embedding EE[V,H][V, H]1.0 GB1.0 GB
WQW_Q[H,H][H, H]32 MB1.0 GB
WKW_K[H,nkvd][H, n_{kv}d]8 MB256 MB
WVW_V[H,nkvd][H, n_{kv}d]8 MB256 MB
WOW_O[H,H][H, H]32 MB1.0 GB
WgateW_{\text{gate}}[H,I][H, I]117 MB3.7 GB
WupW_{\text{up}}[H,I][H, I]117 MB3.7 GB
WdownW_{\text{down}}[I,H][I, H]117 MB3.7 GB
RMSNorm γ\boldsymbol{\gamma}(每层 2 份)[H]×2[H]\times 216 KB500 KB
LM head WlmW_{\text{lm}}[H,V][H, V]1.0 GB1.0 GB
合计~432 MB / 层~16 GB

全模型 fp16 权重约 16 GB,每次 forward 的”底价”就是把这 16 GB 从 HBM 里扫一遍。H100 @ 3 TB/s 下 =16/30005.3 ms= 16/3000 \approx 5.3\ \text{ms}这就是单请求 decode 的物理下限

每层每步的计算 / 内存读写

对同一层在 prefill(N=SN=S)和 decode(N=1N=1)下的各个子步做对照。“Weight HBM” 是要从显存搬的权重字节,“KV HBM” 是要读/写的 KV Cache 字节。中间 activation 默认被 kernel 融合,不单独算。

步骤Prefill FLOPs(S=2048)Decode FLOPs(S=1)Weight HBMKV HBM
RMSNorm5BSH5BSH ≈ 42 MF20 KFγ\boldsymbol{\gamma} 8 KB
QprojQ_{\text{proj}}2BSH22BSH^{2} ≈ 68.7 GF33.5 MFWQW_Q 32 MB
KprojK_{\text{proj}}(+写 cache)2BSHnkvd2BSH \cdot n_{kv}d ≈ 17.2 GF8.4 MFWKW_K 8 MBW 4 MB / 2 KB
VprojV_{\text{proj}}(+写 cache)17.2 GF8.4 MFWVW_V 8 MBW 4 MB / 2 KB
RoPE~50 MF25 KF
Attn QKQK^{\top}2BnqNLkd2B n_q N L_k d ≈ 34.4 GF16.8 MFR 4 MB(decode)
softmax~700 MF260 KF
Attn V\cdot V34.4 GF16.8 MFR 4 MB(decode)
WOW_O2BSH22BSH^{2} ≈ 68.7 GF33.5 MFWOW_O 32 MB
RMSNorm42 MF20 KFγ\boldsymbol{\gamma} 8 KB
WgateW_{\text{gate}}2BSHI2BSHI ≈ 241 GF117 MFWgateW_{\text{gate}} 117 MB
WupW_{\text{up}}241 GF117 MFWupW_{\text{up}} 117 MB
SiLU + gate~90 MF45 KF
WdownW_{\text{down}}2BSIH2BSIH ≈ 241 GF117 MFWdownW_{\text{down}} 117 MB
每层合计~960 GFLOPs~470 MFLOPs~432 MBW 8 MB(P)/ R 8 MB(D)

几个直接结论:

  • FFN 是真正的主角Wgate+Wup+WdownW_{\text{gate}} + W_{\text{up}} + W_{\text{down}} 吃掉 ~75% 的 FLOPs 和 ~80% 的权重带宽。MoE、稀疏激活、FFN 量化全都冲着这块去。
  • Attention 的 4 个投影(Q/K/V/O) 占 ~18%,真正的 QKQK^{\top}V\cdot V 只占 ~7% — prefill 里 attention 并不是瓶颈,投影才是
  • Decode 的 KV 读取 每层 8 MB,在 cache_len=2048\text{cache\_len}=2048 时只占权重读取的 ~2%。但当上下文拉到 64K、128K,它会翻几十倍、直接反超权重带宽,成为新瓶颈(这也是 Paged Attention、sliding window、KV 量化出现的原因)。

全模型一次 forward

把 32 层 + embedding + LM head 加起来:

阶段FLOPsHBM I/OArithmetic Intensity瓶颈
Prefill S=2048, B=1~31 TFLOPs~14 GB(权重)+ 256 MB(KV 写)~2200 FLOPs/byte算力
Decode step, cache_len=2048, B=1~15 GFLOPs~14 GB(权重)+ 256 MB(KV 读)~1.05 FLOPs/byte带宽
LM head(prefill 只算最后一位)~1 GFLOP1 GB~1 FLOPs/byte带宽
LM head(decode)~1 GFLOP1 GB~1 FLOPs/byte带宽

Decode 的 1.05 FLOPs/byte 比 H100 拐点 330 低 两个半数量级 — 意味着理想情况下单请求 decode 的算力利用率只有 1.05/3300.3%1.05/330 \approx 0.3\%。这就是 continuous batching 的数学依据:把 BB 拉到 32,同一批权重 read 被 32 个请求共享,arithmetic intensity 直接 ×32,decode 吞吐几乎线性增长,直到 attention 部分或算力先撞墙。

心算口诀

两条规则覆盖 90% 的推理性能估算:

  1. FLOPs ≈ 2PN2 P NPP 是参数量(~8B),NN 是这次 forward 处理的 token 总数。每个参数被每个 token 各用一次 MAC,一次 MAC 算 2 FLOPs。例如 prefill S=2048S=2048: 2×8B×204833 TFLOPs2 \times 8\text{B} \times 2048 \approx 33\ \text{TFLOPs},与分项加总的 31 TFLOPs 吻合。
  2. 权重 HBM I/O ≈ 2P2 P bytes(fp16):一次 forward 就是把模型扫一遍,约 16 GB。

Arithmetic intensity 本质是 2PN2P=N\frac{2 P N}{2 P} = N — forward 里一共参与的 token 数。Prefill 有 SBS \cdot B 个 token,decode 只有 BB 个。这一个数字直接决定了 prefill / decode 瓶颈不同的根源。

计算复杂度总览 — 有 KV Cache vs 没 KV Cache 差三个数量级

Prefill(一次性处理 SS 个 token):

FLOPsO(LSH2)线性层+O(LS2H)attention\text{FLOPs} \sim \underbrace{O(L \cdot S \cdot H^{2})}_{\text{线性层}} + \underbrace{O(L \cdot S^{2} \cdot H)}_{\text{attention}}
  • LSH2L \cdot S \cdot H^{2} — 每层 4 个 H×HH \times H 投影 + FFN 三个 H×IH \times I 投影(I4HI \sim 4H),各作用在 SS 个 token 上
  • LS2HL \cdot S^{2} \cdot H — attention 的 QKQK^{\top}V\cdot V,含 S×SS \times S 的 score 矩阵
  • 短序列下线性层主导;SHS \gtrsim H 后 attention 二次项追上

Decode 每步(处理 1 个 token,历史 cache_len\text{cache\_len}):

FLOPsO(LH2)线性层,恒定+O(Lcache_lenH)attention,随 cache 线性增长\text{FLOPs} \sim \underbrace{O(L \cdot H^{2})}_{\text{线性层,恒定}} + \underbrace{O(L \cdot \text{cache\_len} \cdot H)}_{\text{attention,随 cache 线性增长}}
  • cache_len\text{cache\_len} — 当前已缓存的位置数(=S+t1= S + t - 1
  • 单步只处理 1 个新 token,所以线性层把 SS 替换为 11;attention 仍要扫整段 cache,故随 cache 线性增长

生成 TT 个 token 总复杂度:

FLOPstotalO ⁣(LTH2+LT(S+T)H)\text{FLOPs}_{\text{total}} \sim O\!\left(L \cdot T \cdot H^{2} + L \cdot T \cdot (S + T) \cdot H\right)
  • TT — 生成 token 数(沿用顶部约定)
  • (S+T)(S + T) — 平均 attention 跨距(prompt 长度 + 已生成段)
  • 第一项是 decode 线性层对 TT 步累加;第二项是 attention 部分对 t=1,,Tt = 1, \ldots, T 的求和近似(精确是 t(S+t)\sum_t (S + t)

对比无 KV Cache:O(L(S+T)3)O(L \cdot (S+T)^{3}),差异巨大。

把 FLOPs 和带宽一起看就更直观 — 这正是上一节那张表展示的:prefill 挑算力上限,decode 挑带宽上限;continuous batching 的意义就是把 NN 个请求的 decode 拼成一个大 GEMV,搬权重的成本被 NN 个请求摊薄,吞吐线性上涨直到算力或 attention 部分成为瓶颈。

工程优化怎么嵌进公式 — Flash Attn / Spec Decode / Continuous Batch

Flash Attention:数学上完全等价于标准 attention,公式一字不变。工程上把 softmax 和 matmul 融成一个 kernel,按 block 流式更新 softmax 的运行统计量(max、sum),避免把 S×SS \times S 的 attention 矩阵写回 HBM。复杂度不变,显存占用从 O(S2)O(S^{2}) 降到 O(S)O(S),速度提升主要来自 HBM 访问减少。FA-2 把切分粒度从 head 改到 query block;FA-3 在 H100/H200 上叠加 warpgroup MMA + producer-consumer 异步流水。

Flash Decoding:Flash Attention 在 decode 时 Q 只有 1 行,kernel 并行度不够。Flash Decoding 把 K、V 的 cache_len\text{cache\_len} 维切成多块并行,再做一次 log-sum-exp 归约。公式还是同一个 softmax,只是拆成两段算

Speculative Decoding:用一个小的 draft 模型连续生成 kk 个 token,再用大模型一次 prefill 那 kk 个位置做验证。接受规则是

accept with prob min ⁣(1,ptarget(x)pdraft(x))\text{accept with prob } \min\!\left(1, \frac{p_{\text{target}}(x)}{p_{\text{draft}}(x)}\right)
  • xx — draft 模型生成的某个候选 token
  • ptarget(x)p_{\text{target}}(x) — 大(target)模型在该位置上对 xx 的概率
  • pdraft(x)p_{\text{draft}}(x) — 小(draft)模型同位置对 xx 的概率
  • 该规则配合”拒绝后从 max(0,ptargetpdraft)\max(0, p_{\text{target}} - p_{\text{draft}}) 重采”可证明最终采样分布严格等于 target 模型直采,无质量损失

关键是把 decode 的 kk 次 GEMV 合并成一次 kk 长度的 GEMM,把大模型压在 memory-bound 边界的时间重新变成 compute-bound。期望每步接受 kˉ\bar k 个 token,吞吐放大 kˉ\bar k 倍(减去 draft 开销)。变种:Medusa(多头预测)、EAGLE(特征级 draft)、Lookahead Decoding(无 draft 模型)。

Continuous Batching(vLLM、TGI):不在 prefill 边界 pad 到齐,而是请求级的 step 调度。每个 step 选一批正处于同一 phase(prefill 或 decode)的请求拼 batch,完成一个释放一个。数学上各请求互不影响,只是编排顺序变了。原始论文是 OSDI’22 的 Orca。

Chunked Prefill:把长 prompt 的 prefill 拆成多段,和 decode 请求混进同一个 step,减少 decode 请求的等待抖动。公式无变化。SARATHI / DistServe 等系统的核心调度原语。

一句话串起整个过程 — 从 token ID 到下一个 token

输入 token ID → 查 embedding → 过 LL 层(Pre-RMSNorm → Attention 带 RoPE → 残差 → Pre-RMSNorm → SwiGLU FFN → 残差) → Final RMSNorm → LM Head → logits → 采样。Prefill 时 SS 个 token 并行走一遍,产物是第一个 token + 完整 KV Cache;Decode 时每步只输入 1 个 token,attention 处从 cache 读历史 K、V,其他所有操作都是逐 token 独立的

关键工程不变量 — 值得背下来的 6 条

  1. 主线张量 shape 始终是 [B,Scurrent,H][B, S_{\text{current}}, H]。残差结构保证维度不变,一旦你在某一步看到 HH 变了,要么是在 attention 内部展成多头,要么是在 FFN 内部升到 II出子层就回到 HH
  2. K、V 一旦算出就不再变。因为它们只是线性投影 WK,WVW_K, W_V 作用在已经定了的输入 x\mathbf{x} 上,而 causal 结构保证更晚的位置不会反过来改更早位置的表示。这是 KV Cache 成立的数学基础。
  3. Attention 是唯一跨 token 的操作,其他所有操作(norm、投影、FFN、激活)都逐 token 独立。所以只需要缓存跨 token 操作需要的 K、V;其他都可以即时算完扔掉。
  4. Decode 的 scores shape 是 [B,nq,1,cache_len][B, n_q, 1, \text{cache\_len}]。“1” 是 Q 侧(当前新 token),cache_len\text{cache\_len} 维在和 VV 做加权求和时被消掉,输出又回到 1 行。
  5. Decode 每步非 attention 部分的计算量恒定,只有 attention 随 cache 长度线性增长。所以真正”越生成越慢”的本质,是 attention 的 cache_len\text{cache\_len} 越来越长,加上 KV Cache 把 memory footprint 顶到 HBM 带宽上限。
  6. Prefill 用 GEMM,decode 用 GEMV。这一字之差决定了所有推理引擎都要做两套 kernel、两套调度策略。理解这一条,后面看任何推理优化论文都不会迷路。

如果把这十条装进脑子里,再去读 Flash Attention、PagedAttention、MLA、投机解码这些论文,会发现它们都是在这张骨架的某个位置做局部优化,而骨架本身二十年没怎么变过。

参考资料 — 公式 · 论文 · 工程 blog

架构与核心算子

位置编码

  • Su et al., “RoFormer: Enhanced Transformer with Rotary Position Embedding” (2021) — RoPE。arxiv.org/abs/2104.09864
  • Peng et al., “YaRN: Efficient Context Window Extension of Large Language Models” (2023) — 长上下文 RoPE 缩放。arxiv.org/abs/2309.00071
  • bloc97 & emozilla, “NTK-Aware Scaled RoPE” — 早期 NTK-aware 工作的开源实现讨论。reddit / LocalLLaMA

多头变体(KV 共享 / 压缩)

  • Shazeer, “Fast Transformer Decoding: One Write-Head is All You Need” (2019) — MQA。arxiv.org/abs/1911.02150
  • Ainslie et al., “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints” (EMNLP 2023) — GQA。arxiv.org/abs/2305.13245
  • DeepSeek-AI, “DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model” (2024) — MLA 首次提出。arxiv.org/abs/2405.04434
  • DeepSeek-AI, “DeepSeek-V3 Technical Report” (2024) — MLA + MoE 工程化。arxiv.org/abs/2412.19437

Mixture-of-Experts

  • Shazeer et al., “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer” (ICLR 2017) — top-kk gated MoE 在深度网络里的奠基论文。arxiv.org/abs/1701.06538
  • Lepikhin et al., “GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding” (ICLR 2021) — 加 aux loss 的标准 load-balance 实现。arxiv.org/abs/2006.16668
  • Fedus et al., “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity” (JMLR 2022) — top-1 + capacity factor。arxiv.org/abs/2101.03961
  • Du et al., “GLaM: Efficient Scaling of Language Models with Mixture-of-Experts” (ICML 2022) — Google 1.2T MoE。arxiv.org/abs/2112.06905
  • Jiang et al., “Mixtral of Experts” (Mistral AI, 2024) — Mixtral 8×7B 技术报告。arxiv.org/abs/2401.04088
  • DeepSeek-AI, “DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Models” (2024) — 细粒度 + 共享 expert 范式。arxiv.org/abs/2401.06066
  • Wang et al., “Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts” (2024) — DeepSeek V3 的无损失负载均衡。arxiv.org/abs/2408.15664
  • Meta AI, “The Llama 4 herd” (2025) — Llama 4 Scout / Maverick / Behemoth 技术细节。ai.meta.com/blog/llama-4

Flash Attention 系列

  • Dao et al., “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness” (NeurIPS 2022) — FA-1。arxiv.org/abs/2205.14135
  • Dao, “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning” (2023) — FA-2。arxiv.org/abs/2307.08691
  • Shah et al., “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-Precision” (NeurIPS 2024) — FA-3 / Hopper。arxiv.org/abs/2407.08608
  • Dao et al., “Flash-Decoding for long-context inference” (Stanford / Together blog, 2023) — Flash Decoding。crfm.stanford.edu

推理引擎与服务调度

Speculative Decoding 家族

  • Leviathan, Kalman, Matias, “Fast Inference from Transformers via Speculative Decoding” (ICML 2023)。arxiv.org/abs/2211.17192
  • Chen et al., “Accelerating Large Language Model Decoding with Speculative Sampling” (DeepMind, 2023) — 同期独立工作。arxiv.org/abs/2302.01318
  • Cai et al., “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads” (2024)。arxiv.org/abs/2401.10774
  • Li et al., “EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty” (ICML 2024)。arxiv.org/abs/2401.15077
  • Fu et al., “Lookahead Decoding: Breaking the Sequential Dependency of LLM Inference” (2024)。arxiv.org/abs/2402.02057

KV Cache 压缩 / 量化

  • Xiao et al., “Efficient Streaming Language Models with Attention Sinks” (ICLR 2024) — StreamingLLM。arxiv.org/abs/2309.17453
  • Zhang et al., “H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models” (NeurIPS 2023)。arxiv.org/abs/2306.14048
  • Li et al., “SnapKV: LLM Knows What You are Looking for Before Generation” (2024)。arxiv.org/abs/2404.14469
  • Liu et al., “KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache” (ICML 2024)。arxiv.org/abs/2402.02750
  • Hooper et al., “KVQuant: Towards 10 Million Context Length LLM Inference with KV Cache Quantization” (NeurIPS 2024)。arxiv.org/abs/2401.18079

代表性开源模型技术报告

硬件 / Roofline

  • NVIDIA, “H100 Tensor Core GPU Architecture Whitepaper” (2022)。resources.nvidia.com
  • Williams, Waterman, Patterson, “Roofline: An Insightful Visual Performance Model for Multicore Architectures” (CACM 2009) — Roofline 模型原始论文。dl.acm.org
  • Hoffmann et al., “Training Compute-Optimal Large Language Models” (Chinchilla, 2022) — 2PN2PN FLOPs 经验法则的训练侧版本。arxiv.org/abs/2203.15556

其他长文 / 教程

  • Lilian Weng, “The Transformer Family Version 2.0”。lilianweng.github.io
  • Horace He, “Making Deep Learning Go Brrrr From First Principles” — compute / memory / overhead 三类瓶颈。horace.io/brrr_intro
  • Adam Casson, “Transformer Inference Arithmetic” — 推理 FLOPs / KV cache 心算。kipp.ly
  • Jay Mody, “LLM inference, in detail” — 张量 shape 流转的另一种讲法。jaykmody.com