LLM Inference Walkthrough — Tensor Shapes and Core Formulas Across the Whole Pipeline

Lay out a Llama 3-style dense decoder-only model end to end from embedding to sampling, embedding the math positions of common variants (MHA / MQA / GQA / MLA, RMSNorm / LayerNorm, SwiGLU / GeGLU, Flash Attention / Paged Attention) along the way — so that after reading you can draw this diagram from memory. One fact frames the whole article: this architecture has barely changed in twenty years; every “inference optimization” is a local surgery somewhere on the same skeleton.

Notation Conventions — Llama 3 8B as the Reference

The same notation is used throughout, with Llama 3 8B as the concrete example:

SymbolMeaningExample value (Llama 3 8B)
BBbatch size2
SSprompt length10
LLnumber of layers32
HHhidden dim4096
VVvocab size128256
nqn_qnumber of Q heads32
nkvn_{kv}number of KV heads (GQA)8
ddper-head dim =H/nq= H/n_q128
IIFFN intermediate dim14336
TTnumber of generated tokens100
ttcurrent decode step1..T1..T

All shape annotations follow PyTorch convention, written as [B, ..., H]; weight matrices follow “in × out”, written as WR[in,out]W \in \mathbb{R}^{[\text{in}, \text{out}]}.

Core Formulas Quick Reference — Embedding · Norm · Attn · FFN · LM Head

Embedding

xi=E[token_idi]RH\mathbf{x}_i = E[\text{token\_id}_i] \in \mathbb{R}^{H}
  • xi\mathbf{x}_i — embedding vector at position ii
  • ERV×HE \in \mathbb{R}^{V \times H} — embedding lookup table (VV tokens, each HH-dim)
  • token_idi{0,1,,V1}\text{token\_id}_i \in \{0, 1, \ldots, V-1\} — integer ID of the ii-th input token

Many implementations share EE with the LM Head’s WlmW_{\text{lm}} (tied embedding), saving memory and providing mild regularization; Llama models do not share by default.

Normalization: LayerNorm vs RMSNorm

Standard LayerNorm:

LN(x)=xμσ2+ϵγ+β,μ=1Hxi, σ2=1H(xiμ)2\text{LN}(\mathbf{x}) = \frac{\mathbf{x} - \mu}{\sqrt{\sigma^{2} + \epsilon}} \odot \boldsymbol{\gamma} + \boldsymbol{\beta}, \quad \mu = \tfrac{1}{H}\sum x_i,\ \sigma^{2} = \tfrac{1}{H}\sum (x_i - \mu)^{2}
  • xRH\mathbf{x} \in \mathbb{R}^{H} — single-token hidden state
  • μ,σ2R\mu, \sigma^{2} \in \mathbb{R} — mean and variance computed across the HH dimensions
  • ϵ\epsilon — small constant to avoid divide-by-zero (typically 10510^{-5})
  • γ,βRH\boldsymbol{\gamma}, \boldsymbol{\beta} \in \mathbb{R}^{H} — learnable scale / shift
  • \odot — element-wise multiplication

RMSNorm (the mainstream choice for Llama / Mistral / Qwen):

RMSNorm(x)=x1Hi=1Hxi2+ϵγ\text{RMSNorm}(\mathbf{x}) = \frac{\mathbf{x}}{\sqrt{\frac{1}{H}\sum_{i=1}^{H} x_i^{2} + \epsilon}} \odot \boldsymbol{\gamma}
  • xRH\mathbf{x} \in \mathbb{R}^{H} — single-token hidden state
  • γRH\boldsymbol{\gamma} \in \mathbb{R}^{H} — learnable per-channel scale (no β\boldsymbol{\beta})
  • ϵ\epsilon — small constant to avoid divide-by-zero
  • The denominator is the root mean square of x\mathbf{x} — hence “RMSNorm”

RMSNorm drops the mean and β\boldsymbol{\beta}, roughly halving both compute and parameters; empirically the quality cost is nearly zero. All mainstream inference engines organize this as pre-norm: norm lives inside the residual branch, and the residual main path bypasses it.

Q/K/V Projection + Positional Encoding

Q=XWQ,K=XWK,V=XWVQ = X W_Q, \quad K = X W_K, \quad V = X W_V
  • XRB×S×HX \in \mathbb{R}^{B \times S \times H} — output of the previous RMSNorm
  • WQRH×nqdW_Q \in \mathbb{R}^{H \times n_q d} — Q projection weight
  • WK,WVRH×nkvdW_K, W_V \in \mathbb{R}^{H \times n_{kv} d} — K, V projection weights (narrower under GQA since nkv<nqn_{kv} < n_q)
  • QRB×S×nqdQ \in \mathbb{R}^{B \times S \times n_q d}, K,VRB×S×nkvdK, V \in \mathbb{R}^{B \times S \times n_{kv} d} — projection outputs, later reshaped to expose the head dimension

RoPE (Rotary Positional Embedding) core idea: rather than adding “absolute position mm” to the embedding (like the original Transformer’s sinusoidal), let position act as a rotation on Q and K, so that when the two vectors are dotted only the relative position nmn - m survives.

Split each head’s dd dimensions into d/2d/2 two-dim subspaces; the kk-th subspace (k=0,1,,d/21k = 0, 1, \ldots, d/2 - 1) corresponds to coordinates (q2k,q2k+1)(q_{2k}, q_{2k+1}). At position mm that pair is multiplied by a 2D rotation matrix with angle mθkm\theta_k:

(q2k(m)q2k+1(m))=(cos(mθk)sin(mθk)sin(mθk)cos(mθk))(q2kq2k+1),θk=base2k/d\begin{pmatrix} q'^{(m)}_{2k} \\ q'^{(m)}_{2k+1} \end{pmatrix} = \begin{pmatrix} \cos(m\theta_k) & -\sin(m\theta_k) \\ \sin(m\theta_k) & \phantom{-}\cos(m\theta_k) \end{pmatrix} \begin{pmatrix} q_{2k} \\ q_{2k+1} \end{pmatrix}, \quad \theta_k = \text{base}^{-2k/d}
  • m{0,1,,S1}m \in \{0, 1, \ldots, S-1\} — absolute position of the current token
  • k{0,1,,d/21}k \in \{0, 1, \ldots, d/2 - 1\} — 2D subspace index (every two dims of the per-head dim dd form one pair)
  • (q2k,q2k+1)(q_{2k}, q_{2k+1}) — the kk-th 2D pair of the projected Q vector
  • (q2k(m),q2k+1(m))(q'^{(m)}_{2k}, q'^{(m)}_{2k+1}) — the rotated pair at position mm
  • θk\theta_k — base angular velocity of the kk-th subspace
  • base\text{base} — frequency-decay base (Llama default 1000010000; YaRN etc. dynamically enlarge it)

Expanded to scalars:

q2k(m)=q2kcos(mθk)q2k+1sin(mθk),q2k+1(m)=q2ksin(mθk)+q2k+1cos(mθk)q'^{(m)}_{2k} = q_{2k}\cos(m\theta_k) - q_{2k+1}\sin(m\theta_k), \qquad q'^{(m)}_{2k+1} = q_{2k}\sin(m\theta_k) + q_{2k+1}\cos(m\theta_k)

Same symbols as above — this just unrolls the 2×22\times 2 matrix into two scalar identities, easier to map to code.

This is the standard counterclockwise rotation matrix R(ϕ)=(cosϕsinϕsinϕcosϕ)R(\phi) = \begin{pmatrix} \cos\phi & -\sin\phi \\ \sin\phi & \phantom{-}\cos\phi \end{pmatrix}; acting on a 2D vector is equivalent to multiplying by eiϕe^{i\phi} in the complex plane — which is why the official Llama repo reshapes (q2k,q2k+1)(q_{2k}, q_{2k+1}) to complex64 and multiplies by cos(mθk)+isin(mθk)\cos(m\theta_k) + i\sin(m\theta_k) directly (HuggingFace equivalently pairs (qk,qk+d/2)(q_k, q_{k+d/2}) in a “half-rotation” form; the two differ only by a coordinate permutation). K rotates by the same angle to get k(n)\mathbf{k}'^{(n)}.

Why does rotation encode relative position? Let Rm\mathbf{R}_m denote the block-diagonal full-dd rotation (the kk-th 2×22\times 2 block uses angle mθkm\theta_k). It’s orthogonal and satisfies RmRn=Rnm\mathbf{R}_m^{\top}\mathbf{R}_n = \mathbf{R}_{n-m}, so:

