LLM Inference Walkthrough — Tensor Shapes and Core Formulas Across the Whole Pipeline
Lay out a Llama 3-style dense decoder-only model end to end from embedding to sampling, embedding the math positions of common variants (MHA / MQA / GQA / MLA, RMSNorm / LayerNorm, SwiGLU / GeGLU, Flash Attention / Paged Attention) along the way — so that after reading you can draw this diagram from memory. One fact frames the whole article: this architecture has barely changed in twenty years; every “inference optimization” is a local surgery somewhere on the same skeleton.
Notation Conventions — Llama 3 8B as the Reference
The same notation is used throughout, with Llama 3 8B as the concrete example:
| Symbol | Meaning | Example value (Llama 3 8B) |
|---|---|---|
| batch size | 2 | |
| prompt length | 10 | |
| number of layers | 32 | |
| hidden dim | 4096 | |
| vocab size | 128256 | |
| number of Q heads | 32 | |
| number of KV heads (GQA) | 8 | |
| per-head dim | 128 | |
| FFN intermediate dim | 14336 | |
| number of generated tokens | 100 | |
| current decode step |
All shape annotations follow PyTorch convention, written as [B, ..., H]; weight matrices follow “in × out”, written as .
Core Formulas Quick Reference — Embedding · Norm · Attn · FFN · LM Head
Embedding
- — embedding vector at position
- — embedding lookup table ( tokens, each -dim)
- — integer ID of the -th input token
Many implementations share with the LM Head’s (tied embedding), saving memory and providing mild regularization; Llama models do not share by default.
Normalization: LayerNorm vs RMSNorm
Standard LayerNorm:
- — single-token hidden state
- — mean and variance computed across the dimensions
- — small constant to avoid divide-by-zero (typically )
- — learnable scale / shift
- — element-wise multiplication
RMSNorm (the mainstream choice for Llama / Mistral / Qwen):
- — single-token hidden state
- — learnable per-channel scale (no )
- — small constant to avoid divide-by-zero
- The denominator is the root mean square of — hence “RMSNorm”
RMSNorm drops the mean and , roughly halving both compute and parameters; empirically the quality cost is nearly zero. All mainstream inference engines organize this as pre-norm: norm lives inside the residual branch, and the residual main path bypasses it.
Q/K/V Projection + Positional Encoding
- — output of the previous RMSNorm
- — Q projection weight
- — K, V projection weights (narrower under GQA since )
- , — projection outputs, later reshaped to expose the head dimension
RoPE (Rotary Positional Embedding) core idea: rather than adding “absolute position ” to the embedding (like the original Transformer’s sinusoidal), let position act as a rotation on Q and K, so that when the two vectors are dotted only the relative position survives.
Split each head’s dimensions into two-dim subspaces; the -th subspace () corresponds to coordinates . At position that pair is multiplied by a 2D rotation matrix with angle :
- — absolute position of the current token
- — 2D subspace index (every two dims of the per-head dim form one pair)
- — the -th 2D pair of the projected Q vector
- — the rotated pair at position
- — base angular velocity of the -th subspace
- — frequency-decay base (Llama default ; YaRN etc. dynamically enlarge it)
Expanded to scalars:
Same symbols as above — this just unrolls the matrix into two scalar identities, easier to map to code.
This is the standard counterclockwise rotation matrix ; acting on a 2D vector is equivalent to multiplying by in the complex plane — which is why the official Llama repo reshapes to complex64 and multiplies by directly (HuggingFace equivalently pairs in a “half-rotation” form; the two differ only by a coordinate permutation). K rotates by the same angle to get .
Why does rotation encode relative position? Let denote the block-diagonal full- rotation (the -th block uses angle ). It’s orthogonal and satisfies , so:
- — per-head Q, K vectors (pre-rotation)
- — block-diagonal rotation matrices for positions ,
- — standard inner product
- Last step uses : rotations are orthogonal and angles are additive
When attention computes , every score depends only on the difference — absolute position is automatically canceled inside the dot product, leaving relative position. This is RoPE’s key property and the fundamental reason it is more stable than additive positional encodings.
Frequency spectrum design: ( usually ; is the per-head dim, not model hidden dim) assigns the subspaces angular velocities from fast to slow:
- : , period tokens — carries the “near-neighbor” signal.
- : , period tokens — carries the “long-range” signal.
The geometric distribution lets one head carry positional signals at many scales simultaneously — isomorphic to the original Transformer’s sinusoidal, just moved from addition to multiplication.
Long-context scaling: training only saw , and low-frequency subspaces can’t even complete one full period inside ; once inference reaches , the low-frequency angles fall outside the training distribution and attention immediately degrades. Three common fixes all modify :
- Position Interpolation (Chen et al. 2023): , equivalent to — compresses all frequencies uniformly. Simple but wastes high-frequency precision.
- NTK-aware scaling: scales low frequencies while preserving high ones, equivalent to .
- YaRN (Peng et al. 2023): band-wise treatment — high frequencies (period , full periods seen during training) are left alone, low frequencies (period , no full periods seen) are PI-scaled, and the mid band interpolates smoothly; additionally a temperature correction cancels the attention-entropy drift introduced by the scaling. Llama 3.1 / 3.2 went from 8K → 128K with YaRN.
RoPE acts only on Q and K, not on V — V is the weighted value itself and needs no positional signal.
Scaled Dot-Product Attention
- — projected and head-split tensors, with the last two dims being (head dim broadcast implicit)
- — similarity matrix between every pair
- — scaling factor that keeps softmax out of the near-zero-gradient region
- — causal mask with upper triangle set to (position can only see )
- softmax normalizes along the K dim, producing attention weights that then re-weight V back to
Multi-Head Variants: MHA / MQA / GQA / MLA
All four variants differ only on the K, V side: Q is always independent heads; what changes is how many independent K, V sets exist, and whether K, V are low-rank compressed. Below we write head-level formulas for original MHA and each of the three variants under one notation — dropping the batch and layer dims, looking at head at position .
| Variant | Per-token cache (per layer, fp16) | Representative models | |
|---|---|---|---|
| MHA | GPT-2/3, Llama 1/2 7B | ||
| GQA | grouped | Llama 3, Qwen2, Mistral | |
| MQA | PaLM, Falcon | ||
| MLA | low-rank compressed | DeepSeek V2/V3 |
MHA — Multi-Head Attention (original, Vaswani et al. 2017)
Each Q head gets its own K, V — independent sets:
- — per-head Q, K, V projections (concatenated into a big matrix in practice, mathematically equivalent)
- Each token writes pairs into the cache — scalars in total
- Llama 2 7B: , per-token cache = / layer
MQA — Multi-Query Attention (Shazeer 2019)
All Q heads share a single K, V — only 1 set left:
- — one shared K, V projection across all heads
- Cache shrinks to — smaller than MHA; but capacity-constrained, large models drop quality if MQA is applied directly — rarely used standalone in practice, mostly superseded by GQA
GQA — Grouped-Query Attention (Ainslie et al. 2023)
Split the Q heads into groups, share K, V within each group — a continuous interpolation between MHA and MQA:
- , — one K, V set per group
- — group index of Q head
- recovers MHA, recovers MQA
- Kernels materialize only K, V sets; during score computation, K and V are broadcast along the group dim up to , not actually replicated
- Llama 3 70B: , per-token cache = / layer — 8× smaller than a same-shape MHA
MLA — Multi-Head Latent Attention (DeepSeek V2 2024)
GQA only shrinks cache linearly in head count; MLA jointly compresses K and V into a low-rank latent , and breaks RoPE off into a shared shallow branch. Four steps:
(1) Content branch — K, V share a down-projection; only the latent goes into the cache:
(2) Q-side low-rank — saves training memory; Q is never cached at inference:
(3) Decoupled RoPE branch — K side computes one shared -dim vector reused by all heads:
(4) Concat + attention — content and RoPE parts are concatenated along the head dim before scoring:
- — shared KV down-projection (DeepSeek V3 uses )
- — per-head up-projections
- — shared RoPE K branch (DeepSeek V3 uses )
- — these two are all that MLA actually caches — scalars
- DeepSeek V3: , per-token cache = bytes / layer — another order of magnitude below a same-scale GQA
Why does RoPE need its own branch? At inference there’s a “weight-folding” trick — by associativity, fold into :
K’s content part is never actually reconstructed — attention runs directly on the cached . But RoPE’s rotation angle depends on absolute position , so it cannot be folded into fixed weights — applying RoPE on reconstructed K kills the fold. DeepSeek’s fix: pull RoPE out into an independent -dim shallow branch, so “foldable content” and “must-rotate-live position” don’t interfere. Cache compression and relative-position signal both survive.
Overview: Params / FLOPs / KV cache
Below we decompose each variant’s cost step by step. Conventions: one attention sublayer per layer, prefill length , ignoring norm / bias / softmax and other non-matmul terms; a linear projection (with ) counts as FLOPs (2 FLOPs per MAC); and do not deduct the causal-mask triangular half. MLA additionally uses = Q-side latent dim (DeepSeek V3 uses ).
Parameters (per layer)
| Step | MHA | MQA | GQA | MLA |
|---|---|---|---|---|
| Q proj | ||||
| K proj | ||||
| V proj | (shares with K) | |||
| RoPE branch | — | — | — | |
| Total | sum of rows above |
Prefill FLOPs (per layer, sequence length )
| Step | MHA | MQA | GQA | MLA |
|---|---|---|---|---|
| Q proj | ||||
| K proj | ||||
| V proj | ||||
| RoPE branch | ||||
| softmax · | ||||
“RoPE branch” in the MHA/MQA/GQA columns means the elementwise rotation applied to Q, K — , negligible next to matmuls. In MLA it refers specifically to the extra / projections, which are genuine matmuls and must be counted.
KV cache (per token, per layer, fp16)
| Variant | What’s cached | Bytes |
|---|---|---|
| MHA | pairs of | |
| MQA | 1 pair | |
| GQA | pairs | |
| MLA | + |
Three observations from reading these tables horizontally:
- All variants only touch the K/V side. Q projection, , and are identical across the four — the surgery only touches K, V, so total attention params and FLOPs never differ by an order of magnitude at the same model size.
- MHA → MQA/GQA saves params, FLOPs, and cache together; MLA trades params/FLOPs for cache. GQA shrinks , scaling K/V projections, FLOPs, and cache down linearly. MLA does the opposite — adds DKV + two UK/UV stages + a dedicated RoPE branch, leaving params and prefill FLOPs at the same order as a same-size GQA, in exchange for shrinking per-token cache from KB scale down to ~1 KB.
- The term doesn’t differ across MHA/MQA/GQA. and are both (K, V sharing is broadcast only, never reducing the quadratic term). So once , the three variants’ prefill FLOPs converge — what truly separates them is how many cache bytes each decoded token must read, which is memory I/O, not compute.
Output Projection + Residual
- — attention sublayer input (the value before pre-norm)
- — RoPE-rotated Q and K; (V is not rotated, the prime is just notational uniformity)
- — output projection that maps concatenated multi-head output back to
- — attention sublayer output + residual
FFN Variants
Classic Bilinear FFN (GPT-2)
- — FFN input
- — up-projection
- — down-projection
- — element-wise scalar nonlinearity (see below)
Activation Functions
The scalar nonlinearity applied after the up-projection. Input and output are both scalars ; it’s applied independently per token, per hidden dim. Four functions cover all the mainstream choices used in the Transformer era.
ReLU (Nair & Hinton 2010)
- Identity on , zero on
- Cheapest to compute; but zero gradient on the negative side — the “dead neuron” problem
- Used in the original Transformer and early BERT implementations
Sigmoid
- Compresses into , a natural “gate” signal — the original GLU’s is exactly this
- Largely abandoned as a standalone FFN activation — saturates at both ends, killing gradients
GeLU (Gaussian Error Linear Unit, Hendrycks & Gimpel 2016)
In practice, the OpenAI tanh approximation is used (numerical error , avoids the erf call):
- is the standard normal CDF — intuitively “let through weighted by its tail probability”
- Everywhere differentiable, non-monotonic (a small negative dip on the side), smoother than ReLU
- Used in GPT-2/3, BERT, ViT
SiLU / Swish (Ramachandran et al. 2017)
- Shape very close to GeLU (also smooth, non-monotonic, passes through origin) but with a simpler closed form — no , no cubic term
- Self-gated: uses its own sigmoid to control signal throughput
- Used by PaLM and the entire Llama family (as the gating in SwiGLU)
Of these four, ReLU and sigmoid have largely exited mainstream FFNs; the active ones are GeLU (GPT era) and SiLU (Llama / PaLM and later). The bf16 numerical difference between the two is < 1%; paper choice is mostly path-dependent. The real engineering inflection point was switching activation from “applied to the projection” (classic FFN) to “applied to the gate” (the GLU family below).
GLU Family (used by Llama, PaLM, Mistral)
-
— FFN input
-
— two independent up-projections
-
— down-projection
-
— element-wise multiplication
-
— gating activation; choice determines the variant name:
Variant Representative model GLU Dauphin et al. 2017 original ReGLU — GeGLU T5 v1.1 SwiGLU PaLM, Llama 1/2/3
SwiGLU has one more projection than the classic GeLU-FFN (three matrices vs two); to align parameter budget, implementations typically set ( is the GPT-2 convention), and Llama 3 8B’s .
Mixture-of-Experts (MoE)
In a classic dense FFN, every token passes through the same pair — all parameters used, all FLOPs paid. MoE (Shazeer et al. 2017) replicates the FFN times (“experts”) and routes each token through only the top , so total parameters scale linearly while per-token activated params and FLOPs stay almost flat — capacity grows without paying compute.
Formally the FFN sublayer becomes:
The router (gating network) decides which experts each token visits:
- — current token’s hidden state (FFN sublayer input)
- — router projection mapping the hidden state to expert logits
- — router probabilities across all experts
- — index set of top- selected experts
- — combine weight; Mixtral and DeepSeek re-normalize the top- scores so they sum to 1
- — -th expert, typically a SwiGLU FFN with its own (unshared) weights
MoE vs Dense FFN
| Dimension | Dense FFN (SwiGLU) | MoE (top- of ) |
|---|---|---|
| Params (FFN block) | + (router) | |
| Per-token activated FLOPs | + (router) | |
| Per-token weight HBM (decode) | bytes | bytes |
| VRAM footprint | (all experts must fit) | |
| Kernel shape | fixed GEMM | grouped GEMM / token permutation |
| Multi-GPU comm | — | all-to-all under expert parallelism |
| Training stability | straightforward | router prone to collapse, needs load-balance |
Key trade-off: decoupling capacity from FLOPs. At the same activated parameter count (i.e. same FLOP budget), MoE can pack 8–32× more total parameters — knowledge capacity goes up for free. The costs are VRAM (must fit all experts), routing stability (avoiding hot experts), and inference batching (per-token paths differ).
Load Balance
A naive router collapses onto a few hot experts during training. GShard / Switch use an auxiliary loss as a soft constraint:
- — fraction of tokens in the current batch routed to expert
- — mean router probability assigned to expert within the batch
- — auxiliary loss weight (Switch uses )
- Intuition: large AND large means the expert is chosen often and with high confidence → penalized → gradient pushes that router logit back down
DeepSeek V3 goes further with auxiliary-loss-free load balance: maintain a bias per expert, base TopK on ; raise for under-used experts and lower it for over-used ones — affecting selection without polluting gradients, avoiding the main-task accuracy hit that aux loss can cause.
Modern Implementations
| Model | Total / activated | routed | top- | Shared | Routing notes |
|---|---|---|---|---|---|
| Switch Transformer (2021) | 1.6 T / ~26 B | 2048 | 1 | — | First stably-trained large MoE; hard top-1 + capacity factor |
| GLaM (2022) | 1.2 T / 97 B | 64 | 2 | — | Google decoder-only MoE; halves inference cost vs dense |
| Mixtral 8×7B (2023) | 47 B / 13 B | 8 | 2 | — | Open-source “8 big experts” reference; per-layer routing |
| Mixtral 8×22B (2024) | 141 B / 39 B | 8 | 2 | — | Scaled-up 8×7B |
| Qwen1.5-MoE-A2.7B (2024) | 14 B / 2.7 B | 60 | 4 | 4 | Alibaba’s first fine-grained MoE |
| DeepSeek V2 (2024) | 236 B / 21 B | 160 | 6 | 2 | Fine-grained + shared-expert paradigm established |
| DeepSeek V3 (2024) | 671 B / 37 B | 256 | 8 | 1 | Aux-loss-free load balance; paired with MLA |
| Qwen3-MoE 235B-A22B (2025) | 235 B / 22 B | 128 | 8 | — | DeepSeek-style fine-grained |
| Llama 4 Scout (2025) | 109 B / 17 B | 16 | 1 | 1 | top-1 + 1 shared; extreme sparsity |
| Llama 4 Maverick (2025) | 400 B / 17 B | 128 | 1 | 1 | Same idea, expert count pushed to 128 |
| GPT-4 (rumored) | ~1.8 T / ~280 B | 16 | 2 | — | Never officially disclosed; semiconductor-analyst reconstructions |
Four patterns:
- Coarse → fine. Mixtral-era designs are “few-and-fat” (8 experts each near a dense FFN’s size); DeepSeek onward shrinks each expert and pushes count past 100, going “many-and-thin” — the same activated FLOPs now span exponentially more combinations.
- Shared experts become standard. DeepSeek / Qwen-MoE / Llama 4 all reserve 1–2 “everyone-must-visit” shared experts per layer to absorb generic patterns; the sparse experts handle specialization.
- MoE pairs with KV-cache compression. Pushing total params to hundreds of billions makes long-context decode equally pressured — it’s no coincidence DeepSeek V3 ships MLA + fine-grained MoE together.
- top- goes to the extremes. Switch (k=1) → Mixtral (k=2) → DeepSeek V3 (k=8 with small experts) → Llama 4 (k=1 + shared). Small makes batching and capacity bounds tractable; fine-graining and shared experts make up the lost expressivity.
Residual
- — attention sublayer output (already contains the first residual)
- — full Transformer layer output, fed into the next layer
- RMSNorm sits inside the residual branch (pre-norm); the main path goes straight through
LM Head + Sampling
- — hidden state after the final RMSNorm, at the last position
- — output embedding matrix (optionally shared with — tied embedding)
- — unnormalized score per vocab token
- — temperature; larger flattens the distribution ( → uniform, → argmax)
- — normalized probability distribution
Several logits transformations are typically applied before sampling:
Repetition penalties (repetition / frequency / presence penalty):
- — a token in the vocabulary
- — the raw logit for that token
- — presence penalty (subtract once if ever seen)
- — frequency penalty (subtract linearly by occurrence count)
- — indicator function (1 if the condition holds, else 0)
- — number of times token has appeared in the generated sequence
Top-k: keep the largest logits, set the rest to .
Top-p (nucleus): sort by probability descending, keep the smallest set whose cumulative probability .
Min-p: keep tokens with ; friendlier to low-entropy distributions.
Typical-p: truncate by deviation from the conditional entropy, keeping the set where is small.
All truncations act before/after the probability distribution itself, without changing the skeleton of the formula.
Prefill Stage Shape Transitions — S Tokens Walk Through Once
Input: input_ids [B, S] = [2, 10]. The figure below shows the forward pass of one Transformer layer, wrapped in 32 layers; the shape stays at [B, S, H] = [2, 10, 4096] throughout, with residual edges shown as orange dashed lines.
After prefill, the KV Cache state: each layer holds the first 10 positions.
Decode Stage Shape Transitions — Step t Processes Only 1 Token
Prior state: positions filled.
Input: input_ids [B, 1] = [2, 1] (the 1 token generated at the previous step).
Prefill vs Decode Shape Comparison — GEMM vs GEMV · Compute vs Bandwidth
| Position | Prefill | Decode (per step) |
|---|---|---|
| input_ids | ||
| after embedding | ||
| Q | ||
| K_new / V_new | ||
| K_full / V_full (from cache) | same as K_new | |
| attention scores | ||
| attention output | ||
| FFN intermediate | ||
| logits | (last position) | |
| Operation type | GEMM (matrix × matrix) | GEMV (matrix × vector) |
| Bottleneck | compute | memory bandwidth |
This table is the starting point for understanding every inference acceleration effort: prefill is like training’s forward, compute-bound; decode is a chain of GEMVs, memory-bound, with most time spent fetching weights into SMs. The optimization directions of the two are worlds apart.
KV Cache Shape and Growth — Few KB Per Token · Hundreds of MB at Long Context
One pair of caches per layer:
- — pre-allocated max sequence length (typically the model’s context cap or the scheduler’s limit)
- Other symbols follow the top-of-article convention (); the 4 dimensions follow PyTorch’s
[B, head, seq, head_dim]ordering
Per-token, per-layer cache size (fp16):
- Leftmost — one for K, one for V
- — fp16 element size (fp8 / int8 cuts this to 1/2 or 1/4)
- Right side plugs in Llama 3 8B: ,
Per-token across the full 32-layer model: . A 4096-token request: .
Several engineering optimizations:
- Paged Attention (vLLM): split the cache into fixed-size blocks (typically 16 tokens), with a block table mapping virtual to physical addresses, eliminating fragmentation. The formulas don’t change; only the tensor layout and access pattern change.
- Sliding Window Attention (Mistral): keep only the most recent tokens of K and V. Cache cap drops from to , at the cost of information truncation, with long-range dependencies relayed through cross-layer stacking.
- INT8 / FP8 KV Cache: quantize fp16 cache down to int8 or even fp8, per-channel or per-token quantization, with controllable error and cache footprint cut by 1/2 to 1/4. Representative work: KIVI / KVQuant.
- KV compression / eviction (H2O, StreamingLLM, SnapKV): drop unimportant positions based on attention weights; used at very long context lengths.
- MLA: mentioned earlier — modifies cache shape at the model-structure level, not as a postprocess.
Per-Step Compute and Memory Cost — Llama 3 8B fp16 · H100 Knee ~330 FLOPs/byte
The shape diagrams above show shapes but not magnitudes. 90% of inference optimization discussion is about “how many FLOPs does this step cost, how many bytes move,” so let’s lay each step’s cost into tables directly.
Using Llama 3 8B, fp16, as the baseline; for prefill take ; for decode take (some step during generation around the 2048th token).
Reference hardware knee: H100 SXM fp16 theoretical compute ~989 TFLOPs, HBM bandwidth ~3 TB/s, roofline knee . Above it is compute-bound, below is memory-bound.
Weight Distribution
| Component | Shape | fp16 size | full model (× 32 layers) |
|---|---|---|---|
| Embedding | 1.0 GB | 1.0 GB | |
| 32 MB | 1.0 GB | ||
| 8 MB | 256 MB | ||
| 8 MB | 256 MB | ||
| 32 MB | 1.0 GB | ||
| 117 MB | 3.7 GB | ||
| 117 MB | 3.7 GB | ||
| 117 MB | 3.7 GB | ||
| RMSNorm (2 per layer) | 16 KB | 500 KB | |
| LM head | 1.0 GB | 1.0 GB | |
| Total | ~432 MB / layer | ~16 GB |
Full-model fp16 weights are ~16 GB; the “floor price” of every forward pass is to scan these 16 GB from HBM. At H100’s 3 TB/s, — this is the physical lower bound of single-request decode.
Per-Layer, Per-Step Compute / Memory I/O
Compare the same layer’s substeps under prefill () and decode (). “Weight HBM” is the weight bytes fetched from VRAM; “KV HBM” is the KV-cache bytes read/written. Intermediate activations are assumed fused into kernels and not counted separately.
| Step | Prefill FLOPs (S=2048) | Decode FLOPs (S=1) | Weight HBM | KV HBM |
|---|---|---|---|---|
| RMSNorm | ≈ 42 MF | 20 KF | 8 KB | — |
| ≈ 68.7 GF | 33.5 MF | 32 MB | — | |
| (+write cache) | ≈ 17.2 GF | 8.4 MF | 8 MB | W 4 MB / 2 KB |
| (+write cache) | 17.2 GF | 8.4 MF | 8 MB | W 4 MB / 2 KB |
| RoPE | ~50 MF | 25 KF | — | — |
| Attn | ≈ 34.4 GF | 16.8 MF | — | R 4 MB (decode) |
| softmax | ~700 MF | 260 KF | — | — |
| Attn | 34.4 GF | 16.8 MF | — | R 4 MB (decode) |
| ≈ 68.7 GF | 33.5 MF | 32 MB | — | |
| RMSNorm | 42 MF | 20 KF | 8 KB | — |
| ≈ 241 GF | 117 MF | 117 MB | — | |
| 241 GF | 117 MF | 117 MB | — | |
| SiLU + gate | ~90 MF | 45 KF | — | — |
| ≈ 241 GF | 117 MF | 117 MB | — | |
| Per-layer total | ~960 GFLOPs | ~470 MFLOPs | ~432 MB | W 8 MB (P) / R 8 MB (D) |
A few direct conclusions:
- FFN is the real protagonist. consume ~75% of FLOPs and ~80% of weight bandwidth. MoE, sparse activation, and FFN quantization all target this block.
- The 4 attention projections (Q/K/V/O) account for ~18%; the actual and only ~7% — in prefill, attention isn’t the bottleneck — the projections are.
- Decode’s KV reads are 8 MB per layer; at this is only ~2% of weight reads. But once context stretches to 64K or 128K, it grows tens of times, overtaking weight bandwidth as the new bottleneck (this is why Paged Attention, sliding window, and KV quantization exist).
One Full Forward Pass
Adding 32 layers + embedding + LM head:
| Stage | FLOPs | HBM I/O | Arithmetic Intensity | Bottleneck |
|---|---|---|---|---|
| Prefill S=2048, B=1 | ~31 TFLOPs | ~14 GB (weights) + 256 MB (KV write) | ~2200 FLOPs/byte | compute |
| Decode step, cache_len=2048, B=1 | ~15 GFLOPs | ~14 GB (weights) + 256 MB (KV read) | ~1.05 FLOPs/byte | bandwidth |
| LM head (prefill, last position only) | ~1 GFLOP | 1 GB | ~1 FLOPs/byte | bandwidth |
| LM head (decode) | ~1 GFLOP | 1 GB | ~1 FLOPs/byte | bandwidth |
Decode’s 1.05 FLOPs/byte is 2.5 orders of magnitude below H100’s knee of 330 — meaning ideal single-request decode compute utilization is only . This is the mathematical basis for continuous batching: push to 32 so the same weight read is amortized across 32 requests; arithmetic intensity scales by 32×, decode throughput grows almost linearly until the attention portion or compute itself becomes the wall.
Mental-Math Rules
Two rules cover 90% of inference performance estimation:
- FLOPs ≈ : is the parameter count (~8B); is the total number of tokens this forward pass processes. Each parameter is used once per token (one MAC = 2 FLOPs). E.g., prefill : , matching the itemized sum of 31 TFLOPs.
- Weight HBM I/O ≈ bytes (fp16): one forward pass scans the model once, about 16 GB.
Arithmetic intensity is essentially — the total number of tokens participating in this forward. Prefill has tokens; decode has only . This single number directly determines why prefill and decode have different bottlenecks.
Compute Complexity Overview — With vs Without KV Cache Differ by Three Orders of Magnitude
Prefill (process tokens at once):
- — per layer, 4 projections + 3 FFN projections (with ), applied to tokens
- — attention’s and , with the score matrix
- At short sequences linear layers dominate; once the attention quadratic catches up
Decode per step (process 1 token, history ):
- — number of currently cached positions ()
- Single step processes 1 new token, so linear layers’ becomes ; attention still scans the full cache and grows linearly with it
Total complexity to generate tokens:
- — number of generated tokens (top-of-article convention)
- — average attention span (prompt + generated segment)
- First term sums linear layers across decode steps; second term approximates the attention portion summed over (exact form is )
Without KV cache: — a massive difference.
Reading FLOPs and bandwidth together is even more illuminating — that’s what the previous section’s table shows: prefill challenges the compute ceiling; decode challenges the bandwidth ceiling; continuous batching’s point is to fuse requests’ decodes into one large GEMV, amortizing the weight-fetch cost across requests, with throughput rising linearly until compute or the attention portion becomes the bottleneck.
How Engineering Optimizations Plug into the Formulas — Flash Attn / Spec Decode / Continuous Batch
Flash Attention: mathematically equivalent to standard attention — the formulas don’t change a single character. Engineering-wise it fuses softmax and matmul into one kernel, updating softmax’s running statistics (max, sum) in a streaming manner over blocks, avoiding writing the attention matrix back to HBM. Complexity unchanged; memory drops from to ; speedup comes mainly from reduced HBM access. FA-2 shifted the partition granularity from heads to query blocks; FA-3 on H100/H200 adds warpgroup MMA + producer-consumer async pipelining.
Flash Decoding: at decode, Flash Attention’s Q has only one row, so kernel parallelism is too low. Flash Decoding splits the dimension of K and V into chunks for parallelism and then does a final log-sum-exp reduction. The formula is the same softmax, just split into two passes.
Speculative Decoding: a small “draft” model generates tokens sequentially, then the large model verifies them with one prefill over the positions. Acceptance rule:
- — a candidate token produced by the draft model
- — probability the large (target) model assigns to at that position
- — probability the small (draft) model assigns to at the same position
- Combined with “on rejection, resample from ”, this rule provably yields the same sampling distribution as direct target-model decoding — zero quality loss
The crux is fusing decode GEMVs into one -length GEMM, turning the large model’s memory-bound regime back into compute-bound. With expected accepted tokens per step, throughput scales by (minus draft overhead). Variants: Medusa (multi-head prediction), EAGLE (feature-level draft), Lookahead Decoding (no draft model).
Continuous Batching (vLLM, TGI): instead of padding prefill to align at boundaries, schedule at the per-step request level. Each step picks a batch of requests in the same phase (prefill or decode), releases finished ones. Mathematically each request is independent; only the ordering changes. Original paper: OSDI’22’s Orca.
Chunked Prefill: split long prompts’ prefill into chunks and mix them with decode requests in the same step, reducing decode latency jitter. No formula changes. The core scheduling primitive of SARATHI / DistServe.
One Sentence Spanning the Whole Process — From Token ID to Next Token
Input token IDs → look up embedding → through layers (Pre-RMSNorm → Attention with RoPE → residual → Pre-RMSNorm → SwiGLU FFN → residual) → Final RMSNorm → LM Head → logits → sampling. In prefill, tokens pass in parallel, producing the first token + full KV Cache; in decode, each step inputs 1 token; at attention it reads historical K and V from the cache, while all other operations are per-token independent.
Key Engineering Invariants — 6 Rules Worth Memorizing
- The main-line tensor shape is always . Residual structure preserves the dimension; whenever appears different somewhere, either it’s spread into heads inside attention, or lifted to inside FFN, and back to on exit.
- K and V, once computed, never change. Because they’re linear projections applied to the already-fixed input , and the causal structure ensures later positions cannot reach back to modify earlier representations. This is the mathematical basis for KV Cache.
- Attention is the only cross-token operation; all others (norm, projection, FFN, activation) are per-token independent. So only K and V — the inputs to cross-token operations — need caching; everything else can be computed and discarded immediately.
- Decode’s scores shape is . The “1” is the Q side (the current new token), and the dim is eliminated when weighted-summing with , returning to one row.
- Decode’s non-attention compute per step is constant; only attention grows linearly with cache length. So the true reason “generation gets slower over time” is that attention’s keeps growing, plus KV Cache pushing the memory footprint against HBM bandwidth limits.
- Prefill uses GEMM; decode uses GEMV. This one-letter difference dictates that every inference engine has two kernel sets, two scheduling strategies. Internalize this and no inference-optimization paper will lose you.
Internalize these six and you’ll find that Flash Attention, PagedAttention, MLA, speculative decoding — they’re all local optimizations at some spot on this skeleton, while the skeleton itself has barely changed in twenty years.
References — Formulas · Papers · Engineering Blogs
Architecture and Core Operators
- Vaswani et al., “Attention Is All You Need” (NeurIPS 2017) — the original Transformer paper. arxiv.org/abs/1706.03762
- Shazeer, “GLU Variants Improve Transformer” (2020) — source of SwiGLU / GeGLU / ReGLU. arxiv.org/abs/2002.05202
- Zhang & Sennrich, “Root Mean Square Layer Normalization” (NeurIPS 2019) — RMSNorm. arxiv.org/abs/1910.07467
- Hendrycks & Gimpel, “Gaussian Error Linear Units (GELUs)” (2016) — GeLU definition and approximation. arxiv.org/abs/1606.08415
Positional Encoding
- Su et al., “RoFormer: Enhanced Transformer with Rotary Position Embedding” (2021) — RoPE. arxiv.org/abs/2104.09864
- Peng et al., “YaRN: Efficient Context Window Extension of Large Language Models” (2023) — long-context RoPE scaling. arxiv.org/abs/2309.00071
- bloc97 & emozilla, “NTK-Aware Scaled RoPE” — discussion of early NTK-aware open-source work. reddit / LocalLLaMA
Multi-Head Variants (KV Sharing / Compression)
- Shazeer, “Fast Transformer Decoding: One Write-Head is All You Need” (2019) — MQA. arxiv.org/abs/1911.02150
- Ainslie et al., “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints” (EMNLP 2023) — GQA. arxiv.org/abs/2305.13245
- DeepSeek-AI, “DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model” (2024) — MLA introduced. arxiv.org/abs/2405.04434
- DeepSeek-AI, “DeepSeek-V3 Technical Report” (2024) — MLA + MoE engineering. arxiv.org/abs/2412.19437
Mixture-of-Experts
- Shazeer et al., “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer” (ICLR 2017) — foundational paper for top- gated MoE in deep nets. arxiv.org/abs/1701.06538
- Lepikhin et al., “GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding” (ICLR 2021) — the standard aux-loss load-balancing recipe. arxiv.org/abs/2006.16668
- Fedus et al., “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity” (JMLR 2022) — top-1 + capacity factor. arxiv.org/abs/2101.03961
- Du et al., “GLaM: Efficient Scaling of Language Models with Mixture-of-Experts” (ICML 2022) — Google’s 1.2 T MoE. arxiv.org/abs/2112.06905
- Jiang et al., “Mixtral of Experts” (Mistral AI, 2024) — Mixtral 8×7B tech report. arxiv.org/abs/2401.04088
- DeepSeek-AI, “DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Models” (2024) — fine-grained + shared-expert paradigm. arxiv.org/abs/2401.06066
- Wang et al., “Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts” (2024) — the load-balancing approach used in DeepSeek V3. arxiv.org/abs/2408.15664
- Meta AI, “The Llama 4 herd” (2025) — Llama 4 Scout / Maverick / Behemoth technical details. ai.meta.com/blog/llama-4
Flash Attention Series
- Dao et al., “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness” (NeurIPS 2022) — FA-1. arxiv.org/abs/2205.14135
- Dao, “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning” (2023) — FA-2. arxiv.org/abs/2307.08691
- Shah et al., “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-Precision” (NeurIPS 2024) — FA-3 / Hopper. arxiv.org/abs/2407.08608
- Dao et al., “Flash-Decoding for long-context inference” (Stanford / Together blog, 2023) — Flash Decoding. crfm.stanford.edu
Inference Engines and Serving Schedulers
- Kwon et al., “Efficient Memory Management for Large Language Model Serving with PagedAttention” (SOSP 2023) — vLLM / PagedAttention. arxiv.org/abs/2309.06180
- Yu et al., “Orca: A Distributed Serving System for Transformer-Based Generative Models” (OSDI 2022) — original Continuous Batching paper. usenix.org/osdi22
- Agrawal et al., “SARATHI: Efficient LLM Inference by Piggybacking Decodes with Chunked Prefills” (2023) — chunked prefill. arxiv.org/abs/2308.16369
- Zhong et al., “DistServe: Disaggregating Prefill and Decoding for Goodput-optimized Large Language Model Serving” (OSDI 2024) — prefill/decode disaggregation. arxiv.org/abs/2401.09670
- vLLM main repo (PagedAttention engineering implementation). github.com/vllm-project/vllm
- Hugging Face Text Generation Inference (TGI). github.com/huggingface/text-generation-inference
- NVIDIA TensorRT-LLM documentation (FA / in-flight batching). nvidia.github.io/TensorRT-LLM
Speculative Decoding Family
- Leviathan, Kalman, Matias, “Fast Inference from Transformers via Speculative Decoding” (ICML 2023). arxiv.org/abs/2211.17192
- Chen et al., “Accelerating Large Language Model Decoding with Speculative Sampling” (DeepMind, 2023) — concurrent independent work. arxiv.org/abs/2302.01318
- Cai et al., “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads” (2024). arxiv.org/abs/2401.10774
- Li et al., “EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty” (ICML 2024). arxiv.org/abs/2401.15077
- Fu et al., “Lookahead Decoding: Breaking the Sequential Dependency of LLM Inference” (2024). arxiv.org/abs/2402.02057
KV Cache Compression / Quantization
- Xiao et al., “Efficient Streaming Language Models with Attention Sinks” (ICLR 2024) — StreamingLLM. arxiv.org/abs/2309.17453
- Zhang et al., “H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models” (NeurIPS 2023). arxiv.org/abs/2306.14048
- Li et al., “SnapKV: LLM Knows What You are Looking for Before Generation” (2024). arxiv.org/abs/2404.14469
- Liu et al., “KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache” (ICML 2024). arxiv.org/abs/2402.02750
- Hooper et al., “KVQuant: Towards 10 Million Context Length LLM Inference with KV Cache Quantization” (NeurIPS 2024). arxiv.org/abs/2401.18079
Representative Open-Source Model Technical Reports
- Meta AI, “The Llama 3 Herd of Models” (2024) — Llama 3 family. arxiv.org/abs/2407.21783
- Jiang et al., “Mistral 7B” (2023) — Sliding Window Attention. arxiv.org/abs/2310.06825
- Qwen Team, “Qwen2.5 Technical Report” (2024). arxiv.org/abs/2412.15115
Hardware / Roofline
- NVIDIA, “H100 Tensor Core GPU Architecture Whitepaper” (2022). resources.nvidia.com
- Williams, Waterman, Patterson, “Roofline: An Insightful Visual Performance Model for Multicore Architectures” (CACM 2009) — original Roofline paper. dl.acm.org
- Hoffmann et al., “Training Compute-Optimal Large Language Models” (Chinchilla, 2022) — training-side version of the FLOPs rule of thumb. arxiv.org/abs/2203.15556
Other Long-Form Articles / Tutorials
- Lilian Weng, “The Transformer Family Version 2.0”. lilianweng.github.io
- Horace He, “Making Deep Learning Go Brrrr From First Principles” — three bottleneck classes: compute / memory / overhead. horace.io/brrr_intro
- Adam Casson, “Transformer Inference Arithmetic” — mental math for inference FLOPs / KV cache. kipp.ly
- Jay Mody, “LLM inference, in detail” — another take on tensor shape flow. jaykmody.com