Towards Stable and Storage-efficient Dataset Distillation: Matching Convexified Trajectory

Wenliang Zhong
School of software
Shandong University
&Haoyu Tang
School of software
Shandong University
Qinghai Zheng
College of Software
Fuzhou University

&Mingzhu Xu
School of software
Shandong University

&Yupeng Hu
School of software
Shandong University

&Liqiang Nie
School of Computer Science
Harbin Institute of Technology (Shenzhen)

Tang Haoyu is the corresponding author; email: [email protected], [email protected]
Abstract

The rapid evolution of deep learning and large language models has led to an exponential growth in the demand for training data, prompting the development of Dataset Distillation methods to address the challenges of managing large datasets. Among these, Matching Training Trajectories (MTT) has been a prominent approach, which replicates the training trajectory of an expert network on real data with a synthetic dataset. However, our investigation found that this method suffers from three significant limitations: 1. Instability of expert trajectory generated by Stochastic Gradient Descent (SGD); 2. Low convergence speed of the distillation process; 3. High storage consumption of the expert trajectory. To address these issues, we offer a new perspective on understanding the essence of Dataset Distillation and MTT through a simple transformation of the objective function, and introduce a novel method called Matching Convexified Trajectory (MCT), which aims to provide better guidance for the student trajectory. MCT leverages insights from the linearized dynamics of Neural Tangent Kernel methods to create a convex combination of expert trajectories, guiding the student network to converge rapidly and stably. This trajectory is not only easier to store, but also enables a continuous sampling strategy during distillation, ensuring thorough learning and fitting of the entire expert trajectory. Comprehensive experiments across three public datasets validate the superiority of MCT over traditional MTT methods.

1 Introduction

The advancement of deep learning has catalyzed an exponential surge in the requisite volume of training data (Wang et al., 2022). With the emergence of Large Language Models (LLMs), there has been a corresponding rise in model complexity, further intensifying the demand for extensive datasets to facilitate the training of these intricate models. However, collecting and managing large datasets presents significant challenges, including storage requirements, computational load, privacy concerns, and the costs of data labeling. To mitigate these challenges, Dataset Distillation (DD) has emerged as a compelling strategy (Wang et al., 2018). DD endeavors to distill the essence of a large, real-world dataset into a more compact, synthetic dataset that can train models with comparable efficacy.

In the landscape of DD methods, Matching Training Trajectories has emerged as a prominent approach. The MTT method aims to generate a synthetic dataset that guidea the learning trajectory of the student network to approximate the expert trajectory of this network on real data. However, upon closer examination, we identify several limitations inherent in traditional MTT approaches:

Refer to caption
(a) Visualization of the expert trajectory.
Refer to caption
(b) Illustration of Convergence Speed
Figure 1: (a): PCA projection of all waypoints model in the expert trajectory, where z-axis represents the value of (1validation accuracy)1validation accuracy(1-\text{validation accuracy})( 1 - validation accuracy ); (b): The required iteration number to convergence for both the MCT and MTT methods during distillation. The convergence is defined by the condition where the difference between the accuracy at any iteration and the maximum accuracy is less than ϵitalic-ϵ\epsilonitalic_ϵ=2%.

1. Instability of expert trajectory: As shown in Figure 1(a), the validation accuracy of the expert network on the MTT trajectory exhibits oscillations. Matching the trajectory locally in each iteration will lead to the similar oscillation in the trajectory of the synthetic data, thereby impeding robust distillation.

2. Low Convergence Speed: The learning process for the expert trajectory is often slow. As in Figure 1(b), a considerable number of distillation iterations are required to generate a synthetic dataset capable of achieving satisfactory test accuracy, resulting in time-consuming procedures.

3. High Storage Consumption: During the distillation process, the conventional MTT approach necessitates the storage of model weights along all timesteps, which is particularly burdensome in terms of storage (about 50 models should be stored). This high storage consumption is a significant limitation for applying existing DD methods to small-scale models.

Through careful observation, we have reformulated the loss function of MTT, and introduced a novel perspective to interpret the essence of DD and MTT: obtaining a synthetic dataset that offers accurate guidance regarding the magnitude and direction of the next update for any given point in the parameter space of the student model, with this guidance determined by the expert trajectory’s update vector at that point. From this perspective, those three limitations can be easily addressed: to find an optimized expert trajectory that can guide the model to stably converge at each iteration, which is also easy to fit and simple to save.

How to find such a trajectory? Drawing inspiration from linearized dynamics of Neural Tangent Kernel (NTK) method (Arora et al., 2019; Jacot et al., 2018), we present a simple yet novel Matching Convexified Trajectory (MCT) method. The MCT method creates a convex combination (linear) expert trajectory based on the network’s training process real data. This trajectory, which starts from a random initialization model and points directly towards the optimal model point, facilitates stable and rapid convergence of the distillation. Moreover, recovering this trajectory only needs storing two models and a set of constants. Distinct from the MTT method, the convexified trajectory also permits a “continuous sampling” strategy during the distillation, ensuring comprehensive learning and fitting of the expert trajectory.

The contributions of this paper are as follows: 1) We highlight the three limitations of traditional MTT methods, and offer a novel perspective for understanding the objective of DD through a simple reformulation of MTT’s loss function. 2) We propose the MCT method, which creates an easy-to-store convexified expert trajectory with a continuous sampling strategy to enable rapid and stable distillation. 3) Comprehensive experiments on three datasets have verified the superiority of our MCT and the effectiveness of the continuous sampling strategy.

2 Preliminaries and Related Work

2.1 Preliminaries