Rmq,  Rnk=qRmRnk=qRnmk\langle \mathbf{R}_m \mathbf{q},\; \mathbf{R}_n \mathbf{k} \rangle = \mathbf{q}^{\top} \mathbf{R}_m^{\top} \mathbf{R}_n \mathbf{k} = \mathbf{q}^{\top} \mathbf{R}_{n-m} \mathbf{k}
  • q,kRd\mathbf{q}, \mathbf{k} \in \mathbb{R}^{d} — per-head Q, K vectors (pre-rotation)
  • Rm,RnRd×d\mathbf{R}_m, \mathbf{R}_n \in \mathbb{R}^{d \times d} — block-diagonal rotation matrices for positions mm, nn
  • ,\langle \cdot, \cdot \rangle — standard inner product
  • Last step uses RmRn=Rnm\mathbf{R}_m^{\top} \mathbf{R}_n = \mathbf{R}_{n-m}: rotations are orthogonal and angles are additive

When attention computes QKQK^{\top}, every (qi,kj)(q_i, k_j) score depends only on the difference jij - i — absolute position is automatically canceled inside the dot product, leaving relative position. This is RoPE’s key property and the fundamental reason it is more stable than additive positional encodings.

Frequency spectrum design: θk=base2k/d\theta_k = \text{base}^{-2k/d} (base\text{base} usually 1000010000; dd is the per-head dim, not model hidden dim) assigns the d/2d/2 subspaces angular velocities from fast to slow:

  • k=0k = 0: θ0=1\theta_0 = 1, period 2π6.282\pi \approx 6.28 tokens — carries the “near-neighbor” signal.
  • k=d/21k = d/2 - 1: θbase(d2)/d104\theta \approx \text{base}^{-(d-2)/d} \approx 10^{-4}, period 2π1000062K\approx 2\pi \cdot 10000 \approx 62\text{K} tokens — carries the “long-range” signal.

The geometric distribution lets one head carry positional signals at many scales simultaneously — isomorphic to the original Transformer’s sinusoidal, just moved from addition to multiplication.

Long-context scaling: training only saw mLtrainm \le L_{\text{train}}, and low-frequency subspaces can’t even complete one full period inside LtrainL_{\text{train}}; once inference reaches m>Ltrainm > L_{\text{train}}, the low-frequency angles fall outside the training distribution and attention immediately degrades. Three common fixes all modify θk\theta_k:

  • Position Interpolation (Chen et al. 2023): mm/sm \to m/s, equivalent to θkθk/s\theta_k \to \theta_k/s — compresses all frequencies uniformly. Simple but wastes high-frequency precision.
  • NTK-aware scaling: scales low frequencies while preserving high ones, equivalent to basebasesd/(d2)\text{base} \to \text{base} \cdot s^{d/(d-2)}.
  • YaRN (Peng et al. 2023): band-wise treatment — high frequencies (period Ltrain\ll L_{\text{train}}, full periods seen during training) are left alone, low frequencies (period Ltrain\gg L_{\text{train}}, no full periods seen) are PI-scaled, and the mid band interpolates smoothly; additionally a 1/t1/\sqrt{t} temperature correction cancels the attention-entropy drift introduced by the scaling. Llama 3.1 / 3.2 went from 8K → 128K with YaRN.

RoPE acts only on Q and K, not on V — V is the weighted value itself and needs no positional signal.

Scaled Dot-Product Attention

Attention(Q,K,V)=softmax ⁣(QKd+M)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^{\top}}{\sqrt{d}} + M\right) V
  • Q,K,VQ, K, V — projected and head-split tensors, with the last two dims being [S,d][S, d] (head dim broadcast implicit)
  • QKRS×SQK^{\top} \in \mathbb{R}^{S \times S} — similarity matrix between every (qi,kj)(q_i, k_j) pair
  • d\sqrt{d} — scaling factor that keeps softmax out of the near-zero-gradient region
  • MRS×SM \in \mathbb{R}^{S \times S} — causal mask with upper triangle set to -\infty (position ii can only see i\le i)
  • softmax normalizes along the K dim, producing attention weights that then re-weight V back to [S,d][S, d]

Multi-Head Variants: MHA / MQA / GQA / MLA

All four variants differ only on the K, V side: Q is always nqn_q independent heads; what changes is how many independent K, V sets exist, and whether K, V are low-rank compressed. Below we write head-level formulas for original MHA and each of the three variants under one notation — dropping the batch and layer dims, looking at head hh at position ii.

Variantnkvn_{kv}Per-token cache (per layer, fp16)Representative models
MHA=nq= n_q2nqd2B2 \cdot n_q \cdot d \cdot 2\text{B}GPT-2/3, Llama 1/2 7B
GQAgrouped <nq< n_q2nkvd2B2 \cdot n_{kv} \cdot d \cdot 2\text{B}Llama 3, Qwen2, Mistral
MQA=1= 12d2B2 \cdot d \cdot 2\text{B}PaLM, Falcon
MLAlow-rank compressed(dc+dr)2B(d_c + d_r) \cdot 2\text{B}DeepSeek V2/V3

MHA — Multi-Head Attention (original, Vaswani et al. 2017)

Each Q head gets its own K, V — nqn_q independent sets:

qi(h)=WQ(h)xi,ki(h)=WK(h)xi,vi(h)=WV(h)xi,h=1,,nq\mathbf{q}^{(h)}_i = W_Q^{(h)} \mathbf{x}_i, \quad \mathbf{k}^{(h)}_i = W_K^{(h)} \mathbf{x}_i, \quad \mathbf{v}^{(h)}_i = W_V^{(h)} \mathbf{x}_i, \quad h = 1, \ldots, n_q
head(h)=softmax ⁣(Q(h)K(h)d+M)V(h),Attn=[head(1);;head(nq)]WO\text{head}^{(h)} = \text{softmax}\!\left(\frac{Q^{(h)} {K^{(h)}}^{\top}}{\sqrt{d}} + M\right) V^{(h)}, \qquad \text{Attn} = [\text{head}^{(1)}; \ldots; \text{head}^{(n_q)}] \cdot W_O
  • WQ(h),WK(h),WV(h)RH×dW_Q^{(h)}, W_K^{(h)}, W_V^{(h)} \in \mathbb{R}^{H \times d} — per-head Q, K, V projections (concatenated into a big matrix in practice, mathematically equivalent)
  • Each token writes nqn_q pairs (k(h),v(h))(\mathbf{k}^{(h)}, \mathbf{v}^{(h)}) into the cache — 2nqd2 n_q d scalars in total
  • Llama 2 7B: nq=32,d=128n_q = 32, d = 128, per-token cache = 2321282B=16 KB2 \cdot 32 \cdot 128 \cdot 2\text{B} = 16\text{ KB} / layer

MQA — Multi-Query Attention (Shazeer 2019)

All nqn_q Q heads share a single K, V — only 1 set left:

qi(h)=WQ(h)xi,ki=WKxi,vi=WVxi\mathbf{q}^{(h)}_i = W_Q^{(h)} \mathbf{x}_i, \quad \mathbf{k}_i = W_K \mathbf{x}_i, \quad \mathbf{v}_i = W_V \mathbf{x}_i
head(h)=softmax ⁣(Q(h)Kd+M)V,h=1,,nq\text{head}^{(h)} = \text{softmax}\!\left(\frac{Q^{(h)} K^{\top}}{\sqrt{d}} + M\right) V, \qquad h = 1, \ldots, n_q
  • WK,WVRH×dW_K, W_V \in \mathbb{R}^{H \times d} — one shared K, V projection across all heads
  • Cache shrinks to 2d2dnq×n_q\times smaller than MHA; but capacity-constrained, large models drop quality if MQA is applied directly — rarely used standalone in practice, mostly superseded by GQA

GQA — Grouped-Query Attention (Ainslie et al. 2023)

Split the nqn_q Q heads into nkvn_{kv} groups, share K, V within each group — a continuous interpolation between MHA and MQA:

qi(h)=WQ(h)xi,ki(g)=WK(g)xi,vi(g)=WV(g)xi\mathbf{q}^{(h)}_i = W_Q^{(h)} \mathbf{x}_i, \quad \mathbf{k}^{(g)}_i = W_K^{(g)} \mathbf{x}_i, \quad \mathbf{v}^{(g)}_i = W_V^{(g)} \mathbf{x}_i
head(h)=softmax ⁣(Q(h)K(g(h))d+M)V(g(h)),g(h)=hnkv/nq\text{head}^{(h)} = \text{softmax}\!\left(\frac{Q^{(h)} {K^{(g(h))}}^{\top}}{\sqrt{d}} + M\right) V^{(g(h))}, \qquad g(h) = \lfloor h \cdot n_{kv} / n_q \rfloor
  • WK(g),WV(g)RH×dW_K^{(g)}, W_V^{(g)} \in \mathbb{R}^{H \times d}, g=1,,nkvg = 1, \ldots, n_{kv} — one K, V set per group
  • g(h)g(h) — group index of Q head hh
  • nkv=nqn_{kv} = n_q recovers MHA, nkv=1n_{kv} = 1 recovers MQA
  • Kernels materialize only nkvn_{kv} K, V sets; during score computation, K and V are broadcast along the group dim up to nqn_q, not actually replicated
  • Llama 3 70B: nq=64,nkv=8,d=128n_q = 64, n_{kv} = 8, d = 128, per-token cache = 281282B=4 KB2 \cdot 8 \cdot 128 \cdot 2\text{B} = 4\text{ KB} / layer — 8× smaller than a same-shape MHA

