Knowledge Distillation: Teaching a Small Network to Mimic a Big One

Knowledge distillation is a common technique for raising training accuracy and speeding up convergence, often used to fine-tune models after compression and to train the small networks inside NAS.

The idea is straightforward: take a high-accuracy teacher model’s outputs as labels for a student model, and have the student mimic the teacher’s behavior.

The distillation process
Knowledge distillation: use the teacher’s soft outputs to teach the student.

Classic distillation: soft labels and temperature

The most common implementation adds the student’s own classification loss to a distillation loss between teacher and student outputs. Let H\mathcal{H} be cross-entropy, TT the softmax temperature, zsz_s the student’s logits and ztz_t the teacher’s, yy the label, and α,β\alpha,\beta the two loss weights:

L(x;W)=αH(y, σ(zs;T=1))+βH(σ(zt;T=τ), σ(zs;T=τ))\mathcal{L}(x;W) = \alpha\,\mathcal{H}\big(y,\ \sigma(z_s; T{=}1)\big) + \beta\,\mathcal{H}\big(\sigma(z_t; T{=}\tau),\ \sigma(z_s; T{=}\tau)\big)

The first term is the student’s ordinary loss against the true label; the second pulls the student’s soft output toward the teacher’s. The temperature-scaled softmax is:

σ(zi;T)=exp(zi/T)jexp(zj/T)\sigma(z^i; T)=\frac{\exp(z^i/T)}{\sum_j \exp(z^j/T)}

At T=1T=1 this is the usual softmax; T>1T>1 smooths the class distribution, making it easier for the student to imitate — the soft labels and soft predictions in the figure above.

Distilling at intermediate layers: FitNets and FSP

Beyond transferring knowledge at the final output, FitNets proposes hints training: transferring knowledge between an intermediate layer of the teacher and the student, giving the student “hints.”

Knowledge transfer at intermediate layers
FitNets: transfer knowledge between intermediate layers of teacher and student.

The hint loss is below, where rr is an extra layer to match dimensions (with parameters WrW_r), and uhu_h, vgv_g are the teacher’s and student’s intermediate outputs:

LHT(WGuided,Wr)=12uh(x;WHint)r(vg(x;WGuided);Wr)2\mathcal{L}_{HT}(W_{Guided}, W_r)=\tfrac{1}{2}\big\lVert u_h(x; W_{Hint})-r\big(v_g(x; W_{Guided}); W_r\big)\big\rVert^2

After hints training, you still finish training the whole student via output-layer distillation.

Distillation after hints training
After hints training, do a round of output-layer distillation.

Another work (A Gift from Knowledge Distillation) transfers knowledge via the FSP matrix — each element is the inner product between two layers’ feature maps.

Computing the FSP matrix
The FSP matrix: built from inner products between feature maps.

Like FitNets it transfers at intermediate layers, but across multiple layers (not just one). It suits cases where teacher and student have similar structure but different depth, so the FSP matrices share dimensions and an L2 loss can be computed.

Knowledge transfer across multiple layers
Transferring FSP matrices across multiple layers.

Teacher assistant: keep the gap small

Worried that too large a gap between teacher and student makes transfer hard, TAKD inserts a teacher assistant between them: knowledge flows teacher → assistant → student, with the extra step smoothing the transfer.

Teacher assistant smooths distillation
TAKD: add a teacher assistant to smooth the teacher→student transfer.

Learning from each other: DML

The methods above all have the student learn one-way from the teacher. DML (Deep Mutual Learning) instead has a cohort of networks learn from one another: two differing architectures complement each other during training, each ending up better. Each network’s loss includes not only the cross-entropy against the label but also the KL divergence against the other network’s output.

DML training
DML: each network is the other’s “soft label” — loss = cross-entropy + mutual KL divergence.

The two networks’ losses LΘ1L_{\Theta_1}, LΘ2L_{\Theta_2} are:

LΘ1=LC1+DKL(p2p1),LΘ2=LC2+DKL(p1p2)L_{\Theta_1}=L_{C_1}+D_{KL}(p_2\,\Vert\,p_1),\qquad L_{\Theta_2}=L_{C_2}+D_{KL}(p_1\,\Vert\,p_2)

where LC1L_{C_1} is the cross-entropy loss and DKLD_{KL} is the Kullback-Leibler divergence:

DKL(p2p1)=i=1Nm=1Mp2m(xi)logp2m(xi)p1m(xi)D_{KL}(p_2\,\Vert\,p_1)=\sum_{i=1}^{N}\sum_{m=1}^{M}p_2^m(x_i)\log\frac{p_2^m(x_i)}{p_1^m(x_i)}

References

  • Hinton, Geoffrey, Vinyals, Oriol, Dean, Jeff. Distilling the Knowledge in a Neural Network. arXiv:1503.02531, 2015.
  • Gou, Jianping, et al. Knowledge Distillation: A Survey. arXiv:2006.05525, 2020.
  • Romero, Adriana, et al. FitNets: Hints for Thin Deep Nets. arXiv:1412.6550, 2014.
  • Yim, Junho, et al. A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning. CVPR, 2017.
  • Mirzadeh, Seyed Iman, et al. Improved Knowledge Distillation via Teacher Assistant. AAAI, 2020.
  • Zhang, Ying, et al. Deep Mutual Learning. CVPR, 2018.