LLM 推理全过程的维度变化与核心公式
把 Llama 3 风格的 dense decoder-only 模型从 embedding 到采样的整条路径一次写清楚,顺便把常见变体(MHA / MQA / GQA / MLA、RMSNorm / LayerNorm、SwiGLU / GeGLU、Flash Attention / Paged Attention)的数学位置都嵌进去 — 读完之后自己能把这张图默写出来。整篇围绕一个事实展开:这套结构二十年没怎么变过,所有”推理优化”都是在同一张骨架的某个位置做局部手术。
参数符号约定 — Llama 3 8B 为基准
全篇沿用同一套符号,举例用 Llama 3 8B:
| 符号 | 含义 | 示例值(Llama 3 8B) |
|---|---|---|
| batch size | 2 | |
| prompt 长度 | 10 | |
| 层数 | 32 | |
| hidden dim | 4096 | |
| 词表大小 | 128256 | |
| Q 头数 | 32 | |
| KV 头数(GQA) | 8 | |
| 每头维度 | 128 | |
| FFN 中间维度 | 14336 | |
| 生成 token 数 | 100 | |
| decode 当前步数 |
所有 shape 标注都按 PyTorch 习惯写成 [B, ..., H];权重矩阵按”输入维 × 输出维”约定写成 。
核心公式速查 — Embedding · Norm · Attn · FFN · LM Head
Embedding
- — 位置 的 embedding 向量
- — embedding 查找表(V 个 token,每个 H 维)
- — 第 个输入 token 的整数 ID
许多实现把 与 LM Head 的 共享(tied embedding),省显存也略微正则;Llama 系列默认不共享。
归一化:LayerNorm vs RMSNorm
标准 LayerNorm:
- — 单 token 的隐状态向量
- — 在 维上算的均值和方差
- — 防除零小常数(通常 )
- — 可学习的 scale / shift
- — 逐元素乘
RMSNorm(Llama / Mistral / Qwen 主流选择):
- — 单 token 的隐状态向量
- — 可学习的 per-channel scale(无 )
- — 防除零小常数
- 分母是 的均方根(root mean square),故得名 RMSNorm
RMSNorm 省了均值、也省了 ,计算量和参数都约减半;实证发现对质量几乎无损。所有主流推理引擎都按 pre-norm 组织:norm 在残差分支内部,残差主干不经过 norm。
Q/K/V 投影 + 位置编码
- — 上一步 RMSNorm 的输出
- — Q 投影权重
- — K、V 投影权重(GQA 下 ,所以更窄)
- , — 投影输出,会再 reshape 出 head 维
RoPE(Rotary Positional Embedding)的核心想法是:与其把”绝对位置 “加到 embedding 上(像原始 Transformer 的 sinusoidal),不如让位置以旋转的形式作用在 Q、K 上,使两个向量做内积时只剩下相对位置 。
把每个头的 维切成 个二维子空间,第 个子空间()对应坐标 。位置 上对这一对乘旋转角 的 2D 旋转矩阵:
- — 当前 token 的绝对位置
- — 二维子空间索引(在单头维度 上每两维一组)
- — 投影后 Q 向量的第 个 2D pair
- — 在位置 旋转后的 pair
- — 第 个子空间的基础角速度
- — 频率衰减基(Llama 默认 ;YaRN 等会动态拉大)
展开成标量:
符号同上式 — 这只是把 2×2 矩阵展开成两条标量等式,方便对照实现。
这就是标准的逆时针旋转矩阵 ,作用在 2D 向量上等价于在复平面上乘 — 也是为什么 Llama 官方 repo 直接把 reshape 成 complex64 然后乘 (HuggingFace 则等价地拆成 的半旋转形式,二者只差一个坐标排列)。K 用同样的角度做同样的旋转,得到 。
为什么旋转能编码相对位置? 把整块 维旋转记作 (分块对角,第 个 块用角度 )。它正交,且 ,因此:
- — 单头的 Q、K 向量(未做位置旋转)
- — 位置 、 对应的分块对角旋转矩阵
- — 标准内积
- 最后一步用了 :旋转矩阵正交且角度可加
attention 算 时,每对 的分数只依赖差 — 绝对位置在内积里被自动消掉,留下相对位置。这是 RoPE 的关键性质,也是它比加性位置编码更稳的根本原因。
频率谱设计:( 一般取 ; 是单头维度,不是模型 hidden dim)让 个子空间分到从快到慢的角速度:
- :,周期 token,承担”近邻”信号。
- :,周期 token,承担”远距离”信号。
几何级数分布让一个 head 内同时携带不同尺度的位置信号 — 与原始 Transformer 的 sinusoidal 同构,只是从加法搬到了乘法。
长上下文缩放:训练只见过 ,低频子空间在 内甚至跑不完一个周期;推理一旦 ,低频角度进入训练分布外,attention 立刻退化。三类常见解法都在改 :
- Position Interpolation(Chen et al. 2023):,等价 同比例压缩所有频率。简单但牺牲高频精度。
- NTK-aware scaling:只缩放低频、保留高频,等价于 。
- YaRN(Peng et al. 2023):分频段处理 — 高频(周期 ,已学过完整周期)不变,低频(周期 ,没学过完整周期)按 PI 缩放,中间段平滑过渡,再叠加一个 的温度修正抵消缩放带来的注意力熵漂移。Llama 3.1 / 3.2 从 8K → 128K 用的就是 YaRN。
RoPE 只作用于 Q、K,不作用于 V — V 是被加权求和的值本身,不需要位置信号。
Scaled Dot-Product Attention
- — 投影并 reshape 出 head 维后的张量,最后两维为 (多头维隐式 broadcast)
- — 每对 的相似度矩阵
- — 缩放因子,防止 softmax 进入梯度近乎为零的区域
- — causal mask,上三角为 (位置 只能看 的位置)
- softmax 沿 K 维归一化,输出 attention 权重,再与 V 加权求和回到
多头变体:MHA / MQA / GQA / MLA
四种变体的差别只在 K、V 这一支:Q 永远是 个独立头,变的是有多少套独立 K、V,以及 K、V 是否被低秩压缩。下面用同一套记号一遍把原版 MHA 和三种变体的 head-level 公式写齐 — 省略 batch 和 layer 维,看位置 上第 个 head。
| 变体 | 单 token cache(per layer,fp16) | 代表模型 | |
|---|---|---|---|
| MHA | GPT-2/3, Llama 1/2 7B | ||
| GQA | 分组 | Llama 3, Qwen2, Mistral | |
| MQA | PaLM, Falcon | ||
| MLA | 低秩压缩 | DeepSeek V2/V3 |
MHA — Multi-Head Attention(原版,Vaswani et al. 2017)
每个 Q head 都配一套独立的 K、V, 套全独立:
- — 第 个 head 独立的 Q、K、V 投影(实现上拼成大矩阵一次算完,数学上等价)
- 每个 token 入 cache 的是 套 — 共 个数
- Llama 2 7B:,单 token cache = / layer
MQA — Multi-Query Attention(Shazeer 2019)
所有 个 Q head 共用同一套 K、V,只剩 1 套:
- — 所有 head 共享同一份 K、V 投影
- cache 收缩到 ,比 MHA 小 倍;但表达能力受限,大模型直接套 MQA 容易掉点 — 工程上很少单独使用,更多被 GQA 替代
GQA — Grouped-Query Attention(Ainslie et al. 2023)
把 个 Q head 切成 组,组内共享 K、V — 是 MHA 与 MQA 之间的连续插值:
- , — 每组一份 K、V 投影
- — 第 个 Q head 所属的组索引
- 退化为 MHA, 退化为 MQA
- kernel 里只实际算 套 K、V,分数计算时把 K、V 沿 group 维 broadcast 到 ,不真的复制张量
- Llama 3 70B:,单 token cache = / layer — 比同样大小的 MHA 小 8 倍
MLA — Multi-Head Latent Attention(DeepSeek V2 2024)
GQA 只是按 head 数线性缩 cache;MLA 直接把 K、V 共同压成一个低秩潜向量 ,再为 RoPE 单独走一条共享的小分支。分成四步看:
(1) 内容分支 — K、V 共享一份 down-projection,入 cache 的只是潜向量 :
(2) Q 端同样低秩 — 训练省显存,推理时 Q 不入 cache:
(3) RoPE 解耦分支 — K 侧只算一份共享的 维向量,所有 head 共用:
(4) 拼接 + Attention — 内容部分和 RoPE 部分沿 head 维拼起来再算分数:
- — KV 共用 down-projection(DeepSeek V3 取 )
- — 第 个 head 自己的 up-projection
- — 共享的 RoPE K 分支(DeepSeek V3 取 )
- — MLA 实际缓存的就这两条 — 共 个数
- DeepSeek V3:,单 token cache = 字节 / layer — 比同等规模 GQA 又小一个量级
为什么 RoPE 要单走一条? 推理时有个”权重折叠”技巧 — 用矩阵结合律把 折进 :
这样 K 内容部分根本不用真的重建,attention 直接在缓存的 上算。但 RoPE 的旋转角依赖绝对位置 ,无法折叠进固定权重 — 一旦把 RoPE 加在重建后的 K 上,前面的折叠就失效。DeepSeek 的办法是把 RoPE 拆成独立的 维浅分支,让”可折叠的内容部分”和”必须实时旋转的位置部分”互不干扰,KV-cache 压缩和相对位置信号才能同时保留。
对比总览:参数量 / 计算量 / KV cache
下面把每种变体的开销按步骤拆开。约定:单层 attention 子层,prefill 长度 ,忽略 norm / bias / softmax 等非矩阵乘项;linear 投影 ()FLOPs 记 (每个 MAC 算 2 FLOPs),attention 的 和 不扣 causal mask 的一半。MLA 额外用到 = Q 端潜向量维度(DeepSeek V3 取 )。
参数量(单层)
| 步骤 | MHA | MQA | GQA | MLA |
|---|---|---|---|---|
| Q 投影 | ||||
| K 投影 | ||||
| V 投影 | (与 K 共用 ) | |||
| RoPE 分支 | — | — | — | |
| 合计 | 上述各行之和 |
Prefill FLOPs(单层,序列长度 )
| 步骤 | MHA | MQA | GQA | MLA |
|---|---|---|---|---|
| Q 投影 | ||||
| K 投影 | ||||
| V 投影 | ||||
| RoPE 分支 | ||||
| softmax · | ||||
表里 “RoPE 分支” 在 MHA/MQA/GQA 指的是把旋转矩阵作用在 Q、K 上的逐元素 mul/add,量级 ,相对矩阵乘可忽略;在 MLA 里专指那条额外的 / 投影,属于真矩阵乘,不能忽略。
KV cache(单 token,单层,fp16)
| 变体 | 缓存内容 | 字节数 |
|---|---|---|
| MHA | 套 | |
| MQA | 1 套 | |
| GQA | 套 | |
| MLA | + |
把三张表横着读,能看出三件事:
- 变体只动 K/V 一侧。Q 投影、、 在四种变体里完全相同 — 真正动手术的只有 K、V 这一支,所以同尺寸下 attention 的总参数和总 FLOPs 不会差到数量级。
- MHA → MQA/GQA 是同时省 params、FLOPs、cache;MLA 是用 params/FLOPs 换 cache。GQA 通过减小 把 K、V 投影、cache 同步线性缩小;MLA 反过来,加了 DKV + 两段 UK/UV + 单独的 RoPE 分支,参数和 prefill FLOPs 跟同尺寸 GQA 同量级(不会显著降低),换来的是单 token cache 从 KB 量级压到 ~1 KB。
- 项在 MHA/MQA/GQA 三者间没差。 和 都是 (K、V 共享只是 broadcast,不省二次项)。也就是说当 时,三种变体的 prefill FLOPs 会趋同 — 它们的差距全部体现在 decode 阶段每生成一 token 要读多少 KV cache,而那是 memory I/O,不是 compute。
Output 投影 + 残差
- — attention 子层的输入(pre-norm 之前的值)
- — Q、K 经过 RoPE 旋转后的版本;(V 不旋转,记号统一而已)
- — output projection,把多头 concat 后投回 维
- — attention 子层输出 + 残差
FFN 变体
经典双线性 FFN(GPT-2)
- — FFN 输入
- — 升维投影
- — 降维投影
- — 逐元素的非线性激活(见下节)
激活函数
FFN 升维之后逐元素套的标量非线性。输入输出都是标量 ;它在每个 token、每个隐藏维度上独立应用。下面四个函数覆盖了 Transformer 时代用过的所有主流选择。
ReLU(Nair & Hinton 2010)
- 时恒等映射, 时输出 0
- 计算最便宜,但负区梯度恒为 0 — “dead neuron” 问题
- 用于原始 Transformer 和早期 BERT 实现
Sigmoid
- 把 压到 ,天然适合做 “门” 信号 — GLU 原版的 就是它
- 单独当 FFN 激活已经基本不用 — 两端饱和导致梯度消失
GeLU(Gaussian Error Linear Unit,Hendrycks & Gimpel 2016)
实现里常用 OpenAI 给的 tanh 近似(数值上误差 ,但省去 erf 调用):
- 是标准正态 CDF — 直观上 “按 的尾部概率加权地放行 ”
- 处处可导、非单调(在 一侧有一段轻微的负向凹陷),比 ReLU 平滑
- 用于 GPT-2/3、BERT、ViT
SiLU / Swish(Ramachandran et al. 2017)
- 形状非常接近 GeLU(同样平滑、非单调、过原点),但表达式更简单 — 不用 也不用三次项
- 自门控(self-gated):用自身的 sigmoid 控制信号通过率
- 用于 PaLM、Llama 全系(在 SwiGLU 里担任门控 )
这四个里 ReLU 和 sigmoid 已经退出主流 FFN,活跃的是 GeLU(GPT 时代)和 SiLU(Llama / PaLM 之后)。两者在 bf16 下数值差别 < 1%,论文里挑哪个多是路径依赖;真正的工程拐点是把激活从 “套在投影上”(经典 FFN)换成 “套在门控上”(GLU 家族)。
GLU 家族(Llama、PaLM、Mistral 都在用)
-
— FFN 输入
-
— 两个独立的升维投影
-
— 降维投影
-
— 逐元素乘
-
— 门控激活,按选择决定变体名:
变体 代表模型 GLU Dauphin et al. 2017 原版 ReGLU — GeGLU T5 v1.1 SwiGLU PaLM, Llama 1/2/3
SwiGLU 比经典 GeLU-FFN 多一个投影(三矩阵 vs 两矩阵),为了参数预算对齐,实现里一般把 设成 ( 是 GPT-2 的惯例),Llama 3 8B 的 。
Mixture-of-Experts(MoE)
经典 dense FFN 里每个 token 都要过同一对 — 参数全用上,FLOPs 也全付。MoE(Shazeer et al. 2017)的核心是把 FFN 复制成 份(“expert”),每个 token 只挑前 份算,总参数线性放大,单 token 激活参数和 FLOPs 几乎不变 — 用容量换知识量,不换 compute。
形式上把 FFN 子层换成:
Router(gating network) 决定每个 token 走哪几个 expert:
- — 当前 token 的隐状态(FFN 子层输入)
- — router 投影,把 hidden state 映到 个 expert 的 logit
- — 所有 expert 的 router 概率
- — 被选中的 top- expert 索引集合
- — combine 权重;Mixtral / DeepSeek 都对 top- 的 再做一次归一化让它们和为 1
- — 第 个 expert,结构通常就是 SwiGLU,仅权重不共享
MoE vs 传统 FFN
| 维度 | Dense FFN (SwiGLU) | MoE (top- of ) |
|---|---|---|
| 参数量(FFN 段) | + (router) | |
| 每 token 激活 FLOPs | + (router) | |
| 每 token 权重带宽(decode) | bytes | bytes |
| 显存占用 | (必须全装下) | |
| Kernel 形状 | 固定 GEMM | grouped GEMM / token permutation |
| 多卡通信 | — | expert-parallel 时跨卡 all-to-all |
| 训练稳定性 | 直接 | router 易塌缩,要 load-balance |
关键 trade-off:容量与 FLOPs 解耦。同样激活参数(即同 FLOPs 预算)下,MoE 可以塞 8–32 × 总参数 — 知识容量上去了。代价是显存(要装下所有 expert)、路由稳定性(避免热门 expert)、推理 batching(每 token 路径不同)。
Load Balance
最朴素的 router 训着训着会”塌缩”到少数热门 expert。GShard / Switch 用辅助损失(auxiliary loss)做软约束:
- — 当前 batch 内被路由到 expert 的 token 占比
- — 当前 batch 内 expert 的平均 router 概率
- — 辅助损失权重(Switch 用 量级)
- 直觉:当 和 都大时 = 这个 expert 既被选很多次又信心很高 → 惩罚 → 反向梯度把 router 的这条 logit 推回去
DeepSeek V3 进一步抛弃辅助损失,用无损失 load balance:给每个 expert 维护偏置 ,TopK 选择基于 ;被选少的 慢慢调高,被选多的 调低 — 只影响选择不污染梯度,避免 aux loss 对主任务造成性能下拉。
现代实现对照
| 模型 | 总参数 / 激活 | routed | top- | 共享 expert | 路由特征 |
|---|---|---|---|---|---|
| Switch Transformer (2021) | 1.6 T / ~26 B | 2048 | 1 | — | 第一个能稳定训的大 MoE;硬 top-1 + capacity factor 限流 |
| GLaM (2022) | 1.2 T / 97 B | 64 | 2 | — | Google decoder-only MoE,证明 MoE 推理可比 dense 省一半 |
| Mixtral 8×7B (2023) | 47 B / 13 B | 8 | 2 | — | “8 大 expert” 开源样板;每层独立路由 |
| Mixtral 8×22B (2024) | 141 B / 39 B | 8 | 2 | — | 8×7B 放大版 |
| Qwen1.5-MoE-A2.7B (2024) | 14 B / 2.7 B | 60 | 4 | 4 | 阿里第一代细粒度 MoE |
| DeepSeek V2 (2024) | 236 B / 21 B | 160 | 6 | 2 | 细粒度 + 共享 expert 范式确立 |
| DeepSeek V3 (2024) | 671 B / 37 B | 256 | 8 | 1 | 无辅助损失 load balance;与 MLA 组合 |
| Qwen3-MoE 235B-A22B (2025) | 235 B / 22 B | 128 | 8 | — | DeepSeek 风格细粒度 |
| Llama 4 Scout (2025) | 109 B / 17 B | 16 | 1 | 1 | top-1 + 1 共享,极端稀疏 |
| Llama 4 Maverick (2025) | 400 B / 17 B | 128 | 1 | 1 | 同上,把 expert 数推到 128 |
| GPT-4(社区推测) | ~1.8 T / ~280 B | 16 | 2 | — | 从未官方公开;半导体分析师拆解推测 |
几条规律:
- 从粗到细。Mixtral 这代是”少而胖”(8 个接近 dense 大小的 expert);DeepSeek 起把每个 expert 切小、数量推到上百,相当于”多而瘦” — 同样激活 FLOPs 下表达组合数指数级上升。
- 共享 expert 成为标配。DeepSeek / Qwen-MoE / Llama 4 都给每层保留 1–2 个”所有 token 必过”的共享 expert,专门吃通用模式;剩下的稀疏 expert 才负责特化。
- MoE 总和 KV cache 压缩配套。MoE 把总参数推到几百 B,long-context decode 同样压力大 — DeepSeek V3 同时上 MLA + 细粒度 MoE 不是巧合。
- top- 走极端。Switch (k=1) → Mixtral (k=2) → DeepSeek V3 (k=8 但 expert 小) → Llama 4 (k=1 + 共享)。top- 小好做 batching 和容量约束,但要靠细粒度 / 共享 expert 把表达力补回来。
残差
- — attention 子层输出(已含第一道残差)
- — 整层 Transformer 输出,进入下一层
- RMSNorm 仍在残差分支内(pre-norm),主干 直接跨过
LM Head + 采样
- — 最后一层 RMSNorm 后、最后一个位置的隐状态
- — 输出嵌入矩阵(可与 共享,即 tied embedding)
- — 词表上每个 token 的未归一化分数
- — temperature,越大概率分布越平坦( 趋近均匀, 趋近 argmax)
- — 归一化后的概率分布
采样前通常叠几层 logits 变换:
重复惩罚(repetition / frequency / presence penalty):
- — 词表中某个 token
- — 该 token 的原始 logit
- — presence penalty(只要出现过就扣一次)
- — frequency penalty(按出现次数线性扣)
- — 指示函数(条件成立为 1,否则为 0)
- — token 在已生成序列中的出现次数
Top-k:只保留最大 个 logits,其它置 。
Top-p(nucleus):按概率降序累加,保留累积概率 的集合。
Min-p:保留 的集合,对低熵分布更友好。
Typical-p:基于与条件熵的偏差做截断,保留 小的集合。
所有截断都作用在概率分布本身之前/之后,不改变公式的骨架。
Prefill 阶段维度流转 — S 个 token 一次性走一遍
输入:input_ids [B, S] = [2, 10]。下图是单层 Transformer 的前向,外面再套 32 层;shape 始终保持在 [B, S, H] = [2, 10, 4096],残差边以橙色虚线表示。
Prefill 结束后,KV Cache 状态:每层已填入前 10 个位置。
Decode 阶段维度流转 — 第 t 步只走 1 个 token
前置状态:已有 个位置。
输入:input_ids [B, 1] = [2, 1](上一步生成的 1 个 token)。
Prefill vs Decode 维度对照 — GEMM vs GEMV · 算力 vs 带宽
| 位置 | Prefill | Decode 每步 |
|---|---|---|
| input_ids | ||
| embedding 后 | ||
| Q | ||
| K_new / V_new | ||
| K_full / V_full(从 cache) | 同 K_new | |
| attention scores | ||
| attention 输出 | ||
| FFN 中间 | ||
| logits | (取最后位置) | |
| 运算性质 | GEMM(矩 × 矩) | GEMV(矩 × 量) |
| 瓶颈 | 算力 | 内存带宽 |
这张表是理解所有推理加速工作的起点:prefill 像训练的 forward,compute-bound;decode 是一串 GEMV,memory-bound,绝大部分时间在往 SM 里搬权重。两段的优化方向天差地别。
KV Cache 的形状与增长 — 单 token 几 KB · 长 context 几百 MB
每层一对 cache:
- — 预分配的最大序列长度(一般 = 模型上下文上限或调度器允许的上限)
- 其它符号沿用顶部约定();4 个维度的顺序按 PyTorch
[B, head, seq, head_dim]习惯
每个 token、每层的 cache 大小(fp16):
- 最左的 — K 和 V 两份
- — fp16 每元素 2 字节(用 fp8 / int8 可减半到 1/4)
- 右侧代入 Llama 3 8B:,
每个 token、全模型(32 层):。一个 4096-token 请求:。
几种工程优化:
- Paged Attention(vLLM):把 cache 拆成固定大小 block(通常 16 token),用一张 block table 做虚地址到物理地址的映射,消除碎片。对应公式里没变化,只是张量布局和访问模式变了。
- Sliding Window Attention(Mistral):只保留最近 个 token 的 K、V。cache 上限从 降到 ,代价是信息截断,靠层间层叠传递远距离依赖。
- INT8 / FP8 KV Cache:把 fp16 cache 量化到 int8 甚至 fp8,per-channel 或 per-token 量化,误差可控,cache 占用减半到 1/4。代表工作 KIVI / KVQuant。
- KV 压缩 / 丢弃(H2O、StreamingLLM、SnapKV):按注意力权重把不重要的位置踢出 cache,上下文超长时用。
- MLA:前面已经提过,是从模型结构层面改 cache 形状,不是后处理。
每步的计算量与内存开销 — Llama 3 8B fp16 · H100 拐点 ~330 FLOPs/byte
前面的维度图只说 shape,不说量级。推理优化 90% 的讨论都在算”这一步要花多少 FLOPs、搬多少字节”,所以这里直接把每步的开销拍成表。
以 Llama 3 8B、fp16、 为基准,prefill 取 ,decode 取 (即生成到第 2048 个 token 时的某一步)。
参考硬件拐点:H100 SXM fp16 理论算力 ~989 TFLOPs,HBM 带宽 ~3 TB/s,roofline 拐点 。高于它是 compute-bound,低于它是 memory-bound。
权重分布
| 组件 | Shape | fp16 大小 | 全模型(×32 层) |
|---|---|---|---|
| Embedding | 1.0 GB | 1.0 GB | |
| 32 MB | 1.0 GB | ||
| 8 MB | 256 MB | ||
| 8 MB | 256 MB | ||
| 32 MB | 1.0 GB | ||
| 117 MB | 3.7 GB | ||
| 117 MB | 3.7 GB | ||
| 117 MB | 3.7 GB | ||
| RMSNorm (每层 2 份) | 16 KB | 500 KB | |
| LM head | 1.0 GB | 1.0 GB | |
| 合计 | ~432 MB / 层 | ~16 GB |
全模型 fp16 权重约 16 GB,每次 forward 的”底价”就是把这 16 GB 从 HBM 里扫一遍。H100 @ 3 TB/s 下 ,这就是单请求 decode 的物理下限。
每层每步的计算 / 内存读写
对同一层在 prefill()和 decode()下的各个子步做对照。“Weight HBM” 是要从显存搬的权重字节,“KV HBM” 是要读/写的 KV Cache 字节。中间 activation 默认被 kernel 融合,不单独算。
| 步骤 | Prefill FLOPs(S=2048) | Decode FLOPs(S=1) | Weight HBM | KV HBM |
|---|---|---|---|---|
| RMSNorm | ≈ 42 MF | 20 KF | 8 KB | — |
| ≈ 68.7 GF | 33.5 MF | 32 MB | — | |
| (+写 cache) | ≈ 17.2 GF | 8.4 MF | 8 MB | W 4 MB / 2 KB |
| (+写 cache) | 17.2 GF | 8.4 MF | 8 MB | W 4 MB / 2 KB |
| RoPE | ~50 MF | 25 KF | — | — |
| Attn | ≈ 34.4 GF | 16.8 MF | — | R 4 MB(decode) |
| softmax | ~700 MF | 260 KF | — | — |
| Attn | 34.4 GF | 16.8 MF | — | R 4 MB(decode) |
| ≈ 68.7 GF | 33.5 MF | 32 MB | — | |
| RMSNorm | 42 MF | 20 KF | 8 KB | — |
| ≈ 241 GF | 117 MF | 117 MB | — | |
| 241 GF | 117 MF | 117 MB | — | |
| SiLU + gate | ~90 MF | 45 KF | — | — |
| ≈ 241 GF | 117 MF | 117 MB | — | |
| 每层合计 | ~960 GFLOPs | ~470 MFLOPs | ~432 MB | W 8 MB(P)/ R 8 MB(D) |
几个直接结论:
- FFN 是真正的主角。 吃掉 ~75% 的 FLOPs 和 ~80% 的权重带宽。MoE、稀疏激活、FFN 量化全都冲着这块去。
- Attention 的 4 个投影(Q/K/V/O) 占 ~18%,真正的 和 只占 ~7% — prefill 里 attention 并不是瓶颈,投影才是。
- Decode 的 KV 读取 每层 8 MB,在 时只占权重读取的 ~2%。但当上下文拉到 64K、128K,它会翻几十倍、直接反超权重带宽,成为新瓶颈(这也是 Paged Attention、sliding window、KV 量化出现的原因)。
全模型一次 forward
把 32 层 + embedding + LM head 加起来:
| 阶段 | FLOPs | HBM I/O | Arithmetic Intensity | 瓶颈 |
|---|---|---|---|---|
| Prefill S=2048, B=1 | ~31 TFLOPs | ~14 GB(权重)+ 256 MB(KV 写) | ~2200 FLOPs/byte | 算力 |
| Decode step, cache_len=2048, B=1 | ~15 GFLOPs | ~14 GB(权重)+ 256 MB(KV 读) | ~1.05 FLOPs/byte | 带宽 |
| LM head(prefill 只算最后一位) | ~1 GFLOP | 1 GB | ~1 FLOPs/byte | 带宽 |
| LM head(decode) | ~1 GFLOP | 1 GB | ~1 FLOPs/byte | 带宽 |
Decode 的 1.05 FLOPs/byte 比 H100 拐点 330 低 两个半数量级 — 意味着理想情况下单请求 decode 的算力利用率只有 。这就是 continuous batching 的数学依据:把 拉到 32,同一批权重 read 被 32 个请求共享,arithmetic intensity 直接 ×32,decode 吞吐几乎线性增长,直到 attention 部分或算力先撞墙。
心算口诀
两条规则覆盖 90% 的推理性能估算:
- FLOPs ≈ : 是参数量(~8B), 是这次 forward 处理的 token 总数。每个参数被每个 token 各用一次 MAC,一次 MAC 算 2 FLOPs。例如 prefill : ,与分项加总的 31 TFLOPs 吻合。
- 权重 HBM I/O ≈ bytes(fp16):一次 forward 就是把模型扫一遍,约 16 GB。
Arithmetic intensity 本质是 — forward 里一共参与的 token 数。Prefill 有 个 token,decode 只有 个。这一个数字直接决定了 prefill / decode 瓶颈不同的根源。
计算复杂度总览 — 有 KV Cache vs 没 KV Cache 差三个数量级
Prefill(一次性处理 个 token):
- — 每层 4 个 投影 + FFN 三个 投影(),各作用在 个 token 上
- — attention 的 与 ,含 的 score 矩阵
- 短序列下线性层主导; 后 attention 二次项追上
Decode 每步(处理 1 个 token,历史 ):
- — 当前已缓存的位置数()
- 单步只处理 1 个新 token,所以线性层把 替换为 ;attention 仍要扫整段 cache,故随 cache 线性增长
生成 个 token 总复杂度:
- — 生成 token 数(沿用顶部约定)
- — 平均 attention 跨距(prompt 长度 + 已生成段)
- 第一项是 decode 线性层对 步累加;第二项是 attention 部分对 的求和近似(精确是 )
对比无 KV Cache:,差异巨大。
把 FLOPs 和带宽一起看就更直观 — 这正是上一节那张表展示的:prefill 挑算力上限,decode 挑带宽上限;continuous batching 的意义就是把 个请求的 decode 拼成一个大 GEMV,搬权重的成本被 个请求摊薄,吞吐线性上涨直到算力或 attention 部分成为瓶颈。
工程优化怎么嵌进公式 — Flash Attn / Spec Decode / Continuous Batch
Flash Attention:数学上完全等价于标准 attention,公式一字不变。工程上把 softmax 和 matmul 融成一个 kernel,按 block 流式更新 softmax 的运行统计量(max、sum),避免把 的 attention 矩阵写回 HBM。复杂度不变,显存占用从 降到 ,速度提升主要来自 HBM 访问减少。FA-2 把切分粒度从 head 改到 query block;FA-3 在 H100/H200 上叠加 warpgroup MMA + producer-consumer 异步流水。
Flash Decoding:Flash Attention 在 decode 时 Q 只有 1 行,kernel 并行度不够。Flash Decoding 把 K、V 的 维切成多块并行,再做一次 log-sum-exp 归约。公式还是同一个 softmax,只是拆成两段算。
Speculative Decoding:用一个小的 draft 模型连续生成 个 token,再用大模型一次 prefill 那 个位置做验证。接受规则是
- — draft 模型生成的某个候选 token
- — 大(target)模型在该位置上对 的概率
- — 小(draft)模型同位置对 的概率
- 该规则配合”拒绝后从 重采”可证明最终采样分布严格等于 target 模型直采,无质量损失
关键是把 decode 的 次 GEMV 合并成一次 长度的 GEMM,把大模型压在 memory-bound 边界的时间重新变成 compute-bound。期望每步接受 个 token,吞吐放大 倍(减去 draft 开销)。变种:Medusa(多头预测)、EAGLE(特征级 draft)、Lookahead Decoding(无 draft 模型)。
Continuous Batching(vLLM、TGI):不在 prefill 边界 pad 到齐,而是请求级的 step 调度。每个 step 选一批正处于同一 phase(prefill 或 decode)的请求拼 batch,完成一个释放一个。数学上各请求互不影响,只是编排顺序变了。原始论文是 OSDI’22 的 Orca。
Chunked Prefill:把长 prompt 的 prefill 拆成多段,和 decode 请求混进同一个 step,减少 decode 请求的等待抖动。公式无变化。SARATHI / DistServe 等系统的核心调度原语。
一句话串起整个过程 — 从 token ID 到下一个 token
输入 token ID → 查 embedding → 过 层(Pre-RMSNorm → Attention 带 RoPE → 残差 → Pre-RMSNorm → SwiGLU FFN → 残差) → Final RMSNorm → LM Head → logits → 采样。Prefill 时 个 token 并行走一遍,产物是第一个 token + 完整 KV Cache;Decode 时每步只输入 1 个 token,attention 处从 cache 读历史 K、V,其他所有操作都是逐 token 独立的。
关键工程不变量 — 值得背下来的 6 条
- 主线张量 shape 始终是 。残差结构保证维度不变,一旦你在某一步看到 变了,要么是在 attention 内部展成多头,要么是在 FFN 内部升到 ,出子层就回到 。
- K、V 一旦算出就不再变。因为它们只是线性投影 作用在已经定了的输入 上,而 causal 结构保证更晚的位置不会反过来改更早位置的表示。这是 KV Cache 成立的数学基础。
- Attention 是唯一跨 token 的操作,其他所有操作(norm、投影、FFN、激活)都逐 token 独立。所以只需要缓存跨 token 操作需要的 K、V;其他都可以即时算完扔掉。
- Decode 的 scores shape 是 。“1” 是 Q 侧(当前新 token), 维在和 做加权求和时被消掉,输出又回到 1 行。
- Decode 每步非 attention 部分的计算量恒定,只有 attention 随 cache 长度线性增长。所以真正”越生成越慢”的本质,是 attention 的 越来越长,加上 KV Cache 把 memory footprint 顶到 HBM 带宽上限。
- Prefill 用 GEMM,decode 用 GEMV。这一字之差决定了所有推理引擎都要做两套 kernel、两套调度策略。理解这一条,后面看任何推理优化论文都不会迷路。
如果把这十条装进脑子里,再去读 Flash Attention、PagedAttention、MLA、投机解码这些论文,会发现它们都是在这张骨架的某个位置做局部优化,而骨架本身二十年没怎么变过。
参考资料 — 公式 · 论文 · 工程 blog
架构与核心算子
- Vaswani et al., “Attention Is All You Need” (NeurIPS 2017) — Transformer 原始论文。arxiv.org/abs/1706.03762
- Shazeer, “GLU Variants Improve Transformer” (2020) — SwiGLU / GeGLU / ReGLU 的来源。arxiv.org/abs/2002.05202
- Zhang & Sennrich, “Root Mean Square Layer Normalization” (NeurIPS 2019) — RMSNorm。arxiv.org/abs/1910.07467
- Hendrycks & Gimpel, “Gaussian Error Linear Units (GELUs)” (2016) — GeLU 定义与近似。arxiv.org/abs/1606.08415
位置编码
- Su et al., “RoFormer: Enhanced Transformer with Rotary Position Embedding” (2021) — RoPE。arxiv.org/abs/2104.09864
- Peng et al., “YaRN: Efficient Context Window Extension of Large Language Models” (2023) — 长上下文 RoPE 缩放。arxiv.org/abs/2309.00071
- bloc97 & emozilla, “NTK-Aware Scaled RoPE” — 早期 NTK-aware 工作的开源实现讨论。reddit / LocalLLaMA
多头变体(KV 共享 / 压缩)
- Shazeer, “Fast Transformer Decoding: One Write-Head is All You Need” (2019) — MQA。arxiv.org/abs/1911.02150
- Ainslie et al., “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints” (EMNLP 2023) — GQA。arxiv.org/abs/2305.13245
- DeepSeek-AI, “DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model” (2024) — MLA 首次提出。arxiv.org/abs/2405.04434
- DeepSeek-AI, “DeepSeek-V3 Technical Report” (2024) — MLA + MoE 工程化。arxiv.org/abs/2412.19437
Mixture-of-Experts
- Shazeer et al., “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer” (ICLR 2017) — top- gated MoE 在深度网络里的奠基论文。arxiv.org/abs/1701.06538
- Lepikhin et al., “GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding” (ICLR 2021) — 加 aux loss 的标准 load-balance 实现。arxiv.org/abs/2006.16668
- Fedus et al., “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity” (JMLR 2022) — top-1 + capacity factor。arxiv.org/abs/2101.03961
- Du et al., “GLaM: Efficient Scaling of Language Models with Mixture-of-Experts” (ICML 2022) — Google 1.2T MoE。arxiv.org/abs/2112.06905
- Jiang et al., “Mixtral of Experts” (Mistral AI, 2024) — Mixtral 8×7B 技术报告。arxiv.org/abs/2401.04088
- DeepSeek-AI, “DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Models” (2024) — 细粒度 + 共享 expert 范式。arxiv.org/abs/2401.06066
- Wang et al., “Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts” (2024) — DeepSeek V3 的无损失负载均衡。arxiv.org/abs/2408.15664
- Meta AI, “The Llama 4 herd” (2025) — Llama 4 Scout / Maverick / Behemoth 技术细节。ai.meta.com/blog/llama-4
Flash Attention 系列
- Dao et al., “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness” (NeurIPS 2022) — FA-1。arxiv.org/abs/2205.14135
- Dao, “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning” (2023) — FA-2。arxiv.org/abs/2307.08691
- Shah et al., “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-Precision” (NeurIPS 2024) — FA-3 / Hopper。arxiv.org/abs/2407.08608
- Dao et al., “Flash-Decoding for long-context inference” (Stanford / Together blog, 2023) — Flash Decoding。crfm.stanford.edu
推理引擎与服务调度
- Kwon et al., “Efficient Memory Management for Large Language Model Serving with PagedAttention” (SOSP 2023) — vLLM / PagedAttention。arxiv.org/abs/2309.06180
- Yu et al., “Orca: A Distributed Serving System for Transformer-Based Generative Models” (OSDI 2022) — Continuous batching 原始论文。usenix.org/osdi22
- Agrawal et al., “SARATHI: Efficient LLM Inference by Piggybacking Decodes with Chunked Prefills” (2023) — Chunked prefill。arxiv.org/abs/2308.16369
- Zhong et al., “DistServe: Disaggregating Prefill and Decoding for Goodput-optimized Large Language Model Serving” (OSDI 2024) — Prefill / decode 分离。arxiv.org/abs/2401.09670
- vLLM 项目主仓库(Paged Attention 工程实现)。github.com/vllm-project/vllm
- Hugging Face Text Generation Inference (TGI)。github.com/huggingface/text-generation-inference
- NVIDIA TensorRT-LLM 文档(FA / In-flight batching 实现)。nvidia.github.io/TensorRT-LLM
Speculative Decoding 家族
- Leviathan, Kalman, Matias, “Fast Inference from Transformers via Speculative Decoding” (ICML 2023)。arxiv.org/abs/2211.17192
- Chen et al., “Accelerating Large Language Model Decoding with Speculative Sampling” (DeepMind, 2023) — 同期独立工作。arxiv.org/abs/2302.01318
- Cai et al., “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads” (2024)。arxiv.org/abs/2401.10774
- Li et al., “EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty” (ICML 2024)。arxiv.org/abs/2401.15077
- Fu et al., “Lookahead Decoding: Breaking the Sequential Dependency of LLM Inference” (2024)。arxiv.org/abs/2402.02057
KV Cache 压缩 / 量化
- Xiao et al., “Efficient Streaming Language Models with Attention Sinks” (ICLR 2024) — StreamingLLM。arxiv.org/abs/2309.17453
- Zhang et al., “H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models” (NeurIPS 2023)。arxiv.org/abs/2306.14048
- Li et al., “SnapKV: LLM Knows What You are Looking for Before Generation” (2024)。arxiv.org/abs/2404.14469
- Liu et al., “KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache” (ICML 2024)。arxiv.org/abs/2402.02750
- Hooper et al., “KVQuant: Towards 10 Million Context Length LLM Inference with KV Cache Quantization” (NeurIPS 2024)。arxiv.org/abs/2401.18079
代表性开源模型技术报告
- Meta AI, “The Llama 3 Herd of Models” (2024) — Llama 3 全家族。arxiv.org/abs/2407.21783
- Jiang et al., “Mistral 7B” (2023) — Sliding Window Attention。arxiv.org/abs/2310.06825
- Qwen Team, “Qwen2.5 Technical Report” (2024)。arxiv.org/abs/2412.15115
硬件 / Roofline
- NVIDIA, “H100 Tensor Core GPU Architecture Whitepaper” (2022)。resources.nvidia.com
- Williams, Waterman, Patterson, “Roofline: An Insightful Visual Performance Model for Multicore Architectures” (CACM 2009) — Roofline 模型原始论文。dl.acm.org
- Hoffmann et al., “Training Compute-Optimal Large Language Models” (Chinchilla, 2022) — FLOPs 经验法则的训练侧版本。arxiv.org/abs/2203.15556
其他长文 / 教程
- Lilian Weng, “The Transformer Family Version 2.0”。lilianweng.github.io
- Horace He, “Making Deep Learning Go Brrrr From First Principles” — compute / memory / overhead 三类瓶颈。horace.io/brrr_intro
- Adam Casson, “Transformer Inference Arithmetic” — 推理 FLOPs / KV cache 心算。kipp.ly
- Jay Mody, “LLM inference, in detail” — 张量 shape 流转的另一种讲法。jaykmody.com