MLA — Multi-Head Latent Attention (DeepSeek V2 2024)

GQA only shrinks cache linearly in head count; MLA jointly compresses K and V into a low-rank latent cKV\mathbf{c}^{KV}, and breaks RoPE off into a shared shallow branch. Four steps:

(1) Content branch — K, V share a down-projection; only the latent ciKV\mathbf{c}^{KV}_i goes into the cache:

ciKV=WDKVxi,kiC,(h)=WKU,(h)ciKV,vi(h)=WVU,(h)ciKV\mathbf{c}^{KV}_i = W^{DKV} \mathbf{x}_i, \quad \mathbf{k}^{C,(h)}_i = W_K^{U,(h)} \mathbf{c}^{KV}_i, \quad \mathbf{v}^{(h)}_i = W_V^{U,(h)} \mathbf{c}^{KV}_i

(2) Q-side low-rank — saves training memory; Q is never cached at inference:

ciQ=WDQxi,qiC,(h)=WQU,(h)ciQ\mathbf{c}^{Q}_i = W^{DQ} \mathbf{x}_i, \quad \mathbf{q}^{C,(h)}_i = W_Q^{U,(h)} \mathbf{c}^{Q}_i

(3) Decoupled RoPE branch — K side computes one shared drd_r-dim vector reused by all heads:

qiR,(h)=RoPE(WQR,(h)ciQ),kiR=RoPE(WKRxi)\mathbf{q}^{R,(h)}_i = \text{RoPE}(W_Q^{R,(h)} \mathbf{c}^{Q}_i), \quad \mathbf{k}^{R}_i = \text{RoPE}(W^{KR} \mathbf{x}_i)

(4) Concat + attention — content and RoPE parts are concatenated along the head dim before scoring:

qi(h)=[qiC,(h);qiR,(h)],ki(h)=[kiC,(h);kiR]\mathbf{q}^{(h)}_i = [\mathbf{q}^{C,(h)}_i; \mathbf{q}^{R,(h)}_i], \quad \mathbf{k}^{(h)}_i = [\mathbf{k}^{C,(h)}_i; \mathbf{k}^{R}_i]
head(h)=softmax ⁣(qi(h)ki(h)d+dr+M)vi(h)\text{head}^{(h)} = \text{softmax}\!\left(\frac{\mathbf{q}^{(h)\top}_i \mathbf{k}^{(h)}_{\le i}}{\sqrt{d + d_r}} + M\right) \mathbf{v}^{(h)}_{\le i}
  • WDKVRH×dcW^{DKV} \in \mathbb{R}^{H \times d_c} — shared KV down-projection (DeepSeek V3 uses dc=512d_c = 512)
  • WKU,(h),WVU,(h)Rdc×dW_K^{U,(h)}, W_V^{U,(h)} \in \mathbb{R}^{d_c \times d} — per-head up-projections
  • WKRRH×drW^{KR} \in \mathbb{R}^{H \times d_r} — shared RoPE K branch (DeepSeek V3 uses dr=64d_r = 64)
  • ciKVRdc,kiRRdr\mathbf{c}^{KV}_i \in \mathbb{R}^{d_c}, \mathbf{k}^{R}_i \in \mathbb{R}^{d_r}these two are all that MLA actually cachesdc+drd_c + d_r scalars
  • DeepSeek V3: nq=128,d=128,dc=512,dr=64n_q = 128, d = 128, d_c = 512, d_r = 64, per-token cache = (512+64)2B=1152(512 + 64) \cdot 2\text{B} = 1152 bytes / layer — another order of magnitude below a same-scale GQA

Why does RoPE need its own branch? At inference there’s a “weight-folding” trick — by associativity, fold WKU,(h)W_K^{U,(h)} into WQU,(h)W_Q^{U,(h)}:

qiC,(h)kjC,(h)=ciQ(WQU,(h))WKU,(h)pre-multiplied offlinecjKV{\mathbf{q}^{C,(h)}_i}^{\top} \mathbf{k}^{C,(h)}_j = {\mathbf{c}^{Q}_i}^{\top} \underbrace{(W_Q^{U,(h)})^{\top} W_K^{U,(h)}}_{\text{pre-multiplied offline}} \mathbf{c}^{KV}_j

K’s content part is never actually reconstructed — attention runs directly on the cached cKV\mathbf{c}^{KV}. But RoPE’s rotation angle depends on absolute position mm, so it cannot be folded into fixed weights — applying RoPE on reconstructed K kills the fold. DeepSeek’s fix: pull RoPE out into an independent drd_r-dim shallow branch, so “foldable content” and “must-rotate-live position” don’t interfere. Cache compression and relative-position signal both survive.

Overview: Params / FLOPs / KV cache

Below we decompose each variant’s cost step by step. Conventions: one attention sublayer per layer, prefill length SS, ignoring norm / bias / softmax and other non-matmul terms; a linear projection XWX W (with XRS×m,WRm×nX \in \mathbb{R}^{S \times m}, W \in \mathbb{R}^{m \times n}) counts as 2Smn2Smn FLOPs (2 FLOPs per MAC); QKQK^{\top} and AVAV do not deduct the causal-mask triangular half. MLA additionally uses dqd_q' = Q-side latent dim (DeepSeek V3 uses 15361536).

Parameters (per layer)

StepMHAMQAGQAMLA
Q projHnqdH \cdot n_q dHnqdH \cdot n_q dHnqdH \cdot n_q dHdq+dqnqdH d_q' + d_q' \cdot n_q d
K projHnqdH \cdot n_q dHdH \cdot dHnkvdH \cdot n_{kv} dHdc+dcnqdH d_c + d_c \cdot n_q d
V projHnqdH \cdot n_q dHdH \cdot dHnkvdH \cdot n_{kv} ddcnqdd_c \cdot n_q d (shares WDKVW^{DKV} with K)
RoPE branchdqnqdr+Hdrd_q' \cdot n_q d_r + H d_r
WOW_OnqdHn_q d \cdot HnqdHn_q d \cdot HnqdHn_q d \cdot HnqdHn_q d \cdot H
Total4Hnqd4 H n_q d2Hnqd+2Hd2 H n_q d + 2 H d2Hnqd+2Hnkvd2 H n_q d + 2 H n_{kv} dsum of rows above

Prefill FLOPs (per layer, sequence length SS)

StepMHAMQAGQAMLA
Q proj2SHnqd2 S H n_q d2SHnqd2 S H n_q d2SHnqd2 S H n_q d2SHdq+2Sdqnqd2 S H d_q' + 2 S d_q' n_q d
K proj2SHnqd2 S H n_q d2SHd2 S H d2SHnkvd2 S H n_{kv} d2SHdc+2Sdcnqd2 S H d_c + 2 S d_c n_q d
V proj2SHnqd2 S H n_q d2SHd2 S H d2SHnkvd2 S H n_{kv} d2Sdcnqd2 S d_c n_q d
RoPE branchO(Snqd)\mathcal{O}(S n_q d)O(Snqd)\mathcal{O}(S n_q d)O(Snqd)\mathcal{O}(S n_q d)2Sdqnqdr+2SHdr2 S d_q' n_q d_r + 2 S H d_r
QKQK^{\top}2nqS2d2 n_q S^2 d2nqS2d2 n_q S^2 d2nqS2d2 n_q S^2 d2nqS2(d+dr)2 n_q S^2 (d + d_r)
softmax · VV2nqS2d2 n_q S^2 d2nqS2d2 n_q S^2 d2nqS2d2 n_q S^2 d2nqS2d2 n_q S^2 d
WOW_O2SHnqd2 S H n_q d2SHnqd2 S H n_q d2SHnqd2 S H n_q d2SHnqd2 S H n_q d

“RoPE branch” in the MHA/MQA/GQA columns means the elementwise rotation applied to Q, K — O(Snqd)\mathcal{O}(S n_q d), negligible next to matmuls. In MLA it refers specifically to the extra WQRW_Q^R / WKRW^{KR} projections, which are genuine matmuls and must be counted.

