Dataset Distillation: Compressing an Entire Dataset Into a Handful of Images
Today I read a paper on Dataset Distillation. The core idea is this: distill the dataset down to a set of synthetic images and a learning rate, and then, using that learning rate, only a very small number of iterations on the distilled images is enough to reach accuracy close to fully training on the original dataset. The two networks use the same initialization scheme (fixed network initialization); later in the paper there are also experiments under random initialization, pretrained initialization, and other settings.
The distilled dataset does not necessarily have a correct pixel distribution—in the extreme case it can be compressed down to a single image per class. The benefit of this is that you can quickly reload a network with the same capability, rather than re-running the full training pipeline every time.
The Basic Method
Concretely, the approach treats the network’s weights as a differentiable function of the synthetic training images, and then optimizes those synthetic images. But this requires the network’s initial weights to be specified in advance. To address this, the paper goes on to develop a method that works for randomly initialized networks, as well as an iterative version that supports training over multiple epochs. Finally, it uses a linear model to analyze the lower bound on the amount of distilled data needed to match the performance of the original dataset.
Exploration in Stages
Stage One: Fixed Initialization + Single-Step Gradient Descent
Once you have the learning rate and the distilled data, a single step of gradient descent is enough to get good results on the real test set.
The learning rate and the distilled dataset are obtained as follows:
Stage Two: Random Initialization
The learning rate and distilled images obtained above were optimized under a specific set of weights, and may not be practical when the initialization scheme is changed. The distilled images actually encode both information about the original data and information about the network initialization weights used for distillation, so swapping in a different set of initial weights can render them useless.
To handle the random-initialization case, the optimization of the distilled images and the learning rate is changed to the following form—that is, optimizing at the expected weights of the weight distribution under random initialization:
The specific algorithm is as follows:
Stage Three: Linear-Model Analysis
A linear model is used to illustrate the properties and limitations of the method, along with a discussion of how the initial weight distribution affects distillation quality.
Stage Four: Multi-Step Gradient Descent and Multiple Epochs
Building on the algorithm above, an improvement is made that resembles the normal neural-network training process, allowing the distilled dataset to be used over multiple epochs.
Stage Five: Obtaining Distilled Images Under Different Initialization Distributions and Learning Objectives
The paper explores how to obtain distilled images under a variety of initialization distributions and learning objectives. The distilled data can be used to fine-tune a pretrained model very quickly.
Experimental Results on Classification
The experimental results on classification tasks are as follows:
| Ours: Fixed init. | Ours: Random init. | Baselines (GD steps): Random real | Baselines (GD steps): Optimized real | Baselines (GD steps): k-means | Baselines (GD steps): Average real | Baselines (K-NN): Random real | Baselines (K-NN): k-means | |
|---|---|---|---|---|---|---|---|---|
| MNIST | 96.6 | 79.5 ± 8.1 | 68.6 ± 9.8 | 73.0 ± 7.6 | 76.4 ± 9.5 | 77.1 ± 2.7 | 71.5 ± 2.1 | 92.2 ± 0.1 |
| CIFAR10 | 54.0 | 36.8 ± 1.2 | 21.3 ± 1.5 | 23.4 ± 1.3 | 22.5 ± 3.1 | 22.3 ± 0.7 | 18.8 ± 1.3 | 29.4 ± 0.3 |
The paper includes some other experiments as well, which I’ll set aside for now.