We first formally define the dataset distillation task. A large scale real dataset 𝒯={(x𝒯(i),y𝒯(i))}i=1|𝒯|𝒯superscriptsubscriptsubscriptsuperscript𝑥𝑖𝒯subscriptsuperscript𝑦𝑖𝒯𝑖1𝒯\mathcal{T}=\{(x^{(i)}_{\mathcal{T}},y^{(i)}_{\mathcal{T}})\}_{i=1}^{|\mathcal% {T}|}caligraphic_T = { ( italic_x start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT , italic_y start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_T | end_POSTSUPERSCRIPT is first provided, where x𝒯(i)dsubscriptsuperscript𝑥𝑖𝒯superscript𝑑x^{(i)}_{\mathcal{T}}\in\mathbb{R}^{d}italic_x start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT and y𝒯(i)𝒴={1,2,,C}subscriptsuperscript𝑦𝑖𝒯𝒴12𝐶y^{(i)}_{\mathcal{T}}\in\mathcal{Y}=\{1,2,\dots,C\}italic_y start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∈ caligraphic_Y = { 1 , 2 , … , italic_C } are the i𝑖iitalic_i-th instance and the corresponding label. C𝐶Citalic_C denotes the number of classes. The core idea of this task is to learn a tiny synthetic dataset 𝒮={(x𝒮(i),y𝒮(i))}i=1|𝒮|𝒮superscriptsubscriptsubscriptsuperscript𝑥𝑖𝒮subscriptsuperscript𝑦𝑖𝒮𝑖1𝒮\mathcal{S}=\{(x^{(i)}_{\mathcal{S}},y^{(i)}_{\mathcal{S}})\}_{i=1}^{|\mathcal% {S}|}caligraphic_S = { ( italic_x start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT , italic_y start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_S | end_POSTSUPERSCRIPT from the original dataset 𝒯𝒯\mathcal{T}caligraphic_T, where x𝒮(i)dsubscriptsuperscript𝑥𝑖𝒮superscript𝑑x^{(i)}_{\mathcal{S}}\in\mathbb{R}^{d}italic_x start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT and y𝒮(i)𝒴subscriptsuperscript𝑦𝑖𝒮𝒴y^{(i)}_{\mathcal{S}}\in\mathcal{Y}italic_y start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT ∈ caligraphic_Y. Typically, ipc𝑖𝑝𝑐ipcitalic_i italic_p italic_c instances are crafted for each class, culminating in a total count for 𝒮𝒮\mathcal{S}caligraphic_S of |𝒮|=Cipc𝒮𝐶𝑖𝑝𝑐{|\mathcal{S}|}=C*ipc| caligraphic_S | = italic_C ∗ italic_i italic_p italic_c. It is always expected that |𝒮||𝒯|much-less-than𝒮𝒯{|\mathcal{S}|}\ll{|\mathcal{T}|}| caligraphic_S | ≪ | caligraphic_T |, while 𝒮𝒮\mathcal{S}caligraphic_S still preserves the majority of the pivotal information in 𝒯𝒯\mathcal{T}caligraphic_T. Consequently, a model trained on 𝒮𝒮\mathcal{S}caligraphic_S should achieve performance comparable to the model trained with the original dataset 𝒯𝒯\mathcal{T}caligraphic_T under the real data distribution 𝒫Dsubscript𝒫𝐷\mathcal{P}_{D}caligraphic_P start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT. Formally, the optimization of DD task can be formulated as:

argmin𝒮(𝒮,𝒯),subscript𝒮𝒮𝒯\arg\min_{\mathcal{S}}\mathcal{L}(\mathcal{S},\mathcal{T}),roman_arg roman_min start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT caligraphic_L ( caligraphic_S , caligraphic_T ) , (1)

where \mathcal{L}caligraphic_L is the certain objective function, which may differ from different DD methods.

2.2 Dataset Distillation Methods.

The field of DD contains four principal approaches. a. Meta-model Matching methods (Wang et al., 2018; Zhou et al., 2022; Nguyen et al., 2021; Loo et al., 2022) involve a bi-level optimization algorithm where the inner loop updates the weights of a differentiable model using gradient descent on a synthetic dataset while caching recursive computation graphs, and the outer loop validates models trained in the inner loops on a real dataset, back-propagating the validation loss through the unrolled computation graph to the synthetic dataset. b. Distribution Matching methods (Zhao and Bilen, 2023; Wang et al., 2022) align synthetic and real data by optimizing within a set of embedding spaces using maximum mean discrepancy. However, inaccurate estimation of the data distribution often results in suboptimal performance. c. Single-step Gradient Matching methods (Zhao et al., 2020; Zhao and Bilen, 2021) aim to align the gradient of the synthetic dataset with that of the real dataset during each training step. To enhance generalization with improved gradients, recent research efforts have focused on further optimizing the gradient matching objective by incorporating class-related information (Lee et al., 2022; Jiang et al., 2023). d. Multi-step Trajectory Matching methods (Cazenavette et al., 2022; Guo et al., 2023) address the accumulated trajectory errors of single-step methods by matching the multi-step training trajectories of models separately trained on synthetic and real datasets.

Our research primarily focuses on multi-step trajectory matching methods. The first method in this branch is MTT (Cazenavette et al., 2022). Based on MTT, Du et al. (2023) presented to incorporate the random noise to the initialized model weights to mitigate accumulated trajectory errors, and Cui et al. (2023) proposed to decompose the objective function of MTT to improve computational efficiency and reduce GPU memory without performance degradation. Further research has explored the robustness of the synthesized dataset (Guo et al., 2023; Li et al., 2022b; Du et al., 2024) and applied this technique to downstream tasks (Li et al., 2022a, 2020).

Despite their successes, none of these approaches address the detriment of oscillations in the MTT expert trajectory on the stability and convergence speed of the distillation process. Furthermore, the necessity to retain all waypoint networks along the expert trajectory has yet to be addressed.

3 Motivation

3.1 Review of Multi-step Trajectory Matching

In this section, we first review the multi-step trajectory matching methods. The essence of them is to minimize the discrepancy of the student training trajectory of 𝒮𝒮\mathcal{S}caligraphic_S and the expert training trajectory of 𝒯𝒯\mathcal{T}caligraphic_T. Here we take MTT (Cazenavette et al., 2022) as an example. Firstly, an expert trajectory τmtt={θ𝒯(t)|0tK}subscript𝜏mttconditional-setsuperscriptsubscript𝜃𝒯𝑡0𝑡𝐾\tau_{\text{mtt}}=\{\theta_{\mathcal{T}}^{(t)}|0\leq t\leq K\}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT = { italic_θ start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | 0 ≤ italic_t ≤ italic_K } is generated by training a randomly initialized model θ𝒯(0)superscriptsubscript𝜃𝒯0\theta_{\mathcal{T}}^{(0)}italic_θ start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT on the real dataset 𝒯𝒯\mathcal{T}caligraphic_T with K𝐾Kitalic_K timesteps. Afterward, MTT matches the student trajectory with the expert τmttsubscript𝜏mtt\tau_{\text{mtt}}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT through massive iterations. During each iteration, MTT samples a random timestep θ𝒯(t)subscriptsuperscript𝜃𝑡𝒯\theta^{(t)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT and captures the target timestep θ𝒯(t+M)subscriptsuperscript𝜃𝑡𝑀𝒯\theta^{(t+M)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_t + italic_M ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT after M𝑀Mitalic_M steps from τmttsubscript𝜏mtt\tau_{\text{mtt}}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT. Meanwhile, θ𝒯(t)subscriptsuperscript𝜃𝑡𝒯\theta^{(t)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT is also trained on synthetic dataset 𝒮𝒮\mathcal{S}caligraphic_S for N𝑁Nitalic_N steps to get the updated student parameters θ𝒮(t+N)subscriptsuperscript𝜃𝑡𝑁𝒮\theta^{(t+N)}_{\mathcal{S}}italic_θ start_POSTSUPERSCRIPT ( italic_t + italic_N ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT. Formally, the objective is to minimize the normalized squared L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT error between the updated student parameters θ𝒮(t+N)subscriptsuperscript𝜃𝑡𝑁𝒮\theta^{(t+N)}_{\mathcal{S}}italic_θ start_POSTSUPERSCRIPT ( italic_t + italic_N ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT and the future expert (target) parameters θ𝒯(t+M)subscriptsuperscript𝜃𝑡𝑀𝒯\theta^{(t+M)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_t + italic_M ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT:

(𝒮,𝒯)=θ𝒮(t+N)θ𝒯(t+M)22θ𝒯(t)θ𝒯(t+M)22𝒮𝒯subscriptsuperscriptnormsubscriptsuperscript𝜃𝑡𝑁𝒮subscriptsuperscript𝜃𝑡𝑀𝒯22subscriptsuperscriptnormsubscriptsuperscript𝜃𝑡𝒯subscriptsuperscript𝜃𝑡𝑀𝒯22\mathcal{L}(\mathcal{S},\mathcal{T})=\frac{\|\theta^{(t+N)}_{\mathcal{S}}-% \theta^{(t+M)}_{\mathcal{T}}\|^{2}_{2}}{\|\theta^{(t)}_{\mathcal{T}}-\theta^{(% t+M)}_{\mathcal{T}}\|^{2}_{2}}caligraphic_L ( caligraphic_S , caligraphic_T ) = divide start_ARG ∥ italic_θ start_POSTSUPERSCRIPT ( italic_t + italic_N ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT - italic_θ start_POSTSUPERSCRIPT ( italic_t + italic_M ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG ∥ italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT - italic_θ start_POSTSUPERSCRIPT ( italic_t + italic_M ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG (2)
θ𝒮(t+1)=θ𝒮(t)α𝒮(𝒮;θ𝒮(t))superscriptsubscript𝜃𝒮𝑡1superscriptsubscript𝜃𝒮𝑡subscript𝛼𝒮𝒮superscriptsubscript𝜃𝒮𝑡\theta_{\mathcal{S}}^{(t+1)}=\theta_{\mathcal{S}}^{(t)}-\alpha_{\mathcal{S}}% \nabla\ell(\mathcal{S};\theta_{\mathcal{S}}^{(t)})italic_θ start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = italic_θ start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_α start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT ∇ roman_ℓ ( caligraphic_S ; italic_θ start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) (3)
θ𝒯(t+1)=θ𝒯(t)α𝒯(𝒯;θ𝒯(t)),superscriptsubscript𝜃𝒯𝑡1superscriptsubscript𝜃𝒯𝑡subscript𝛼𝒯𝒯superscriptsubscript𝜃𝒯𝑡\theta_{\mathcal{T}}^{(t+1)}=\theta_{\mathcal{T}}^{(t)}-\alpha_{\mathcal{T}}% \nabla\ell(\mathcal{T};\theta_{\mathcal{T}}^{(t)}),italic_θ start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = italic_θ start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_α start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∇ roman_ℓ ( caligraphic_T ; italic_θ start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , (4)

where θ𝒯(t)=θ𝒮(t)superscriptsubscript𝜃𝒯𝑡superscriptsubscript𝜃𝒮𝑡\theta_{\mathcal{T}}^{(t)}=\theta_{\mathcal{S}}^{(t)}italic_θ start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = italic_θ start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT. \ellroman_ℓ is the loss function for model training, where the cross-entropy loss is often adopted, and α𝒮subscript𝛼𝒮\alpha_{\mathcal{S}}italic_α start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT and α𝒯subscript𝛼𝒯\alpha_{\mathcal{T}}italic_α start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT are the learning rates for training on the synthetic and real datasets, respectively. To ensure generalization, MTT usually performs the above trajectory matching process on a large number of expert trajectories from different θ𝒯(0)superscriptsubscript𝜃𝒯0\theta_{\mathcal{T}}^{(0)}italic_θ start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT. Although the subsequent methods have focused on optimizing model parameters (Du et al., 2023, 2024) and objective functions (Cui et al., 2023), the overall process remains roughly the same as MTT.

3.2 Motivation: A New Perspective to Optimize the Trajectory

Through a lot of preliminary experiments and visualizations, we found that the MTT method have three serious shortcomings: 1. Instability of the expert trajectory generated by mini-batch SGD: The expert trajectory τmttsubscript𝜏mtt\tau_{\text{mtt}}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT trained on 𝒯𝒯\mathcal{T}caligraphic_T exhibits erratic oscillations instead of following a path where the loss steadily decreases, so the accuracy of the waypoint model θ𝒯(t)subscriptsuperscript𝜃𝑡𝒯\theta^{(t)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT is subject to fluctuations. This problem complicates the student network to learn better training dynamics with synthetic data. 2. Low convergence speed of the distillation process: When learning expert trajectories, a very large number of iterations are required to obtain a synthetic dataset that can achieve good validation accuracy, which is very time-consuming. 3. High storage consumption of the expert trajectory: To expedite the distillation process, the expert trajectories are pre-generated and stored in memory as trajectory buffers. These trajectories serve as sources from which the initial point θ𝒯(t)subscriptsuperscript𝜃𝑡𝒯\theta^{(t)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT and target parameters θ𝒯(t+M)subscriptsuperscript𝜃𝑡𝑀𝒯\theta^{(t+M)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_t + italic_M ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT are extracted. However, the necessity to store all waypoints for each expert trajectory incurs a substantial storage footprint.

To better explain the internalization of these drawbacks, we propose a novel perspective to view the dataset distillation and explain the essence of the MTT approach: The objective of DD task can be regarded as obtaining a set of parameters (i.e., the synthetic dataset 𝒮𝒮\mathcal{S}caligraphic_S) that enables accurate prediction of how far (magnitude) and where (direction) to step next for any given network parameters θ𝜃\thetaitalic_θ (i.e., provides appropriate guidance V𝒮subscript𝑉𝒮\vec{V}_{\mathcal{S}}over→ start_ARG italic_V end_ARG start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT to update the current network parameters θ𝜃\thetaitalic_θ). From this perspective, each distillation iteration of the MTT method can be viewed as updating the synthetic dataset 𝒮𝒮\mathcal{S}caligraphic_S to provide the network update guidance V𝒮=θ𝒮(t+N)θ𝒯(t)22subscript𝑉𝒮subscriptsuperscriptnormsubscriptsuperscript𝜃𝑡𝑁𝒮subscriptsuperscript𝜃𝑡𝒯22\vec{V}_{\mathcal{S}}=\|\theta^{(t+N)}_{\mathcal{S}}-\theta^{(t)}_{\mathcal{T}% }\|^{2}_{2}over→ start_ARG italic_V end_ARG start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT = ∥ italic_θ start_POSTSUPERSCRIPT ( italic_t + italic_N ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT - italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT of N𝑁Nitalic_N-step SGD training on 𝒮𝒮\mathcal{S}caligraphic_S, which aligns closer to the M𝑀Mitalic_M-step SGD guidance V𝒯=θ𝒯(t+M)θ𝒯(t)22subscript𝑉𝒯subscriptsuperscriptnormsubscriptsuperscript𝜃𝑡𝑀𝒯subscriptsuperscript𝜃𝑡𝒯22\vec{V}_{\mathcal{T}}=\|\theta^{(t+M)}_{\mathcal{T}}-\theta^{(t)}_{\mathcal{T}% }\|^{2}_{2}over→ start_ARG italic_V end_ARG start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT = ∥ italic_θ start_POSTSUPERSCRIPT ( italic_t + italic_M ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT - italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT obtained from the expert trajectory, given an arbitrary initialized point θ𝒯(t)superscriptsubscript𝜃𝒯𝑡\theta_{\mathcal{T}}^{(t)}italic_θ start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT. A simple reformulation of Equ. 2 yields the same result:

min𝒮(𝒮,𝒯)=min𝒮𝔼θ𝒯(t)τmtt(θ𝒮(t+N)θ𝒯(t))(θ𝒯(t+M)θ𝒯(t))22θ𝒯(t)θ𝒯(t+M)22=min𝒮𝔼θ𝒯(t)τmttV𝒮V𝒯22V𝒯22,subscript𝒮𝒮𝒯subscript𝒮subscript𝔼similar-tosubscriptsuperscript𝜃𝑡𝒯subscript𝜏mttsubscriptsuperscriptnormsubscriptsuperscript𝜃𝑡𝑁𝒮subscriptsuperscript𝜃𝑡𝒯subscriptsuperscript𝜃𝑡𝑀𝒯subscriptsuperscript𝜃𝑡𝒯22subscriptsuperscriptnormsubscriptsuperscript𝜃𝑡𝒯subscriptsuperscript𝜃𝑡𝑀𝒯22subscript𝒮subscript𝔼similar-tosubscriptsuperscript𝜃𝑡𝒯subscript𝜏mttsubscriptsuperscriptnormsubscript𝑉𝒮subscript𝑉𝒯22subscriptsuperscriptnormsubscript𝑉𝒯22\min_{\mathcal{S}}\mathcal{L}(\mathcal{S},\mathcal{T})=\min_{\mathcal{S}}% \mathbb{E}_{\theta^{(t)}_{\mathcal{T}}\sim\tau_{\text{mtt}}}\frac{\|(\theta^{(% t+N)}_{\mathcal{S}}-\theta^{(t)}_{\mathcal{T}})-(\theta^{(t+M)}_{\mathcal{T}}-% \theta^{(t)}_{\mathcal{T}})\|^{2}_{2}}{\|\theta^{(t)}_{\mathcal{T}}-\theta^{(t% +M)}_{\mathcal{T}}\|^{2}_{2}}=\min_{\mathcal{S}}\mathbb{E}_{\theta^{(t)}_{% \mathcal{T}}\sim\tau_{\text{mtt}}}\frac{\left\|\vec{V}_{\mathcal{S}}-\vec{V}_{% \mathcal{T}}\right\|^{2}_{2}}{\left\|\vec{V}_{\mathcal{T}}\right\|^{2}_{2}},\\ roman_min start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT caligraphic_L ( caligraphic_S , caligraphic_T ) = roman_min start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∼ italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT end_POSTSUBSCRIPT divide start_ARG ∥ ( italic_θ start_POSTSUPERSCRIPT ( italic_t + italic_N ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT - italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ) - ( italic_θ start_POSTSUPERSCRIPT ( italic_t + italic_M ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT - italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG ∥ italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT - italic_θ start_POSTSUPERSCRIPT ( italic_t + italic_M ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG = roman_min start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∼ italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT end_POSTSUBSCRIPT divide start_ARG ∥ over→ start_ARG italic_V end_ARG start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT - over→ start_ARG italic_V end_ARG start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG ∥ over→ start_ARG italic_V end_ARG start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG , (5)

Thereafter, we can regard all the waypoints of τmttsubscript𝜏mtt\tau_{\text{mtt}}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT as the training “dataset” to optimize V𝒮subscript𝑉𝒮\vec{V}_{\mathcal{S}}over→ start_ARG italic_V end_ARG start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT, i.e., {(θ𝒯(t),V𝒯(t))|θ𝒯(t)τmtt,0tK}conditional-setsubscriptsuperscript𝜃𝑡𝒯subscriptsuperscript𝑉𝑡𝒯formulae-sequencesubscriptsuperscript𝜃𝑡𝒯subscript𝜏mtt0𝑡𝐾{\{(\theta^{(t)}_{\mathcal{T}},\vec{V}^{(t)}_{\mathcal{T}})|\theta^{(t)}_{% \mathcal{T}}\in\tau_{\text{mtt}},0\leq t\leq K\}}{ ( italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT , over→ start_ARG italic_V end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ) | italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT ∈ italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT , 0 ≤ italic_t ≤ italic_K }, where V𝒯(t)subscriptsuperscript𝑉𝑡𝒯\vec{V}^{(t)}_{\mathcal{T}}over→ start_ARG italic_V end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT denotes the “label” of θ𝒯(t)subscriptsuperscript𝜃𝑡𝒯\theta^{(t)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT. From this perspective, the first two drawbacks can be easily explained: Given that the models on the expert trajectory τmttsubscript𝜏mtt\tau_{\text{mtt}}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT are all obtained by SGD training, and considering the variations in sample distribution across mini-batches, the expert trajectory τmttsubscript𝜏mtt\tau_{\text{mtt}}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT has huge oscillations. Therefore, the training dynamics V𝒯(t)subscriptsuperscript𝑉𝑡𝒯\vec{V}^{(t)}_{\mathcal{T}}over→ start_ARG italic_V end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT obtained by sampling two arbitrary points with an interval of M𝑀Mitalic_M steps from τmttsubscript𝜏mtt\tau_{\text{mtt}}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT cannot guarantee to always provide a favorable direction for V𝒮(t)subscriptsuperscript𝑉𝑡𝒮\vec{V}^{(t)}_{\mathcal{S}}over→ start_ARG italic_V end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT to learn. The final result is 1) poor V𝒯(t)subscriptsuperscript𝑉𝑡𝒯\vec{V}^{(t)}_{\mathcal{T}}over→ start_ARG italic_V end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT leads to instability; 2) considerable time is expended in identifying the optimal optimization direction to achieve convergence. This raises the question: Is there a superior trajectory τ^^𝜏\hat{\tau}over^ start_ARG italic_τ end_ARG that consistently delivers more advantageous V𝒯(t)subscriptsuperscript𝑉𝑡𝒯\vec{V}^{(t)}_{\mathcal{T}}over→ start_ARG italic_V end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT to optimize the synthetic dataset 𝒮𝒮\mathcal{S}caligraphic_S through V𝒮(t)subscriptsuperscript𝑉𝑡𝒮\vec{V}^{(t)}_{\mathcal{S}}over→ start_ARG italic_V end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT?

We believe that an ideal expert trajectory should: 1) For any θ^𝒯(t)superscriptsubscript^𝜃𝒯𝑡\hat{\theta}_{\mathcal{T}}^{(t)}over^ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT on τ^^𝜏\hat{\tau}over^ start_ARG italic_τ end_ARG, the obtained V𝒯(t)subscriptsuperscript𝑉𝑡𝒯\vec{V}^{(t)}_{\mathcal{T}}over→ start_ARG italic_V end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT should always point to the direction that guides the target loss (𝒯;θ𝒯(t))𝒯superscriptsubscript𝜃𝒯𝑡\ell(\mathcal{T};\theta_{\mathcal{T}}^{(t)})roman_ℓ ( caligraphic_T ; italic_θ start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) to decrease; 2) This trajectory is easier to fit for 𝒮𝒮\mathcal{S}caligraphic_S, because the size of 𝒮𝒮\mathcal{S}caligraphic_S is much smaller than the original dataset 𝒯𝒯\mathcal{T}caligraphic_T. 3) The trajectory is easy to save and restore.

We draw inspiration from convex optimization (Boyd and Vandenberghe, 2004; Bubeck, 2015) and NTK (Jacot et al., 2018; Hanin and Nica, 2019). First, since deep learning is essentially a non-convex problem, if we can make expert trajectories exhibit more convex properties, optimization becomes much less difficult. How to find convex trajectories? NTK methods prove that for a neural network fθ(x)subscript𝑓𝜃𝑥f_{\theta}(x)italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ), its update can be approximated by its first-order Taylor expansion in the neural network tangent space (Lee et al., 2019):

fθ(x)flin,θ(x)=fθ0(x)+(θθ0)𝖳θfθ0(x).subscript𝑓𝜃𝑥subscript𝑓𝑙𝑖𝑛𝜃𝑥subscript𝑓subscript𝜃0𝑥superscript𝜃subscript𝜃0𝖳subscript𝜃subscript𝑓subscript𝜃0𝑥f_{\theta}(x)\approx f_{lin,\theta}(x)=f_{\theta_{0}}(x)+(\theta-\theta_{0})^{% \mathsf{T}}\nabla_{\theta}f_{\theta_{0}}(x).italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) ≈ italic_f start_POSTSUBSCRIPT italic_l italic_i italic_n , italic_θ end_POSTSUBSCRIPT ( italic_x ) = italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ) + ( italic_θ - italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ) . (6)

From this, we believe that replacing the original trajectory with a convex combination (linear) trajectory would be much more effective. The starting and ending points of this linear trajectory are the same as τmttsubscript𝜏mtt\tau_{\text{mtt}}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT, and all the waypoints are distributed along this line. This trajectory meets our needs very well: 1) The visualization in Figure 1(a) verifies that the validation accuracy of the model on this trajectory consistently increases; 2) The direction of any V𝒯(t)subscriptsuperscript𝑉𝑡𝒯\vec{V}^{(t)}_{\mathcal{T}}over→ start_ARG italic_V end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT sampled from this trajectory is always from the starting point to the ending (optimal) point, which is easy to fit for distilled data; 3) Only the parameters of its starting and ending points need to be stored, and the trajectory can be reconstructed by linear interpolation; 4) This trajectory is continuous, rather than consisting of intermittently sampled points like the original path, which greatly enriches our training set.

Refer to caption
Figure 2: An illustration of the proposed MCT method. The left figure illustrates a schematic of the landscape in the model parameter space, while the right figure shows the validation accuracy of waypoint models extracted from expert trajectories of both the MTT method and our MCT method. In the left figure, the original trajectory τmttsubscript𝜏mtt\tau_{\text{mtt}}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT exhibits constant oscillations, causing V𝒯(t)subscriptsuperscript𝑉𝑡𝒯\vec{V}^{(t)}_{\mathcal{T}}over→ start_ARG italic_V end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT to continuously change, resulting in fluctuating accuracy of the expert model in the right figure. In contrast, the trajectory τconvsubscript𝜏conv\tau_{\text{conv}}italic_τ start_POSTSUBSCRIPT conv end_POSTSUBSCRIPT of our MCT method is very stable, thereby ensuring a consistent guidance direction, which leads to a steady improvement of the expert model as shown in the right figure.

4 Our proposed MCT Method

4.1 Matching Convexified Trajectory

The expert trajectory τmttsubscript𝜏mtt\tau_{\text{mtt}}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT is pre-generated with the parameter of all waypoint models stored in memory, i.e., τmtt={θ𝒯(0),θ𝒯(1),,θ𝒯(t),,θ𝒯(K)}subscript𝜏mttsubscriptsuperscript𝜃0𝒯subscriptsuperscript𝜃1𝒯subscriptsuperscript𝜃𝑡𝒯subscriptsuperscript𝜃𝐾𝒯\tau_{\text{mtt}}=\{\theta^{(0)}_{\mathcal{T}},\theta^{(1)}_{\mathcal{T}},% \dots,\theta^{(t)}_{\mathcal{T}},\dots,\theta^{(K)}_{\mathcal{T}}\}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT = { italic_θ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT , italic_θ start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT , … , italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT , … , italic_θ start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT }, where θ𝒯(t)subscriptsuperscript𝜃𝑡𝒯\theta^{(t)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT is computed by multiple steps of mini-batch SGD (Cazenavette et al., 2022).

However, the trajectory generated by vanilla mini-batch SGD exhibits strong non-convexity, which makes synthetic data challenging to converge to an optimal solution. To this end, we proposed MCT, which creates a convexified trajectory τconvsubscript𝜏conv\tau_{\text{conv}}italic_τ start_POSTSUBSCRIPT conv end_POSTSUBSCRIPT and is defined as:

τconv={θ^(t)|0tK},subscript𝜏convconditional-setsuperscript^𝜃𝑡0𝑡𝐾\tau_{\text{conv}}=\{\hat{\theta}^{(t)}|0\leq t\leq K\},italic_τ start_POSTSUBSCRIPT conv end_POSTSUBSCRIPT = { over^ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | 0 ≤ italic_t ≤ italic_K } , (7)
θ^(t)=(1β(t))θ𝒯(0)+β(t)θ𝒯(K),superscript^𝜃𝑡1superscript𝛽𝑡subscriptsuperscript𝜃0𝒯superscript𝛽𝑡subscriptsuperscript𝜃𝐾𝒯\hat{\theta}^{(t)}=(1-\beta^{(t)}){\theta}^{(0)}_{\mathcal{T}}+\beta^{(t)}{% \theta}^{(K)}_{\mathcal{T}},over^ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = ( 1 - italic_β start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_θ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT + italic_β start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT italic_θ start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT , (8)

where β(t)(0,1)superscript𝛽𝑡01\beta^{(t)}\in(0,1)italic_β start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∈ ( 0 , 1 ) is a weight value that determines the distribution of all waypoints. The starting point θ^(0)superscript^𝜃0\hat{\theta}^{(0)}over^ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT and ending point θ^(K)superscript^𝜃𝐾\hat{\theta}^{(K)}over^ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT are same as θ𝒯(0)subscriptsuperscript𝜃0𝒯{\theta}^{(0)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT and θ𝒯(K)subscriptsuperscript𝜃𝐾𝒯{\theta}^{(K)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT in τconvsubscript𝜏conv\tau_{\text{conv}}italic_τ start_POSTSUBSCRIPT conv end_POSTSUBSCRIPT. Particularly, the generated trajectory τconvsubscript𝜏conv\tau_{\text{conv}}italic_τ start_POSTSUBSCRIPT conv end_POSTSUBSCRIPT directly points from θ𝒯(0)subscriptsuperscript𝜃0𝒯{\theta}^{(0)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT to θ𝒯(K)subscriptsuperscript𝜃𝐾𝒯{\theta}^{(K)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT, and β(t)superscript𝛽𝑡\beta^{(t)}italic_β start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT is determined by the ratio of the difference between θ𝒯(t1)subscriptsuperscript𝜃𝑡1𝒯{\theta}^{(t-1)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT and θ𝒯(t)subscriptsuperscript𝜃𝑡𝒯{\theta}^{(t)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT in τmttsubscript𝜏mtt\tau_{\text{mtt}}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT to the total length of τmttsubscript𝜏mtt\tau_{\text{mtt}}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT as:

β(0)=0,β(t)=l=0t1Norm(θ𝒯(l+1)θ𝒯(l))l=0K1Norm(θ𝒯(l+1)θ𝒯(l)),t=1,2,,K,formulae-sequencesuperscript𝛽00formulae-sequencesuperscript𝛽𝑡superscriptsubscript𝑙0𝑡1Normsuperscriptsubscript𝜃𝒯𝑙1superscriptsubscript𝜃𝒯𝑙superscriptsubscript𝑙0𝐾1Normsuperscriptsubscript𝜃𝒯𝑙1superscriptsubscript𝜃𝒯𝑙𝑡12𝐾\displaystyle\begin{split}\beta^{(0)}&=0,\\ \beta^{(t)}&=\frac{\sum_{l=0}^{t-1}{\mathrm{Norm}(\theta_{\mathcal{T}}^{(l+1)}% -\theta_{\mathcal{T}}^{(l)})}}{\sum_{l=0}^{K-1}{\mathrm{Norm}(\theta_{\mathcal% {T}}^{(l+1)}-\theta_{\mathcal{T}}^{(l)})}},t=1,2,\dots,K,\end{split}start_ROW start_CELL italic_β start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT end_CELL start_CELL = 0 , end_CELL end_ROW start_ROW start_CELL italic_β start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL start_CELL = divide start_ARG ∑ start_POSTSUBSCRIPT italic_l = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT roman_Norm ( italic_θ start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l + 1 ) end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l ) end_POSTSUPERSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_l = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K - 1 end_POSTSUPERSCRIPT roman_Norm ( italic_θ start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l + 1 ) end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l ) end_POSTSUPERSCRIPT ) end_ARG , italic_t = 1 , 2 , … , italic_K , end_CELL end_ROW (9)

where Norm()Norm\mathrm{Norm}(\cdot)roman_Norm ( ⋅ ) is L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT norm. To mitigate discrepancies among different network layers, we calculate the L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT normalization for each layer individually, i.e., β(t)=[β1(t),β2(t),,βn(t)]𝖳superscript𝛽𝑡superscriptsubscriptsuperscript𝛽𝑡1subscriptsuperscript𝛽𝑡2subscriptsuperscript𝛽𝑡𝑛𝖳\beta^{(t)}=[\beta^{(t)}_{1},\beta^{(t)}_{2},\dots,\beta^{(t)}_{n}]^{\mathsf{T}}italic_β start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = [ italic_β start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_β start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_β start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT, where each element in β(t)superscript𝛽𝑡\beta^{(t)}italic_β start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT represents the weight value of a network layer. Note that our trajectory is generated based on τmttsubscript𝜏mtt\tau_{\text{mtt}}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT, and the calculation of β(t)superscript𝛽𝑡\beta^{(t)}italic_β start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT does not require saving all the intermediate models θ𝒯(t)subscriptsuperscript𝜃𝑡𝒯{\theta}^{(t)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT. It only needs to save Norm(θ𝒯(l+1)θ𝒯(l))Normsuperscriptsubscript𝜃𝒯𝑙1superscriptsubscript𝜃𝒯𝑙\mathrm{Norm}(\theta_{\mathcal{T}}^{(l+1)}-\theta_{\mathcal{T}}^{(l)})roman_Norm ( italic_θ start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l + 1 ) end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l ) end_POSTSUPERSCRIPT ) obtained in each step of the expert trajectory τmttsubscript𝜏mtt\tau_{\text{mtt}}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT, allowing β(t)superscript𝛽𝑡\beta^{(t)}italic_β start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT to be calculated at the end of expert training. Given this expert trajectory τconvsubscript𝜏conv\tau_{\text{conv}}italic_τ start_POSTSUBSCRIPT conv end_POSTSUBSCRIPT, the distillation in Equ. 2 can be conducted. During distillation, our MCT method always provides a convexified guidance V𝒯(t)subscriptsuperscript𝑉𝑡𝒯\vec{{V}}^{(t)}_{\mathcal{T}}over→ start_ARG italic_V end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT with the direction from θ𝒯(0)subscriptsuperscript𝜃0𝒯{\theta}^{(0)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT to θ𝒯(K)subscriptsuperscript𝜃𝐾𝒯{\theta}^{(K)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT, leading to the steady optimization of V𝒮(t)subscriptsuperscript𝑉𝑡𝒮\vec{{V}}^{(t)}_{\mathcal{S}}over→ start_ARG italic_V end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT, and thus, the convergence of 𝒮𝒮\mathcal{S}caligraphic_S will be rapid.

4.2 Continuous Sampling

Due to the continuity of our convexified trajectory, we can perform continuous sampling from the trajectory during distillation. This approach is completely different from the MTT method, enabling the selection of intermediate positions such as "the 1.5th point." Specifically, the MTT method only performs discrete sampling on the expert trajectory (i.e., selecting θ𝒯(t)subscriptsuperscript𝜃𝑡𝒯{\theta}^{(t)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT with an integer t𝑡titalic_t ). In contrast, for τconvsubscript𝜏conv\tau_{\text{conv}}italic_τ start_POSTSUBSCRIPT conv end_POSTSUBSCRIPT with the starting point θ^(0)superscript^𝜃0\hat{\theta}^{(0)}over^ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT and ending point θ^(K)superscript^𝜃𝐾\hat{\theta}^{(K)}over^ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT, since all points are on a straight line, we can obtain any timestep θ^(c)superscript^𝜃𝑐\hat{\theta}^{(c)}over^ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_c ) end_POSTSUPERSCRIPT with a decimal c[0,K]𝑐0𝐾c\in[0,K]italic_c ∈ [ 0 , italic_K ] on this line by interpolation:

θ^(c)=(1β^)θ^(0)+β^θ^(K),β^=(1η)β(c)+ηβ(c),η=cc,formulae-sequencesuperscript^𝜃𝑐1^𝛽superscript^𝜃0^𝛽superscript^𝜃𝐾formulae-sequence^𝛽1𝜂superscript𝛽𝑐𝜂superscript𝛽𝑐𝜂𝑐𝑐\displaystyle\begin{split}\hat{\theta}^{(c)}&=(1-\hat{\beta})\hat{\theta}^{(0)% }+\hat{\beta}\hat{\theta}^{(K)},\\ \hat{\beta}&=(1-\eta)\beta^{(\lfloor c\rfloor)}+\eta\beta^{(\lceil c\rceil)},% \\ \eta&=c-\lfloor c\rfloor,\end{split}start_ROW start_CELL over^ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_c ) end_POSTSUPERSCRIPT end_CELL start_CELL = ( 1 - over^ start_ARG italic_β end_ARG ) over^ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT + over^ start_ARG italic_β end_ARG over^ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT , end_CELL end_ROW start_ROW start_CELL over^ start_ARG italic_β end_ARG end_CELL start_CELL = ( 1 - italic_η ) italic_β start_POSTSUPERSCRIPT ( ⌊ italic_c ⌋ ) end_POSTSUPERSCRIPT + italic_η italic_β start_POSTSUPERSCRIPT ( ⌈ italic_c ⌉ ) end_POSTSUPERSCRIPT , end_CELL end_ROW start_ROW start_CELL italic_η end_CELL start_CELL = italic_c - ⌊ italic_c ⌋ , end_CELL end_ROW (10)

After θ^(c)superscript^𝜃𝑐\hat{\theta}^{(c)}over^ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_c ) end_POSTSUPERSCRIPT and θ^(c+M)superscript^𝜃𝑐𝑀\hat{\theta}^{(c+M)}over^ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_c + italic_M ) end_POSTSUPERSCRIPT are obtained, the distillation process can be conducted. This continuous sampling strategy ensures sufficient learning and fitting of the entire expert trajectory τconvsubscript𝜏conv\tau_{\text{conv}}italic_τ start_POSTSUBSCRIPT conv end_POSTSUBSCRIPT, facilitating thorough learning of the synthetic dataset 𝒮𝒮\mathcal{S}caligraphic_S.

4.3 Memory-Efficient Storage

In conventional MTT, the learning of the expert trajectory requires storing the parameters of all timesteps in memory, which will incur significant storage overhead. Formally, let W𝑊Witalic_W denote the size of the network parameter θ𝒯(t)subscriptsuperscript𝜃𝑡𝒯\theta^{(t)}_{\mathcal{T}}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT caligraphic_T end_POSTSUBSCRIPT and C𝐶Citalic_C denote the size of other irrelevant parameters. Since there are K𝐾Kitalic_K timesteps on τmttsubscript𝜏mtt\tau_{\text{mtt}}italic_τ start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT, the entire required storage will be:

Storagemtt=K×W+C=O(KW).subscriptStoragemtt𝐾𝑊𝐶𝑂𝐾𝑊\text{Storage}_{\text{mtt}}=K\times W+C=O(KW).Storage start_POSTSUBSCRIPT mtt end_POSTSUBSCRIPT = italic_K × italic_W + italic_C = italic_O ( italic_K italic_W ) . (11)

In contrast, our method only requires storing the starting point θ^(0)superscript^𝜃0\hat{\theta}^{(0)}over^ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT, the ending point θ^(K)superscript^𝜃𝐾\hat{\theta}^{(K)}over^ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_K ) end_POSTSUPERSCRIPT, and point distribution {β(t)|0tK}conditional-setsuperscript𝛽𝑡0𝑡𝐾\{\beta^{(t)}|0\leq t\leq K\}{ italic_β start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | 0 ≤ italic_t ≤ italic_K } along the trajectory. Therefore, the entire storage cost becomes:

Storageconv=2×W+K×(β(t))+C.subscriptStorageconv2𝑊𝐾superscript𝛽𝑡𝐶\text{Storage}_{\text{conv}}=2\times W+K\times(\beta^{(t)})+C.Storage start_POSTSUBSCRIPT conv end_POSTSUBSCRIPT = 2 × italic_W + italic_K × ( italic_β start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) + italic_C . (12)

Since β(t)superscript𝛽𝑡\beta^{(t)}italic_β start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT is a floating-point number, the storage cost will be Storageconv=O(W)subscriptStorageconv𝑂𝑊\text{Storage}_{\text{conv}}=O(W)Storage start_POSTSUBSCRIPT conv end_POSTSUBSCRIPT = italic_O ( italic_W ). In practice, K𝐾Kitalic_K is usually set to 50. Once the surrogate models in distillation become complex (e.g. LLMs), K𝐾Kitalic_K and W𝑊Witalic_W will increase simultaneously, highlighting the significant storage advantages of our MCT method.

Refer to caption
(a) Convergence on CIFAR-10
Refer to caption
(b) Convergence on CIFAR-100
Refer to caption
(c) Storage Comparison
Figure 3: (a) and (b): Convergence comparisons of distillation process on CIFAR-10 and CIFAR-100, where the symbol “star” denotes the convergence point. (c): Storage comparisons on three datasets.

5 Experiment

5.1 Experiment Setup

Experiment Settings: We evaluated our method on three datasets: CIFAR-10 and CIFAR-100 (Krizhevsky et al., 2009), and Tiny-ImageNet (Le and Yang, ). We first generated the convexified trajectories with our MCT method. Similar to MTT, we applied Kornia (Riba et al., 2020) Zero component analysis (ZCA) whitening on CIFAR-10, CIFAR-100, and Tiny-ImageNet datasets, and utilized Differentiable Siamese Augmentation (DSA) (Zhao and Bilen, 2021) technique during training and evaluation.

Evaluation and Baselines: Our MCT method is compared with several baselines from different branches, including Dataset Condensation (DC) (Zhao et al., 2020), Distribution Matching (DM) (Zhao and Bilen, 2023), DSA (Zhao and Bilen, 2021), Condense Aligning FEatures (CAFE) Wang et al. (2022), dataset distillation using Parameter Pruning (PP) (Li et al., 2023), and MTT. Following the conventional settings, we conducted dataset distillation using 1/10/50 images per class (ipc) for evaluations, respectively. The images with the resolution of 32 × 32 and 64 × 64 were synthesized on the CIFAR and Tiny-ImageNet datasets, respectively. Subsequently, five randomly initialized networks were trained in 1000 iterations with the cross-entropy loss on the distilled dataset. These trained networks were then evaluated on the real validation set, and their average accuracy (Acc) was reported as the evaluation metric. To maintain consistency with MTT and DC, we use ConvNet (Gidaris and Komodakis, 2018) as the surrogate model. This model comprises 128 filters with a 3 × 3 kernel size. Following the filters, instance normalization (Ulyanov et al., 2016) and ReLU activation are applied. Additionally, an average pooling layer with a kernel size of 2 × 2 and a stride of 2 is incorporated into the network.

Implementation Details: We adopt the same settings of MTT in most cases. Specifically, 100 expert trajectories are generated, each spanning 50 epochs of training (i.e., 51 timesteps). In practice, we often insert two waypoint models in the expert trajectory of MTT to derive our convexified trajectory: the models of 6-th and 25-th epochs for CIFAR-10 and the models of 15-th and 30-th epochs for CIFAR-100 and Tiny-ImageNet. During the distillation process, 5,000 distillation iterations are conducted. For each iteration, θ^(c)superscript^𝜃𝑐\hat{\theta}^{(c)}over^ start_ARG italic_θ end_ARG start_POSTSUPERSCRIPT ( italic_c ) end_POSTSUPERSCRIPT is generated from Equ. 10, where the decimal c𝑐citalic_c is randomly sampled within [0, MaxStartEpoch]. We adopt the SGD optimizer, and a learnable learning rate is employed to distill the synthetic data. All experiments are run on four RTX3090 GPUs.

Table 1: Performance of Various Algorithms on Different Datasets
Dataset CIFAR-10 CIFAR-100 Tiny ImageNet
ipc 1 10 50 1 10 50 1 10 50
Random 15.4±0.3 31.0±0.5 50.6±0.3 4.2±0.3 14.6±0.5 33.4±0.4 1.4±0.1 5.0±0.2 15.0±0.4
DC (Zhao et al., 2020) 28.3±0.5 44.9±0.5 53.9±0.5 12.8±0.3 25.2±0.3 - - - -
DSA (Zhao and Bilen, 2021) 28.8±0.7 52.1±0.5 60.6±0.5 13.9±0.3 32.3±0.3 42.8±0.4 - - -
CAFE (Wang et al., 2022) 30.3±1.1 46.3±0.6 55.5±0.6 12.9±0.3 27.8±0.3 37.9±0.3 - - -
DM (Zhao and Bilen, 2023) 26.0±0.8 48.9±0.6 63.0±0.4 11.4±0.3 29.7±0.3 43.6±0.4 3.9±0.2 12.9±0.4 24.1±0.3
PP (Li et al., 2023) 46.4±0.6 65.5±0.3 71.9±0.2 24.6±0.1 43.1±0.3 48.4±0.3 - - -
MTT (Cazenavette et al., 2022) 46.3±0.8 65.3±0.7 71.6±0.2 24.3±0.3 40.1±0.4 47.7±0.2 8.8±0.3 23.2±0.2 28.0±0.3
Ours 48.5±0.2 66.0±0.3 72.3±0.3 24.5±0.5 42.5±0.5 46.8±0.2 9.6±0.5 22.6±0.8 27.6±0.4
Full dataset 84.8±0.1 56.2±0.3 37.6±0.4

5.2 Experiment Result

Validation Accuracy Comparison. Table 1 presents a comparison of validation accuracy between our method and various baselines across three datasets. Although performance is not the main focus of our MCT method, it is evident that our method achieves the best performance on the three metrics of the CIFAR-10 dataset as well as the ipc=1 metric of the Tiny ImageNet dataset. Notably, compared to the crucial MTT method, our MCT method demonstrates performance improvements in most metrics, indicating that our convexified trajectory and continuous sampling strategy can indeed provide enhanced guidance to the optimization of synthetic datasets.

Convergence of Distillation Process. Figure 3(a) and 3(b) illustrate the distillation processes utilizing the MCT and MTT methods for the CIFAR-10 and CIFAR-100 datasets. After every 100 distillation iterations, five networks with random initialization are trained on the current distillation dataset and their average accuracy on the validation set are recorded. The figures present the validation accuracy trends of both methods over the initial 2,500 iterations. As depicted, under all ipc settings, our MCT method achieves a substantial performance much sooner (200-1200 iterations ahead), indicating a faster convergence speed; after nearing convergence, the performance of the MCT method remains consistently stable as iterations proceed, whereas the MTT method still experiences significant performance fluctuations. These two phenomena suggest that our method effectively enhances training stability and accelerates the convergence process.

Comparison of Storage Requirement. Figure 3(c) compares the required storage of the expert trajectory between MTT and our MCT method. As demonstrated in Sec. 4.3, it is clear that our convex trajectories require significantly less memory (approximately 8%) compared to the expert trajectories needed by the MTT method. It is foreseeable that as model sizes and expert trajectories continue to grow, the space savings offered by our method will become even more substantial.

Visualization of Distilled Data. The visualization results of the synthetic data on CIFAR-10 with ipc=10 and CIFAR-100 with ipc=1 are presented in Figure 5(a) and 5(b). As we can see, the synthetic set learned from our expert trajectories exhibits notable degrees of recognizability and authenticity, while it also tends to integrate various characteristic features of images within the same category.

Refer to caption
(a) 1 Expert Trajectory
Refer to caption
(b) 10 Expert Trajectories
Refer to caption
(c) 50 Expert Trajectories
Figure 4: Effects of Continuous Sampling over iterations with different expert trajectory numbers.

5.3 Ablation Studies

Table 2: Effects of Continuous Sampling with different numbers of expert trajectories on CIFAR-10.
Number of expert trajectories 1 5 10 20 50
w/o. Continuous Sampling 54.8±0.2 60.6±0.2 61.5±0.3 62.3±0.3 62.1±0.4
w. Continuous Sampling 56.2±0.3 61.3±0.5 61.8±0.6 62.8±0.3 62.8±0.2
Table 3: Effects of different M𝑀Mitalic_M with different ipc on CIFAR-10.
M𝑀Mitalic_M 3 4 5 6 7
ipc=1 46.7 47.1 48.5 48.0 45.6
ipc=10 62.3 62.6 65.0 66.0 65.2
ipc=50 70.0 71.4 71.8 72.3 71.8

Effects of Continuous Sampling. To verify the effect of the continuous sampling, we set ipc=10 and randomized the starting epoch parameter within the range [0,5] on the CIFAR-10 dataset. The validation accuracy over iterations and the optimal accuracy throughout the entire distillation process are reported in Figure 4(c) and Table. 2, respectively. Overall, the integration of continuous sampling can improve the validation performance under all conditions. Moreover, the fewer the number of expert trajectories, the more pronounced the performance improvement brought about by the continuous sampling strategy. Those results prove that our continuous sampling can effectively expand the sampling space, ultimately leading to the enhancement of the final distillation outcomes.

Effects of expert updating step M𝑀Mitalic_M. Table 3 shows the effects of the updating step M𝑀Mitalic_M of the expert trajectory τconvsubscript𝜏conv\tau_{\text{conv}}italic_τ start_POSTSUBSCRIPT conv end_POSTSUBSCRIPT on the CIFAR-10 dataset. N𝑁Nitalic_N is set to 50 for all results. As we can see, when ipc=1, the optimal performance can be obtained at M𝑀Mitalic_M=5, while when ipc=10 and ipc=50, the optimal performance can be obtained at M𝑀Mitalic_M=6. Overall, our MCT method is robust to the selection of M𝑀Mitalic_M and will not experience significant performance degradation with changes in M𝑀Mitalic_M.

Refer to caption
(a) CIFAR-10, ipc=10
Refer to caption
(b) CIFAR-100, ipc=1
Figure 5: Visualization of synthetic dataset.

6 Conclusion

To address three major limitations of traditional MTT, this paper draws inspiration from NTK methods and proposes a novel perspective to understand the essence of dataset distillation and MTT. A simple yet novel Matching Convexified Trajectory method is introduced to create a simplified, convexified expert trajectory that enhances the optimization process, leading to more stable and rapid convergence and reduced memory consumption. The convexified trajectory allows for continuous sampling during distillation, enriching the learning process and ensuring thorough fitting of the expert trajectory. Our experiments on CIFAR-10, CIFAR-100, and Tiny-ImageNet datasets demonstrate MCT’s superiority over MTT and other baselines. MCT’s ablation studies confirm the benefits of continuous sampling and the impact of the convexified trajectory on distillation performance. The results indicate that MCT is a promising solution for training complex models with reduced data needs, offering an efficient, stable, and memory-friendly approach to dataset distillation.

7 Limitations

Our MCT method has three primary limitations: 1. Although MCT can effectively enhance training stability and convergence speed, the improvement in validation accuracy is not very significant due to the starting and ending points being the same as those in MTT, and the enhancement is mainly attributed to the more thorough trajectory learning enabled by continuous sampling; future work could identify better endpoints to further improve performance. 2. While our trajectory provides a direction with more stable and rapidly descending, the calculation the magnitude β𝛽\betaitalic_β is relatively simple (derived by the proportion of MTT step size to trajectory length), and there may exist more optimal step sizes that allow for more rapid and robust trajectory learning. 3. The linear approximation of NTK is proposed based on infinitely wide networks and requires some rather strict initialization methods. However, we did not conduct our work under this condition; further theoretical deduction is required to address this issue.

References

  • Arora et al. [2019] Sanjeev Arora, Simon S Du, Wei Hu, Zhiyuan Li, Russ R Salakhutdinov, and Ruosong Wang. On exact computation with an infinitely wide neural net. Advances in neural information processing systems, 32, 2019.
  • Boyd and Vandenberghe [2004] Stephen Boyd and Lieven Vandenberghe. Convex Optimization. Cambridge University Press, 2004.
  • Bubeck [2015] Sébastien Bubeck. Convex optimization: Algorithms and complexity, 2015.
  • Cazenavette et al. [2022] George Cazenavette, Tongzhou Wang, Antonio Torralba, Alexei A Efros, and Jun-Yan Zhu. Dataset distillation by matching training trajectories. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 4750–4759, 2022.
  • Cui et al. [2023] Justin Cui, Ruochen Wang, Si Si, and Cho-Jui Hsieh. Scaling up dataset distillation to imagenet-1k with constant memory. In International Conference on Machine Learning, pages 6565–6590. PMLR, 2023.
  • Du et al. [2023] Jiawei Du, Yidi Jiang, Vincent YF Tan, Joey Tianyi Zhou, and Haizhou Li. Minimizing the accumulated trajectory error to improve dataset distillation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 3749–3758, 2023.
  • Du et al. [2024] Jiawei Du, Qin Shi, and Joey Tianyi Zhou. Sequential subset matching for dataset distillation. Advances in Neural Information Processing Systems, 36, 2024.
  • Gidaris and Komodakis [2018] Spyros Gidaris and Nikos Komodakis. Dynamic few-shot visual learning without forgetting. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 4367–4375, 2018.
  • Guo et al. [2023] Ziyao Guo, Kai Wang, George Cazenavette, Hui Li, Kaipeng Zhang, and Yang You. Towards lossless dataset distillation via difficulty-aligned trajectory matching. arXiv preprint arXiv:2310.05773, 2023.
  • Hanin and Nica [2019] Boris Hanin and Mihai Nica. Finite depth and width corrections to the neural tangent kernel. arXiv preprint arXiv:1909.05989, 2019.
  • Jacot et al. [2018] Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Convergence and generalization in neural networks. Advances in neural information processing systems, 31, 2018.
  • Jiang et al. [2023] Zixuan Jiang, Jiaqi Gu, Mingjie Liu, and David Z Pan. Delving into effective gradient matching for dataset condensation. In 2023 IEEE International Conference on Omni-layer Intelligent Systems (COINS), pages 1–6. IEEE, 2023.
  • Krizhevsky et al. [2009] Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images. 2009.
  • [14] Ya Le and Xuan Yang. Tiny imagenet visual recognition challenge.
  • Lee et al. [2019] Jaehoon Lee, Lechao Xiao, Samuel Schoenholz, Yasaman Bahri, Roman Novak, Jascha Sohl-Dickstein, and Jeffrey Pennington. Wide neural networks of any depth evolve as linear models under gradient descent. Advances in neural information processing systems, 32, 2019.
  • Lee et al. [2022] Saehyung Lee, Sanghyuk Chun, Sangwon Jung, Sangdoo Yun, and Sungroh Yoon. Dataset condensation with contrastive signals. In International Conference on Machine Learning, pages 12352–12364. PMLR, 2022.
  • Li et al. [2020] Guang Li, Ren Togo, Takahiro Ogawa, and Miki Haseyama. Soft-label anonymous gastric x-ray image distillation. In 2020 IEEE International Conference on Image Processing (ICIP), pages 305–309. IEEE, 2020.
  • Li et al. [2022a] Guang Li, Ren Togo, Takahiro Ogawa, and Miki Haseyama. Compressed gastric image generation based on soft-label dataset distillation for medical data sharing. Computer Methods and Programs in Biomedicine, 227:107189, 2022a.
  • Li et al. [2022b] Guang Li, Ren Togo, Takahiro Ogawa, and Miki Haseyama. Dataset distillation for medical dataset sharing. arXiv preprint arXiv:2209.14603, 2022b.
  • Li et al. [2023] Guang Li, Ren Togo, Takahiro Ogawa, and Miki Haseyama. Dataset distillation using parameter pruning. IEICE Transactions on Fundamentals of Electronics, Communications and Computer Sciences, 2023.
  • Loo et al. [2022] Noel Loo, Ramin Hasani, Alexander Amini, and Daniela Rus. Efficient dataset distillation using random feature approximation. Advances in Neural Information Processing Systems, 35:13877–13891, 2022.
  • Nguyen et al. [2021] Timothy Nguyen, Roman Novak, Lechao Xiao, and Jaehoon Lee. Dataset distillation with infinitely wide convolutional networks. Advances in Neural Information Processing Systems, 34:5186–5198, 2021.
  • Riba et al. [2020] Edgar Riba, Dmytro Mishkin, Daniel Ponsa, Ethan Rublee, and Gary Bradski. Kornia: an open source differentiable computer vision library for pytorch. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pages 3674–3683, 2020.
  • Ulyanov et al. [2016] Dmitry Ulyanov, Andrea Vedaldi, and Victor Lempitsky. Instance normalization: The missing ingredient for fast stylization. arXiv preprint arXiv:1607.08022, 2016.
  • Wang et al. [2022] Kai Wang, Bo Zhao, Xiangyu Peng, Zheng Zhu, Shuo Yang, Shuo Wang, Guan Huang, Hakan Bilen, Xinchao Wang, and Yang You. Cafe: Learning to condense dataset by aligning features. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 12196–12205, 2022.
  • Wang et al. [2018] Tongzhou Wang, Jun-Yan Zhu, Antonio Torralba, and Alexei A Efros. Dataset distillation. arXiv preprint arXiv:1811.10959, 2018.
  • Zhao and Bilen [2021] Bo Zhao and Hakan Bilen. Dataset condensation with differentiable siamese augmentation. In International Conference on Machine Learning, pages 12674–12685. PMLR, 2021.
  • Zhao and Bilen [2023] Bo Zhao and Hakan Bilen. Dataset condensation with distribution matching. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pages 6514–6523, 2023.
  • Zhao et al. [2020] Bo Zhao, Konda Reddy Mopuri, and Hakan Bilen. Dataset condensation with gradient matching. In International Conference on Learning Representations, 2020.
  • Zhou et al. [2022] Yongchao Zhou, Ehsan Nezhadarya, and Jimmy Ba. Dataset distillation using neural feature regression. Advances in Neural Information Processing Systems, 35:9813–9827, 2022.