KV cache (per token, per layer, fp16)

VariantWhat’s cachedBytes
MHAnqn_q pairs of (k,v)Rd(\mathbf{k}, \mathbf{v}) \in \mathbb{R}^d2nqd2B2 \cdot n_q \cdot d \cdot 2\text{B}
MQA1 pair (k,v)(\mathbf{k}, \mathbf{v})2d2B2 \cdot d \cdot 2\text{B}
GQAnkvn_{kv} pairs (k,v)(\mathbf{k}, \mathbf{v})2nkvd2B2 \cdot n_{kv} \cdot d \cdot 2\text{B}
MLAcKVRdc\mathbf{c}^{KV} \in \mathbb{R}^{d_c} + kRRdr\mathbf{k}^R \in \mathbb{R}^{d_r}(dc+dr)2B(d_c + d_r) \cdot 2\text{B}

Three observations from reading these tables horizontally:

  1. All variants only touch the K/V side. Q projection, WOW_O, and AVAV are identical across the four — the surgery only touches K, V, so total attention params and FLOPs never differ by an order of magnitude at the same model size.
  2. MHA → MQA/GQA saves params, FLOPs, and cache together; MLA trades params/FLOPs for cache. GQA shrinks nkvn_{kv}, scaling K/V projections, FLOPs, and cache down linearly. MLA does the opposite — adds DKV + two UK/UV stages + a dedicated RoPE branch, leaving params and prefill FLOPs at the same order as a same-size GQA, in exchange for shrinking per-token cache from KB scale down to ~1 KB.
  3. The S2S^2 term doesn’t differ across MHA/MQA/GQA. QKQK^{\top} and AVAV are both 2nqS2d2 n_q S^2 d (K, V sharing is broadcast only, never reducing the quadratic term). So once SH/dS \gg H/d, the three variants’ prefill FLOPs converge — what truly separates them is how many cache bytes each decoded token must read, which is memory I/O, not compute.

Output Projection + Residual

h=x+Attn(Q,K,V)WO\mathbf{h} = \mathbf{x} + \text{Attn}(Q', K', V') \cdot W_O
  • xRH\mathbf{x} \in \mathbb{R}^{H} — attention sublayer input (the value before pre-norm)
  • Q,K,VQ', K', V' — RoPE-rotated Q and K; V=VV' = V (V is not rotated, the prime is just notational uniformity)
  • WORnqd×HW_O \in \mathbb{R}^{n_q d \times H} — output projection that maps concatenated multi-head output back to HH
  • hRH\mathbf{h} \in \mathbb{R}^{H} — attention sublayer output + residual

FFN Variants

Classic Bilinear FFN (GPT-2)

FFN(x)=ϕ(xW1)W2\text{FFN}(\mathbf{x}) = \phi(\mathbf{x} W_1) W_2
  • xRH\mathbf{x} \in \mathbb{R}^{H} — FFN input
  • W1RH×IW_1 \in \mathbb{R}^{H \times I} — up-projection
  • W2RI×HW_2 \in \mathbb{R}^{I \times H} — down-projection
  • ϕ\phi — element-wise scalar nonlinearity (see below)

Activation Functions

The scalar nonlinearity applied after the up-projection. Input and output are both scalars xx; it’s applied independently per token, per hidden dim. Four functions cover all the mainstream choices used in the Transformer era.

ReLU (Nair & Hinton 2010)

ReLU(x)=max(0,x)\text{ReLU}(x) = \max(0, x)
  • Identity on x>0x > 0, zero on x0x \le 0
  • Cheapest to compute; but zero gradient on the negative side — the “dead neuron” problem
  • Used in the original Transformer and early BERT implementations

Sigmoid

σ(x)=11+ex\sigma(x) = \frac{1}{1 + e^{-x}}
  • Compresses R\mathbb{R} into (0,1)(0, 1), a natural “gate” signal — the original GLU’s ϕ\phi is exactly this
  • Largely abandoned as a standalone FFN activation — saturates at both ends, killing gradients

GeLU (Gaussian Error Linear Unit, Hendrycks & Gimpel 2016)

GeLU(x)=xΦ(x),Φ(x)=12(1+erf(x/2))\text{GeLU}(x) = x \cdot \Phi(x), \quad \Phi(x) = \tfrac{1}{2}\big(1 + \text{erf}(x/\sqrt{2})\big)

In practice, the OpenAI tanh approximation is used (numerical error <103< 10^{-3}, avoids the erf call):

GeLU(x)0.5x(1+tanh ⁣[2π(x+0.044715x3)])\text{GeLU}(x) \approx 0.5\, x \left(1 + \tanh\!\left[\sqrt{\tfrac{2}{\pi}}\left(x + 0.044715\, x^{3}\right)\right]\right)
  • Φ\Phi is the standard normal CDF — intuitively “let xx through weighted by its tail probability”
  • Everywhere differentiable, non-monotonic (a small negative dip on the x<0x < 0 side), smoother than ReLU
  • Used in GPT-2/3, BERT, ViT

SiLU / Swish (Ramachandran et al. 2017)

SiLU(x)=xσ(x)=x1+ex\text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}
  • Shape very close to GeLU (also smooth, non-monotonic, passes through origin) but with a simpler closed form — no erf\text{erf}, no cubic term
  • Self-gated: uses its own sigmoid to control signal throughput
  • Used by PaLM and the entire Llama family (as the gating ϕ\phi in SwiGLU)

Of these four, ReLU and sigmoid have largely exited mainstream FFNs; the active ones are GeLU (GPT era) and SiLU (Llama / PaLM and later). The bf16 numerical difference between the two is < 1%; paper choice is mostly path-dependent. The real engineering inflection point was switching activation from “applied to the projection” (classic FFN) to “applied to the gate” (the GLU family below).

GLU Family (used by Llama, PaLM, Mistral)

GLU(x)=(ϕ(xWgate)(xWup))Wdown\text{GLU}(\mathbf{x}) = \big(\phi(\mathbf{x} W_{\text{gate}}) \odot (\mathbf{x} W_{\text{up}})\big) W_{\text{down}}
  • xRH\mathbf{x} \in \mathbb{R}^{H} — FFN input

  • Wgate,WupRH×IW_{\text{gate}}, W_{\text{up}} \in \mathbb{R}^{H \times I} — two independent up-projections

  • WdownRI×HW_{\text{down}} \in \mathbb{R}^{I \times H} — down-projection

  • \odot — element-wise multiplication

  • ϕ\phi — gating activation; choice determines the variant name:

    Variantϕ\phiRepresentative model
    GLUσ\sigmaDauphin et al. 2017 original
    ReGLUReLU\text{ReLU}
    GeGLUGeLU\text{GeLU}T5 v1.1
    SwiGLUSiLU\text{SiLU}PaLM, Llama 1/2/3

SwiGLU has one more projection than the classic GeLU-FFN (three matrices vs two); to align parameter budget, implementations typically set I=234HI = \tfrac{2}{3} \cdot 4H (4H4H is the GPT-2 convention), and Llama 3 8B’s I=1433623440961.3I = 14336 \approx \tfrac{2}{3}\cdot 4\cdot 4096 \cdot 1.3.

Mixture-of-Experts (MoE)

In a classic dense FFN, every token passes through the same Wup/WdownW_{\text{up}} / W_{\text{down}} pair — all parameters used, all FLOPs paid. MoE (Shazeer et al. 2017) replicates the FFN NN times (“experts”) and routes each token through only the top kk, so total parameters scale linearly while per-token activated params and FLOPs stay almost flat — capacity grows without paying compute.

Formally the FFN sublayer becomes:

MoE(x)=iTk(x)gi(x)FFNi(x)\text{MoE}(\mathbf{x}) = \sum_{i \in \mathcal{T}_k(\mathbf{x})} g_i(\mathbf{x}) \cdot \text{FFN}_i(\mathbf{x})

The router (gating network) decides which experts each token visits:

s(x)=softmax(xWg),Tk(x)=TopK(s(x),k)\mathbf{s}(\mathbf{x}) = \text{softmax}(\mathbf{x} W_g), \quad \mathcal{T}_k(\mathbf{x}) = \text{TopK}(\mathbf{s}(\mathbf{x}), k)
gi(x)=si(x)1[iTk(x)]jTk(x)sj(x)g_i(\mathbf{x}) = \frac{s_i(\mathbf{x}) \cdot \mathbb{1}[i \in \mathcal{T}_k(\mathbf{x})]}{\sum_{j \in \mathcal{T}_k(\mathbf{x})} s_j(\mathbf{x})}
  • xRH\mathbf{x} \in \mathbb{R}^{H} — current token’s hidden state (FFN sublayer input)
  • WgRH×NW_g \in \mathbb{R}^{H \times N} — router projection mapping the hidden state to NN expert logits
  • s(x)RN\mathbf{s}(\mathbf{x}) \in \mathbb{R}^{N} — router probabilities across all experts
  • Tk(x)\mathcal{T}_k(\mathbf{x}) — index set of top-kk selected experts
  • gi(x)g_i(\mathbf{x}) — combine weight; Mixtral and DeepSeek re-normalize the top-kk scores so they sum to 1
  • FFNi\text{FFN}_iii-th expert, typically a SwiGLU FFN with its own (unshared) weights

MoE vs Dense FFN

DimensionDense FFN (SwiGLU)MoE (top-kk of NN)
Params (FFN block)3HI3 H I3HIN3 H I \cdot N + HNH N (router)
Per-token activated FLOPs6HI6 H I6HIk6 H I \cdot k + 2HN2 H N (router)
Per-token weight HBM (decode)3HI3 H I bytes3HIk3 H I \cdot k bytes
VRAM footprint3HI3 H I3HIN3 H I \cdot N (all experts must fit)
Kernel shapefixed GEMMgrouped GEMM / token permutation
Multi-GPU commall-to-all under expert parallelism
Training stabilitystraightforwardrouter prone to collapse, needs load-balance

Key trade-off: decoupling capacity from FLOPs. At the same activated parameter count (i.e. same FLOP budget), MoE can pack 8–32× more total parameters — knowledge capacity goes up for free. The costs are VRAM (must fit all experts), routing stability (avoiding hot experts), and inference batching (per-token paths differ).

Load Balance

A naive router collapses onto a few hot experts during training. GShard / Switch use an auxiliary loss as a soft constraint:

Laux=αNi=1Nfisˉi\mathcal{L}_{\text{aux}} = \alpha \cdot N \cdot \sum_{i=1}^{N} f_i \cdot \bar{s}_i
  • fif_i — fraction of tokens in the current batch routed to expert ii
  • sˉi\bar{s}_i — mean router probability assigned to expert ii within the batch
  • α\alpha — auxiliary loss weight (Switch uses 0.01\sim 0.01)
  • Intuition: large fif_i AND large sˉi\bar{s}_i means the expert is chosen often and with high confidence → penalized → gradient pushes that router logit back down

DeepSeek V3 goes further with auxiliary-loss-free load balance: maintain a bias bib_i per expert, base TopK on si+bis_i + b_i; raise bib_i for under-used experts and lower it for over-used ones — affecting selection without polluting gradients, avoiding the main-task accuracy hit that aux loss can cause.

Modern Implementations

ModelTotal / activatedNN routedtop-kkSharedRouting notes
Switch Transformer (2021)1.6 T / ~26 B20481First stably-trained large MoE; hard top-1 + capacity factor
GLaM (2022)1.2 T / 97 B642Google decoder-only MoE; halves inference cost vs dense
Mixtral 8×7B (2023)47 B / 13 B82Open-source “8 big experts” reference; per-layer routing
Mixtral 8×22B (2024)141 B / 39 B82Scaled-up 8×7B
Qwen1.5-MoE-A2.7B (2024)14 B / 2.7 B6044Alibaba’s first fine-grained MoE
DeepSeek V2 (2024)236 B / 21 B16062Fine-grained + shared-expert paradigm established
DeepSeek V3 (2024)671 B / 37 B25681Aux-loss-free load balance; paired with MLA
Qwen3-MoE 235B-A22B (2025)235 B / 22 B1288DeepSeek-style fine-grained
Llama 4 Scout (2025)109 B / 17 B1611top-1 + 1 shared; extreme sparsity
Llama 4 Maverick (2025)400 B / 17 B12811Same idea, expert count pushed to 128
GPT-4 (rumored)~1.8 T / ~280 B162Never officially disclosed; semiconductor-analyst reconstructions

Four patterns:

  1. Coarse → fine. Mixtral-era designs are “few-and-fat” (8 experts each near a dense FFN’s size); DeepSeek onward shrinks each expert and pushes count past 100, going “many-and-thin” — the same activated FLOPs now span exponentially more combinations.
  2. Shared experts become standard. DeepSeek / Qwen-MoE / Llama 4 all reserve 1–2 “everyone-must-visit” shared experts per layer to absorb generic patterns; the sparse experts handle specialization.
  3. MoE pairs with KV-cache compression. Pushing total params to hundreds of billions makes long-context decode equally pressured — it’s no coincidence DeepSeek V3 ships MLA + fine-grained MoE together.
  4. top-kk goes to the extremes. Switch (k=1) → Mixtral (k=2) → DeepSeek V3 (k=8 with small experts) → Llama 4 (k=1 + shared). Small kk makes batching and capacity bounds tractable; fine-graining and shared experts make up the lost expressivity.

Residual

xout=h+FFN(RMSNorm(h))\mathbf{x}_{\text{out}} = \mathbf{h} + \text{FFN}(\text{RMSNorm}(\mathbf{h}))
  • h\mathbf{h} — attention sublayer output (already contains the first residual)
  • xoutRH\mathbf{x}_{\text{out}} \in \mathbb{R}^{H} — full Transformer layer output, fed into the next layer
  • RMSNorm sits inside the residual branch (pre-norm); the main path h\mathbf{h} goes straight through

LM Head + Sampling

logits=xfinalWlm,p=softmax(logits/T)\text{logits} = \mathbf{x}_{\text{final}} W_{\text{lm}}, \quad \mathbf{p} = \text{softmax}(\text{logits}/T)
  • xfinalRH\mathbf{x}_{\text{final}} \in \mathbb{R}^{H} — hidden state after the final RMSNorm, at the last position
  • WlmRH×VW_{\text{lm}} \in \mathbb{R}^{H \times V} — output embedding matrix (optionally shared with EE — tied embedding)
  • logitsRV\text{logits} \in \mathbb{R}^{V} — unnormalized score per vocab token
  • T>0T > 0 — temperature; larger flattens the distribution (TT \to \infty → uniform, T0T \to 0 → argmax)
  • pRV\mathbf{p} \in \mathbb{R}^{V} — normalized probability distribution

Several logits transformations are typically applied before sampling:

Repetition penalties (repetition / frequency / presence penalty):

logitsv=logitsvα1[vhistory]βcount(v)\text{logits}'_v = \text{logits}_v - \alpha \cdot \mathbb{1}[v \in \text{history}] - \beta \cdot \text{count}(v)
  • vv — a token in the vocabulary
  • logitsv\text{logits}_v — the raw logit for that token
  • α0\alpha \ge 0 — presence penalty (subtract once if ever seen)
  • β0\beta \ge 0 — frequency penalty (subtract linearly by occurrence count)
  • 1[]\mathbb{1}[\cdot] — indicator function (1 if the condition holds, else 0)
  • count(v)\text{count}(v) — number of times token vv has appeared in the generated sequence

Top-k: keep the largest kk logits, set the rest to -\infty.

Top-p (nucleus): sort by probability descending, keep the smallest set whose cumulative probability p\le p.

Min-p: keep tokens with pvpminpmaxp_v \ge p_{\min} \cdot p_{\max}; friendlier to low-entropy distributions.

Typical-p: truncate by deviation from the conditional entropy, keeping the set where logpvH(p)|-\log p_v - H(\mathbf{p})| is small.

All truncations act before/after the probability distribution itself, without changing the skeleton of the formula.

Prefill Stage Shape Transitions — S Tokens Walk Through Once

Input: input_ids [B, S] = [2, 10]. The figure below shows the forward pass of one Transformer layer, wrapped in 32 layers; the shape stays at [B, S, H] = [2, 10, 4096] throughout, with residual edges shown as orange dashed lines.

× 32 layersinput_ids · [2, 10]Embedding · lookup E[V, H]x = E[token_id]x · [2, 10, 4096] [B, S, H]RMSNormx / √(mean(x²) + ε) ⊙ γQ/K/V proj + RoPEQ = X·W_Q K = X·W_K V = X·W_VQ [2, 32, 10, 128] K, V [2, 8, 10, 128] GQAWrite KV Cachecache[l][:,:,0:10,:] = K, VAttention · multi-headsoftmax(Q·Kᵀ / √d + M) · V · W_Oscores [2, 32, 10, 10] · causal mask+h · [2, 10, 4096]h = x + Attn(RMSNorm(x)) · W_ORMSNormh / √(mean(h²) + ε) ⊙ γFFN · SwiGLU(SiLU(h·W_gate) ⊙ h·W_up) · W_downintermediate dim [2, 10, 14336]+x_out · [2, 10, 4096]x_out = h + FFN(RMSNorm(h))Final RMSNormx_final / √(mean(x_final²) + ε) ⊙ γtake last position only · [2, 4096]LM Headlogits = x_final · W_lmW_lm [H, V] · logits [2, 128256]Sampling · temperature / top-pp = softmax(logits / T) → top-k / top-p → samplenext_token · [2, 1] · first outputattn residualFFN residual
Prefill — shape transitions and core formulas of a single Transformer layer. Orange dashed lines are pre-norm residuals; ⊕ marks residual merges; the left bracket marks ”× 32 layers.”

After prefill, the KV Cache state: each layer holds the first 10 positions.

Decode Stage Shape Transitions — Step t Processes Only 1 Token

Prior state: cache_len=S+t1\text{cache\_len} = S + t - 1 positions filled.

Input: input_ids [B, 1] = [2, 1] (the 1 token generated at the previous step).

× 32 layersinput_ids · [2, 1] 1 token from previous stepEmbeddingx = E[token_id]x · [2, 1, 4096] [B, 1, H] · S = 1RMSNormx / √(mean(x²) + ε) ⊙ γQ/K/V proj · compute only 1 new tokenQ_new = x·W_Q K_new = x·W_K V_new = x·W_VQ_new [2, 32, 1, 128] K_new, V_new [2, 8, 1, 128]RoPE · pos = cache_lenQ_new, K_new ← rotate by m·θ_k, m = cache_lenWrite next position in KV Cachecache[l][:,:,cache_len,:] = K_new, V_new · cache_len += 1Read full history K, V from cacheK_full, V_full = cache[l][:,:,:cache_len,:]K_full, V_full · [2, 8, cache_len, 128]Attention · no causal masksoftmax(Q_new · K_fullᵀ / √d) · V_full · W_Oscores [2, 32, 1, cache_len]+h · [2, 1, 4096]h = x + Attn(RMSNorm(x)) · W_ORMSNormh / √(mean(h²) + ε) ⊙ γFFN · SwiGLU · processes only 1 token(SiLU(h·W_gate) ⊙ h·W_up) · W_downintermediate dim [2, 1, 14336]+x_out · [2, 1, 4096]x_out = h + FFN(RMSNorm(h))Final RMSNormx_final / √(mean(x_final²) + ε) ⊙ γLM Headlogits = x_final · W_lmlogits [2, 1, 128256] → [2, 128256]Sampling · temperature / top-pp = softmax(logits / T) → top-k / top-p → samplenext_token · [2, 1]attn residualFFN residualnext-step input
Decode — step t processes only 1 token, with core formulas; the KV Cache keeps growing; the blue dashed line shows the generated next_token feeding back as the next step’s input.

Prefill vs Decode Shape Comparison — GEMM vs GEMV · Compute vs Bandwidth

PositionPrefillDecode (per step)
input_ids[B,S][B, S][B,1][B, 1]
after embedding[B,S,H][B, S, H][B,1,H][B, 1, H]
Q[B,nq,S,d][B, n_q, S, d][B,nq,1,d][B, n_q, 1, d]
K_new / V_new[B,nkv,S,d][B, n_{kv}, S, d][B,nkv,1,d][B, n_{kv}, 1, d]
K_full / V_full (from cache)same as K_new[B,nkv,cache_len,d][B, n_{kv}, \text{cache\_len}, d]
attention scores[B,nq,S,S][B, n_q, S, S][B,nq,1,cache_len][B, n_q, 1, \text{cache\_len}]
attention output[B,nq,S,d][B, n_q, S, d][B,nq,1,d][B, n_q, 1, d]
FFN intermediate[B,S,I][B, S, I][B,1,I][B, 1, I]
logits[B,V][B, V] (last position)[B,V][B, V]
Operation typeGEMM (matrix × matrix)GEMV (matrix × vector)
Bottleneckcomputememory bandwidth

This table is the starting point for understanding every inference acceleration effort: prefill is like training’s forward, compute-bound; decode is a chain of GEMVs, memory-bound, with most time spent fetching weights into SMs. The optimization directions of the two are worlds apart.

KV Cache Shape and Growth — Few KB Per Token · Hundreds of MB at Long Context

One pair of caches per layer:

Kcache,VcacheRB×nkv×Smax×dK_{\text{cache}}, V_{\text{cache}} \in \mathbb{R}^{B \times n_{kv} \times S_{\max} \times d}
  • SmaxS_{\max} — pre-allocated max sequence length (typically the model’s context cap or the scheduler’s limit)
  • Other symbols follow the top-of-article convention (B,nkv,dB, n_{kv}, d); the 4 dimensions follow PyTorch’s [B, head, seq, head_dim] ordering

Per-token, per-layer cache size (fp16):

2×nkv×d×2 bytes=2×8×128×2=4 KB2 \times n_{kv} \times d \times 2\ \text{bytes} = 2 \times 8 \times 128 \times 2 = 4\ \text{KB}
  • Leftmost 22 — one for K, one for V
  • 2 bytes2\ \text{bytes} — fp16 element size (fp8 / int8 cuts this to 1/2 or 1/4)
  • Right side plugs in Llama 3 8B: nkv=8n_{kv} = 8, d=128d = 128

Per-token across the full 32-layer model: 4 KB×32=128 KB/token4\ \text{KB} \times 32 = 128\ \text{KB} / \text{token}. A 4096-token request: 128 KB×4096512 MB128\ \text{KB} \times 4096 \approx 512\ \text{MB}.

Several engineering optimizations:

  • Paged Attention (vLLM): split the cache into fixed-size blocks (typically 16 tokens), with a block table mapping virtual to physical addresses, eliminating fragmentation. The formulas don’t change; only the tensor layout and access pattern change.
  • Sliding Window Attention (Mistral): keep only the most recent WW tokens of K and V. Cache cap drops from SmaxS_{\max} to WW, at the cost of information truncation, with long-range dependencies relayed through cross-layer stacking.
  • INT8 / FP8 KV Cache: quantize fp16 cache down to int8 or even fp8, per-channel or per-token quantization, with controllable error and cache footprint cut by 1/2 to 1/4. Representative work: KIVI / KVQuant.
  • KV compression / eviction (H2O, StreamingLLM, SnapKV): drop unimportant positions based on attention weights; used at very long context lengths.
  • MLA: mentioned earlier — modifies cache shape at the model-structure level, not as a postprocess.

Per-Step Compute and Memory Cost — Llama 3 8B fp16 · H100 Knee ~330 FLOPs/byte

The shape diagrams above show shapes but not magnitudes. 90% of inference optimization discussion is about “how many FLOPs does this step cost, how many bytes move,” so let’s lay each step’s cost into tables directly.

Using Llama 3 8B, fp16, B=1B=1 as the baseline; for prefill take S=2048S=2048; for decode take cache_len=2048\text{cache\_len}=2048 (some step during generation around the 2048th token).

Reference hardware knee: H100 SXM fp16 theoretical compute ~989 TFLOPs, HBM bandwidth ~3 TB/s, roofline knee AI330 FLOPs/byte\text{AI}^{*} \approx 330\ \text{FLOPs/byte}. Above it is compute-bound, below is memory-bound.

Weight Distribution

ComponentShapefp16 sizefull model (× 32 layers)
Embedding EE[V,H][V, H]1.0 GB1.0 GB
WQW_Q[H,H][H, H]32 MB1.0 GB
WKW_K[H,nkvd][H, n_{kv}d]8 MB256 MB
WVW_V[H,nkvd][H, n_{kv}d]8 MB256 MB
WOW_O[H,H][H, H]32 MB1.0 GB
WgateW_{\text{gate}}[H,I][H, I]117 MB3.7 GB
WupW_{\text{up}}[H,I][H, I]117 MB3.7 GB
WdownW_{\text{down}}[I,H][I, H]117 MB3.7 GB
RMSNorm γ\boldsymbol{\gamma} (2 per layer)[H]×2[H]\times 216 KB500 KB
LM head WlmW_{\text{lm}}[H,V][H, V]1.0 GB1.0 GB
Total~432 MB / layer~16 GB

Full-model fp16 weights are ~16 GB; the “floor price” of every forward pass is to scan these 16 GB from HBM. At H100’s 3 TB/s, =16/30005.3 ms= 16/3000 \approx 5.3\ \text{ms}this is the physical lower bound of single-request decode.

Per-Layer, Per-Step Compute / Memory I/O

Compare the same layer’s substeps under prefill (N=SN=S) and decode (N=1N=1). “Weight HBM” is the weight bytes fetched from VRAM; “KV HBM” is the KV-cache bytes read/written. Intermediate activations are assumed fused into kernels and not counted separately.

StepPrefill FLOPs (S=2048)Decode FLOPs (S=1)Weight HBMKV HBM
RMSNorm5BSH5BSH ≈ 42 MF20 KFγ\boldsymbol{\gamma} 8 KB
QprojQ_{\text{proj}}2BSH22BSH^{2} ≈ 68.7 GF33.5 MFWQW_Q 32 MB
KprojK_{\text{proj}} (+write cache)2BSHnkvd2BSH \cdot n_{kv}d ≈ 17.2 GF8.4 MFWKW_K 8 MBW 4 MB / 2 KB
VprojV_{\text{proj}} (+write cache)17.2 GF8.4 MFWVW_V 8 MBW 4 MB / 2 KB
RoPE~50 MF25 KF
Attn QKQK^{\top}2BnqNLkd2B n_q N L_k d ≈ 34.4 GF16.8 MFR 4 MB (decode)
softmax~700 MF260 KF
Attn V\cdot V34.4 GF16.8 MFR 4 MB (decode)
WOW_O2BSH22BSH^{2} ≈ 68.7 GF33.5 MFWOW_O 32 MB
RMSNorm42 MF20 KFγ\boldsymbol{\gamma} 8 KB
WgateW_{\text{gate}}2BSHI2BSHI ≈ 241 GF117 MFWgateW_{\text{gate}} 117 MB
WupW_{\text{up}}241 GF117 MFWupW_{\text{up}} 117 MB
SiLU + gate~90 MF45 KF
WdownW_{\text{down}}2BSIH2BSIH ≈ 241 GF117 MFWdownW_{\text{down}} 117 MB
Per-layer total~960 GFLOPs~470 MFLOPs~432 MBW 8 MB (P) / R 8 MB (D)

A few direct conclusions:

  • FFN is the real protagonist. Wgate+Wup+WdownW_{\text{gate}} + W_{\text{up}} + W_{\text{down}} consume ~75% of FLOPs and ~80% of weight bandwidth. MoE, sparse activation, and FFN quantization all target this block.
  • The 4 attention projections (Q/K/V/O) account for ~18%; the actual QKQK^{\top} and V\cdot V only ~7% — in prefill, attention isn’t the bottleneck — the projections are.
  • Decode’s KV reads are 8 MB per layer; at cache_len=2048\text{cache\_len}=2048 this is only ~2% of weight reads. But once context stretches to 64K or 128K, it grows tens of times, overtaking weight bandwidth as the new bottleneck (this is why Paged Attention, sliding window, and KV quantization exist).

One Full Forward Pass

Adding 32 layers + embedding + LM head:

StageFLOPsHBM I/OArithmetic IntensityBottleneck
Prefill S=2048, B=1~31 TFLOPs~14 GB (weights) + 256 MB (KV write)~2200 FLOPs/bytecompute
Decode step, cache_len=2048, B=1~15 GFLOPs~14 GB (weights) + 256 MB (KV read)~1.05 FLOPs/bytebandwidth
LM head (prefill, last position only)~1 GFLOP1 GB~1 FLOPs/bytebandwidth
LM head (decode)~1 GFLOP1 GB~1 FLOPs/bytebandwidth

Decode’s 1.05 FLOPs/byte is 2.5 orders of magnitude below H100’s knee of 330 — meaning ideal single-request decode compute utilization is only 1.05/3300.3%1.05/330 \approx 0.3\%. This is the mathematical basis for continuous batching: push BB to 32 so the same weight read is amortized across 32 requests; arithmetic intensity scales by 32×, decode throughput grows almost linearly until the attention portion or compute itself becomes the wall.

Mental-Math Rules

Two rules cover 90% of inference performance estimation:

  1. FLOPs ≈ 2PN2 P N: PP is the parameter count (~8B); NN is the total number of tokens this forward pass processes. Each parameter is used once per token (one MAC = 2 FLOPs). E.g., prefill S=2048S=2048: 2×8B×204833 TFLOPs2 \times 8\text{B} \times 2048 \approx 33\ \text{TFLOPs}, matching the itemized sum of 31 TFLOPs.
  2. Weight HBM I/O ≈ 2P2 P bytes (fp16): one forward pass scans the model once, about 16 GB.

Arithmetic intensity is essentially 2PN2P=N\frac{2 P N}{2 P} = N — the total number of tokens participating in this forward. Prefill has SBS \cdot B tokens; decode has only BB. This single number directly determines why prefill and decode have different bottlenecks.

Compute Complexity Overview — With vs Without KV Cache Differ by Three Orders of Magnitude

Prefill (process SS tokens at once):

FLOPsO(LSH2)linear layers+O(LS2H)attention\text{FLOPs} \sim \underbrace{O(L \cdot S \cdot H^{2})}_{\text{linear layers}} + \underbrace{O(L \cdot S^{2} \cdot H)}_{\text{attention}}
  • LSH2L \cdot S \cdot H^{2} — per layer, 4 H×HH \times H projections + 3 H×IH \times I FFN projections (with I4HI \sim 4H), applied to SS tokens
  • LS2HL \cdot S^{2} \cdot H — attention’s QKQK^{\top} and V\cdot V, with the S×SS \times S score matrix
  • At short sequences linear layers dominate; once SHS \gtrsim H the attention quadratic catches up

Decode per step (process 1 token, history cache_len\text{cache\_len}):

FLOPsO(LH2)linear layers, constant+O(Lcache_lenH)attention, linear in cache\text{FLOPs} \sim \underbrace{O(L \cdot H^{2})}_{\text{linear layers, constant}} + \underbrace{O(L \cdot \text{cache\_len} \cdot H)}_{\text{attention, linear in cache}}
  • cache_len\text{cache\_len} — number of currently cached positions (=S+t1= S + t - 1)
  • Single step processes 1 new token, so linear layers’ SS becomes 11; attention still scans the full cache and grows linearly with it

Total complexity to generate TT tokens:

FLOPstotalO ⁣(LTH2+LT(S+T)H)\text{FLOPs}_{\text{total}} \sim O\!\left(L \cdot T \cdot H^{2} + L \cdot T \cdot (S + T) \cdot H\right)
  • TT — number of generated tokens (top-of-article convention)
  • (S+T)(S + T) — average attention span (prompt + generated segment)
  • First term sums linear layers across TT decode steps; second term approximates the attention portion summed over t=1,,Tt = 1, \ldots, T (exact form is t(S+t)\sum_t (S + t))

Without KV cache: O(L(S+T)3)O(L \cdot (S+T)^{3}) — a massive difference.

Reading FLOPs and bandwidth together is even more illuminating — that’s what the previous section’s table shows: prefill challenges the compute ceiling; decode challenges the bandwidth ceiling; continuous batching’s point is to fuse NN requests’ decodes into one large GEMV, amortizing the weight-fetch cost across NN requests, with throughput rising linearly until compute or the attention portion becomes the bottleneck.

How Engineering Optimizations Plug into the Formulas — Flash Attn / Spec Decode / Continuous Batch

Flash Attention: mathematically equivalent to standard attention — the formulas don’t change a single character. Engineering-wise it fuses softmax and matmul into one kernel, updating softmax’s running statistics (max, sum) in a streaming manner over blocks, avoiding writing the S×SS \times S attention matrix back to HBM. Complexity unchanged; memory drops from O(S2)O(S^{2}) to O(S)O(S); speedup comes mainly from reduced HBM access. FA-2 shifted the partition granularity from heads to query blocks; FA-3 on H100/H200 adds warpgroup MMA + producer-consumer async pipelining.

Flash Decoding: at decode, Flash Attention’s Q has only one row, so kernel parallelism is too low. Flash Decoding splits the cache_len\text{cache\_len} dimension of K and V into chunks for parallelism and then does a final log-sum-exp reduction. The formula is the same softmax, just split into two passes.

Speculative Decoding: a small “draft” model generates kk tokens sequentially, then the large model verifies them with one prefill over the kk positions. Acceptance rule:

accept with prob min ⁣(1,ptarget(x)pdraft(x))\text{accept with prob } \min\!\left(1, \frac{p_{\text{target}}(x)}{p_{\text{draft}}(x)}\right)
  • xx — a candidate token produced by the draft model
  • ptarget(x)p_{\text{target}}(x) — probability the large (target) model assigns to xx at that position
  • pdraft(x)p_{\text{draft}}(x) — probability the small (draft) model assigns to xx at the same position
  • Combined with “on rejection, resample from max(0,ptargetpdraft)\max(0, p_{\text{target}} - p_{\text{draft}})”, this rule provably yields the same sampling distribution as direct target-model decoding — zero quality loss

The crux is fusing kk decode GEMVs into one kk-length GEMM, turning the large model’s memory-bound regime back into compute-bound. With expected kˉ\bar k accepted tokens per step, throughput scales by kˉ\bar k (minus draft overhead). Variants: Medusa (multi-head prediction), EAGLE (feature-level draft), Lookahead Decoding (no draft model).

Continuous Batching (vLLM, TGI): instead of padding prefill to align at boundaries, schedule at the per-step request level. Each step picks a batch of requests in the same phase (prefill or decode), releases finished ones. Mathematically each request is independent; only the ordering changes. Original paper: OSDI’22’s Orca.

Chunked Prefill: split long prompts’ prefill into chunks and mix them with decode requests in the same step, reducing decode latency jitter. No formula changes. The core scheduling primitive of SARATHI / DistServe.

One Sentence Spanning the Whole Process — From Token ID to Next Token

Input token IDs → look up embedding → through LL layers (Pre-RMSNorm → Attention with RoPE → residual → Pre-RMSNorm → SwiGLU FFN → residual) → Final RMSNorm → LM Head → logits → sampling. In prefill, SS tokens pass in parallel, producing the first token + full KV Cache; in decode, each step inputs 1 token; at attention it reads historical K and V from the cache, while all other operations are per-token independent.

Key Engineering Invariants — 6 Rules Worth Memorizing

  1. The main-line tensor shape is always [B,Scurrent,H][B, S_{\text{current}}, H]. Residual structure preserves the dimension; whenever HH appears different somewhere, either it’s spread into heads inside attention, or lifted to II inside FFN, and back to HH on exit.
  2. K and V, once computed, never change. Because they’re linear projections WK,WVW_K, W_V applied to the already-fixed input x\mathbf{x}, and the causal structure ensures later positions cannot reach back to modify earlier representations. This is the mathematical basis for KV Cache.
  3. Attention is the only cross-token operation; all others (norm, projection, FFN, activation) are per-token independent. So only K and V — the inputs to cross-token operations — need caching; everything else can be computed and discarded immediately.
  4. Decode’s scores shape is [B,nq,1,cache_len][B, n_q, 1, \text{cache\_len}]. The “1” is the Q side (the current new token), and the cache_len\text{cache\_len} dim is eliminated when weighted-summing with VV, returning to one row.
  5. Decode’s non-attention compute per step is constant; only attention grows linearly with cache length. So the true reason “generation gets slower over time” is that attention’s cache_len\text{cache\_len} keeps growing, plus KV Cache pushing the memory footprint against HBM bandwidth limits.
  6. Prefill uses GEMM; decode uses GEMV. This one-letter difference dictates that every inference engine has two kernel sets, two scheduling strategies. Internalize this and no inference-optimization paper will lose you.

Internalize these six and you’ll find that Flash Attention, PagedAttention, MLA, speculative decoding — they’re all local optimizations at some spot on this skeleton, while the skeleton itself has barely changed in twenty years.

References — Formulas · Papers · Engineering Blogs

Architecture and Core Operators

  • Vaswani et al., “Attention Is All You Need” (NeurIPS 2017) — the original Transformer paper. arxiv.org/abs/1706.03762
  • Shazeer, “GLU Variants Improve Transformer” (2020) — source of SwiGLU / GeGLU / ReGLU. arxiv.org/abs/2002.05202
  • Zhang & Sennrich, “Root Mean Square Layer Normalization” (NeurIPS 2019) — RMSNorm. arxiv.org/abs/1910.07467
  • Hendrycks & Gimpel, “Gaussian Error Linear Units (GELUs)” (2016) — GeLU definition and approximation. arxiv.org/abs/1606.08415

Positional Encoding

  • Su et al., “RoFormer: Enhanced Transformer with Rotary Position Embedding” (2021) — RoPE. arxiv.org/abs/2104.09864
  • Peng et al., “YaRN: Efficient Context Window Extension of Large Language Models” (2023) — long-context RoPE scaling. arxiv.org/abs/2309.00071
  • bloc97 & emozilla, “NTK-Aware Scaled RoPE” — discussion of early NTK-aware open-source work. reddit / LocalLLaMA

Multi-Head Variants (KV Sharing / Compression)

  • Shazeer, “Fast Transformer Decoding: One Write-Head is All You Need” (2019) — MQA. arxiv.org/abs/1911.02150
  • Ainslie et al., “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints” (EMNLP 2023) — GQA. arxiv.org/abs/2305.13245
  • DeepSeek-AI, “DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model” (2024) — MLA introduced. arxiv.org/abs/2405.04434
  • DeepSeek-AI, “DeepSeek-V3 Technical Report” (2024) — MLA + MoE engineering. arxiv.org/abs/2412.19437

Mixture-of-Experts

  • Shazeer et al., “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer” (ICLR 2017) — foundational paper for top-kk gated MoE in deep nets. arxiv.org/abs/1701.06538
  • Lepikhin et al., “GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding” (ICLR 2021) — the standard aux-loss load-balancing recipe. arxiv.org/abs/2006.16668
  • Fedus et al., “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity” (JMLR 2022) — top-1 + capacity factor. arxiv.org/abs/2101.03961
  • Du et al., “GLaM: Efficient Scaling of Language Models with Mixture-of-Experts” (ICML 2022) — Google’s 1.2 T MoE. arxiv.org/abs/2112.06905
  • Jiang et al., “Mixtral of Experts” (Mistral AI, 2024) — Mixtral 8×7B tech report. arxiv.org/abs/2401.04088
  • DeepSeek-AI, “DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Models” (2024) — fine-grained + shared-expert paradigm. arxiv.org/abs/2401.06066
  • Wang et al., “Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts” (2024) — the load-balancing approach used in DeepSeek V3. arxiv.org/abs/2408.15664
  • Meta AI, “The Llama 4 herd” (2025) — Llama 4 Scout / Maverick / Behemoth technical details. ai.meta.com/blog/llama-4

Flash Attention Series

  • Dao et al., “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness” (NeurIPS 2022) — FA-1. arxiv.org/abs/2205.14135
  • Dao, “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning” (2023) — FA-2. arxiv.org/abs/2307.08691
  • Shah et al., “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-Precision” (NeurIPS 2024) — FA-3 / Hopper. arxiv.org/abs/2407.08608
  • Dao et al., “Flash-Decoding for long-context inference” (Stanford / Together blog, 2023) — Flash Decoding. crfm.stanford.edu

Inference Engines and Serving Schedulers

Speculative Decoding Family

  • Leviathan, Kalman, Matias, “Fast Inference from Transformers via Speculative Decoding” (ICML 2023). arxiv.org/abs/2211.17192
  • Chen et al., “Accelerating Large Language Model Decoding with Speculative Sampling” (DeepMind, 2023) — concurrent independent work. arxiv.org/abs/2302.01318
  • Cai et al., “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads” (2024). arxiv.org/abs/2401.10774
  • Li et al., “EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty” (ICML 2024). arxiv.org/abs/2401.15077
  • Fu et al., “Lookahead Decoding: Breaking the Sequential Dependency of LLM Inference” (2024). arxiv.org/abs/2402.02057

KV Cache Compression / Quantization

  • Xiao et al., “Efficient Streaming Language Models with Attention Sinks” (ICLR 2024) — StreamingLLM. arxiv.org/abs/2309.17453
  • Zhang et al., “H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models” (NeurIPS 2023). arxiv.org/abs/2306.14048
  • Li et al., “SnapKV: LLM Knows What You are Looking for Before Generation” (2024). arxiv.org/abs/2404.14469
  • Liu et al., “KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache” (ICML 2024). arxiv.org/abs/2402.02750
  • Hooper et al., “KVQuant: Towards 10 Million Context Length LLM Inference with KV Cache Quantization” (NeurIPS 2024). arxiv.org/abs/2401.18079

Representative Open-Source Model Technical Reports

Hardware / Roofline

  • NVIDIA, “H100 Tensor Core GPU Architecture Whitepaper” (2022). resources.nvidia.com
  • Williams, Waterman, Patterson, “Roofline: An Insightful Visual Performance Model for Multicore Architectures” (CACM 2009) — original Roofline paper. dl.acm.org
  • Hoffmann et al., “Training Compute-Optimal Large Language Models” (Chinchilla, 2022) — training-side version of the 2PN2PN FLOPs rule of thumb. arxiv.org/abs/2203.15556

Other Long-Form Articles / Tutorials

  • Lilian Weng, “The Transformer Family Version 2.0”. lilianweng.github.io
  • Horace He, “Making Deep Learning Go Brrrr From First Principles” — three bottleneck classes: compute / memory / overhead. horace.io/brrr_intro
  • Adam Casson, “Transformer Inference Arithmetic” — mental math for inference FLOPs / KV cache. kipp.ly
  • Jay Mody, “LLM inference, in detail” — another take on tensor shape flow. jaykmody.com