The Benefits of Reusing Batches for Gradient Descent in Two-Layer Networks: Breaking the Curse of Information and Leap Exponents
Abstract
We investigate the training dynamics of two-layer neural networks when learning multi-index target functions. We focus on multi-pass gradient descent (GD) that reuses the batches multiple times and show that it significantly changes the conclusion about which functions are learnable compared to single-pass gradient descent. In particular, multi-pass GD with finite stepsize is found to overcome the limitations of gradient flow and single-pass GD given by the information exponent (Ben Arous et al., 2021) and leap exponent (Abbe et al., 2023) of the target function. We show that upon re-using batches, the network achieves in just two time steps an overlap with the target subspace even for functions not satisfying the staircase property (Abbe et al., 2021). We characterize the (broad) class of functions efficiently learned in finite time. The proof of our results is based on the analysis of the Dynamical Mean-Field Theory (DMFT). We further provide a closed-form description of the dynamical process of the low-dimensional projections of the weights, and numerical experiments illustrating the theory.
1 Introduction
Recent years have witnessed significant theoretical advancements in understanding the dynamics of training neural networks using gradient descent, to unravel the learning mechanisms of these networks, particularly how they adapt to data and identify pivotal features for predicting the target function. Significant progress has been made over the last few years in the case of two-layer networks, in large part thanks to the so-called mean-field analysis (Mei et al., 2018; Chizat and Bach, 2018; Rotskoff and Vanden-Eijnden, 2022; Sirignano and Spiliopoulos, 2020)). Most of the theoretical efforts, in particular, focused either on one-pass optimization algorithms, where each iteration involves a new fresh batch of data, or to the limit of gradient flow in the population loss. For high-dimensional synthetic Gaussian data, and a low dimensional target function (a multi-index model), the class of functions efficiently learned by these one-pass methods has been thoroughly analyzed in a series of recent works, and have been shown to be limited by the so-called information exponent (Ben Arous et al., 2021) and leap exponent (Abbe et al., 2022; 2023) of the target. These analyses have sparked many follow-up theoretical works over the last few months, see, e.g. (Damian et al., 2022; 2023; Dandi et al., 2023; Bietti et al., 2023; Ba et al., 2023; Moniri et al., 2023; Mousavi-Hosseini et al., 2023; Zweig and Bruna, 2023).
However, a common practice in machine learning involves repeatedly traversing the same mini-batch of data. This paper, therefore, aims to go beyond the constraints of single-pass algorithms and to evaluate whether multiple-pass training overcomes these inherent flaws of single-pass methods. We focus on gradient descent, certainly the most straightforward procedure in this family. The theoretical framework we use to prove our main results is based on Dynamical Mean Field Theory (DMFT), which was developed in the statistical physics community (Sompolinsky et al., 1988) to analyze correlated systems, and recently made rigorous in the context of high dimensional machine-learning problems in (Celentano et al., 2021; Gerbelot et al., 2023).
Our findings significantly alter the prevailing narrative in the literature. We demonstrate that gradient descent surpasses the limitations imposed by the information and leap exponents, achieving a positive correlation with the target function for a much broader class than the staircase functions (Abbe et al., 2021), even with minimal (that is, two) repetition of data batches. We characterize the (broad) class of functions efficiently learned in finite time. Among the exceptions are symmetric functions that remain a challenge due to their extended symmetry-breaking times (a natural feature of physical dynamics (Bouchaud et al., 1998)).
With independent Gaussian datapoints as inputs, both one-pass SGD and gradient flow on population loss result in pre-activations remaining distributed as Gaussian random variables. This Gaussianity underlies the analysis of such settings, starting from the seminal work of Saad and Solla (1995). In contrast, upon re-using batches, the pre-activations develop non-Gaussian components correlated with the targets. This non-Gaussianity is a crucial aspect of the stark contrast in the learning of directions compared to the one-pass setting. While we establish our results for discrete steps with extensive batches of size where denotes the input dimension, we expect our conclusions about the learning of new directions to also hold while performing gradient flow on the empirical loss, since such a setup would lead to the development of similar non-Gaussian components in the pre-activations, in contrast to gradient flow on the population loss where the pre-activations remain Gaussian.
Our results demonstrate that contrary to the common wisdom “the more data the better”, gradient descent on the same batch can surpass one-pass SGD on different batches, even when one-pass SGD utilizes a larger number of datapoints. More generally, we believe that our analysis provides insights into incremental learning of features in the presence of correlations between datapoints across batches, which is typical in most high dimensional datasets having a small number of “latent” factors. Our conclusions follow from a rigorous mathematical proof rooted in DMFT, which we also use to provide an analytic description of the dynamic processes of low-dimensional weight projections. This analysis has interest on its own.
2 Setting and main contributions
Let be the set of data . The input data are taken as a standard i.i.d. Gaussian, while the labels are generated by a teacher, or target, function . We consider multi-index target function, dependent on a low-dimensional subspace of the input space:
(1) |
We assume for convenience that is normalized row-wise on the sphere , with orthogonal weights i.e for .
The data are handled to a two-layer network (the student) with first layer weights ( is the number of neurons in the hidden layer) and second layer weights with an activation function , that is:
(2) |
Our main goal is to analyze the dynamics of gradient descent that minimizes the empirical Mean Squared Error (MSE) loss at time :
(3) | ||||
in the high-dimensional limit where with . We use a common assumption that is amenable to rigorous theoretical guarantees: we keep the second layer weights fixed at initialization. For convenience, we further impose the constraint of symmetric initialization common in such analyses (Dandi et al., 2023; Damian et al., 2022). Concretely, we assume that the number of neurons is even and the weights satisfy at initialization:
(4) |
which ensures that the output equals at initialization. For , the weights are initialized as Subsequently, with fixed, the first layer weights are learned using gradient descent, producing the following sequence of iterates up to a final time :
(5) |
where is the learning rate and is the explicit regularisation. We may refer to these steps as the representation learning steps, in which the first layer weights learn how to adapt to the low dimensional structure identified by the teacher subspace .
Our main contributions in this paper are the following:
-
•
We characterize the class of multi-index targets that can be learned efficiently by two-layer networks trained with a finite number of iterations of gradient descent in the high dimensional limit with large batch sizes (). We establish a strong separation between what can be learned with one-pass algorithms (that use new fresh batches at every step) and multi-pass gradient approaches that can use the same batch many times (see Figs. 1 and 2 for examples).
-
•
We show that while both gradient flow (Bietti et al., 2023) and single-pass algorithms suffer from the curse of the information exponent (Ben Arous et al., 2021), and are limited to staircase learning (Abbe et al., 2023), requiring a diverging number of iterations for non-staircase functions, some of these problems become trivial when allowing reusing samples multiple times, and features can be learned in just iterations. This disproves, in particular, a recent conjecture by (Abbe et al., 2023).
-
•
The simplest examples of directions that cannot be learned in a finite number of steps relate to symmetries in the target function. This includes phase retrieval (Maillard et al., 2020) or the specialization transition in committees, as discussed in the Bayes optimal approaches of single-index (Barbier et al., 2019) and multi-index (Aubin et al., 2019) models.
-
•
The proof of our results is based on the concept of “hidden progress”, and crucially uses the rigorous Dynamical Mean Field Theory (DMFT) (Celentano et al., 2021; Gerbelot et al., 2023). This has an interest on its own as it provides a sharp example of how DMFT can help to understand batch reusing to go beyond the current state-of-the-art results.
-
•
Finally, we use DMFT to provide a closed-form description of the dynamics of gradient descent for two-layer nets. Kee** track of the correlations induced by re-using the same batch leads to a set of integro-differential equations. We provide rigorous theoretical guarantees in the correlated samples regime without assuming the resampling of a fresh new batch for each iteration of the algorithm. We corroborate the theoretical claims with numerical simulations (See https://github.com/IdePHICS/benefit-reusing-batch. ).
Other Related works –
A major issue in machine learning theory is figuring out how well two-layer neural networks adapt to low-dimensional structures in the data. Different results have tightly characterized the limitations of networks in which the first layer of weights is kept fixed, i.e. equivalent to kernel approaches (Dietrich et al., 1999; Ghorbani et al., 2019; 2020; Bordelon et al., 2020; Loureiro et al., 2021; Cui et al., 2021). This class of learning algorithms, although amenable to theoretical analysis, is unable to learn features in the data. Therefore, one central avenue of research in this context is to understand the efficiency of the representation learning (or feature learning) when training with gradient-based algorithms to overcome the limitations of the kernel regime. Sharp separation results between the performance of neural networks at initialization (random features) and trained with only one step of gradient descent (with a large learning rate) have been offered (Ba et al., 2022; Damian et al., 2022; Dandi et al., 2023).
The class of features efficiently learned with multiple steps of one-pass SGD with one sample per batch is characterized by the information exponent () (Ben Arous et al., 2021) of the target function. In the context of single-index learning, denoting the of the target, the algorithm needs steps to perform weak recovery of the teacher direction, i.e., obtaining an overlap between learned weights and better than random guessing (Ben Arous et al., 2021). Recently, these results have been improved up to the Correlational Statistical Query (CSQ) lower bound of , by considering an appropriate smoothing of the loss (Damian et al., 2023). A generalization to large batch one-pass SGD is in (Dandi et al., 2023).
Similarly, multi-index feature learning presents an unavoidable computational barrier for one-pass algorithms. (Abbe et al., 2021) first characterizes a hierarchical picture of learning in the Boolean data case: informally, the features efficiently learned at each step of the one-pass algorithm need to be linearly connected with the previously learned features. This concept is formalized by the definition of the staircase property (Abbe et al., 2021). This hierarchical picture of learning is extended to large batches in the SGD and non-Boolean data in (Abbe et al., 2022; 2023; Dandi et al., 2023). Moreover, (Abbe et al., 2023) conjecture that re-using the batch can reduce the sample complexity of the target with leap only up to , corresponding to the lower bound for Correlational Statistical Query (CSQ) algorithms.
We disprove this conjecture and show that the sample complexity for a large class of functions can be reduced to independently of the leap exponent . More generally, our results show that CSQ lower bounds and the notions of staircase property and information exponent are limited to online-SGD on Gaussian/Boolean data, and do not describe the class of functions inherently easy or hard to learn by gradient-based methods. We also show that learning non-even single-index functions does not require techniques such as spectral warm-start (Chen and Meka, 2020).
Dynamical Mean Field Theory has a long history in statistical physics. Early theories of dynamics in complex systems were pioneered in soft spin glass models (Sompolinsky and Zippelius, 1981) and toy models of random feature deep networks (Sompolinsky et al., 1988). The DMFT approach used in this paper was first proposed as a way to study “hard spins” in spin glass models (Eissfeller and Opper, 1992; 1994), and was later generalized to “soft spins” (Cugliandolo, 2003) and more realistic models in condensed matter (Georges et al., 1996). In the context of learning, DMFT was used for optimization problems (Mannelli et al., 2019a; b; 2020; Mannelli and Urbani, 2021) and for analyzing the behavior and the noise of gradient-based algorithms (Mignacco et al., 2020; 2021; Mignacco and Urbani, 2022). From the mathematics point of view, these DMFT equations were first proven rigorously in the seminal work of (Ben Arous et al., 1997) in the context of spin glasses. Important progress was achieved recently with rigorous proofs of the DMFT equations for multi-index models (Celentano et al., 2021; Gerbelot et al., 2023) that we use to prove our main results.
![Refer to caption](x1.png)
![Refer to caption](x2.png)
3 Statement of the results
Here, we introduce the main results covering the theoretical learning guarantees with gradient descent and contrast them with the known one-pass results. We exploit the rigorous DMFT construction to prove the first key result: two-layer networks efficiently learn a large class of multi-index targets in only iterations, breaking the curse of one-pass algorithms dictated by the information and leap exponents.
3.1 Finite-T Learnable and Non-learnable directions
We first identify which target directions are hard to learn for multi-pass gradient descent. Define to be the subspace spanned by the rows of the target weights . The “hard” directions are the ones where any transformation of the output does not lead to a linear correlation along the direction. We now define the subspace of such directions:
Definition 3.1.
We define as the subspace of directions such that for any polynomial with coefficients in , the following condition is satisfied:
(6) |
Similarly, we denote by , the subspace of directions where the above condition is satisfied for all real-valued analytic functions .
One part of our main result shows that directions in cannot be learned even by re-using batches of size in a finite number of gradient steps. Furthermore, under suitable conditions on and (discussed in Theorem 3.2 and Appendix A.5), we show that after two gradient steps, the first layer learns all directions in the complement . We are now ready to state our main result:
Theorem 3.2.
Suppose that . Let denote an arbitrary direction in the orthogonal complement of the subspace defined in definition 3.1 with norm and a fixed representation in the basis . Suppose further that the activation function is analytic, with polynomially bounded derivatives satisfying and . Then, for any with derivatives bounded by polynomials, there exist such that almost surely over the choice of , we have:
(7) |
with high probability as . Furthermore, for large enough , asymptotically spans :
(8) |
with high probability as . In other words, directions are learned in gradient steps.
Suppose, however that the teacher subspace , then:
(9) |
with high probability as , for any finite time . Thus, none of the directions are learned in any finite number of GD steps.
The proof is based on the analysis of the DMFT equations discussed in Sec. 4.2, is given in App. A, and we provide an informal heuristic derivation in sec. 4.1. While the above negative result requires all directions in to be in and thus in , in App A.6, we discuss the more general setup where learning of certain directions in can affect the learning of directions in in subsequent timesteps.
When the expectation in Equation 6 is non-zero for being the identity map**, i.e. , is in-fact learned in the first gradient step (Ba et al., 2022; Dandi et al., 2023) or through online SGD (Ben Arous et al., 2022; Abbe et al., 2023). We discuss this further in Section 3.3.
Our analysis reveals that the effect of re-using batches is to implicitly transform the output in the subsequent steps, allowing a larger set of directions to be learned. However, for directions in , such transformations are still insufficient.
3.2 Characterization of hard directions through symmetries
While Definition 3.1 characterizes the subspace of hard directions , it requires checking that the equality in Equation 6 holds for any real analytic transformation . We now show that a sufficient condition for is for to possess certain symmetries along . This leads us to identify subspaces of hard directions, contained in , linked to symmetries w.r.t certain transformations. We characterize such subspaces below. The simplest such symmetry is defined through reflection along :
Definition 3.3.
For any direction , let denote the reflection operator along , i.e. . We say that a direction is even-symmetric w.r.t if for any :
(10) |
We denote by the subspace spanned by all even-symmetric directions in .
It is straightforward to see that any leads to Equation 6 being satisfied for any transformation , since remains even w.r.t the function . Therefore, However, the set of non-learnable directions can be larger due to the presence of additional symmetries. We now define such a larger subspace of hard directions arising due to a symmetry w.r.t reflections along coupled with orthogonal transformations along the orthogonal subspace:
Definition 3.4.
For any direction in , let be as defined in Definition 3.3. Let be a matrix in the orthogonal group on the dimensional subspace i.e the orthogonal complement of the linear subspace spanned by . We say that a direction is orthogonally-even-symmetric w.r.t , if there exists an , such that for any :
(11) |
We denote by the subspace spanned by all orthogonally-even-symmetric directions in .
By setting as the identity map** in the above definition, we recover the condition for . Therefore, we have that . While is the largest set of directions we’ve identified as being hard, the true set of hard directions may be larger still and is given by in Definition 3.1. We show in Appendix A.8 that the directions in are indeed hard as per Definition 3.1:
Proposition 3.5.
App.A.7 gives several examples where , such as single-index targets with odd Hermite activations, staircase functions, etc. Interestingly, we show that there exist functions where the set is strictly larger than . Consequently, for such functions, is strictly contained in . We discuss such target functions in Appendix A.9. For example, we show in Appendix A.9, that for the target function , the direction does not lie in but lies in and thus in .
3.3 Comparison between one-pass and multi-pass GD
Our results are particularly interesting in the context of a recent line of work on the limitations of one-pass algorithms. (Ben Arous et al., 2021; Abbe et al., 2021; 2022; 2023; Dandi et al., 2023; Bietti et al., 2023; Zweig and Bruna, 2023). We can demonstrate, in particular, a sharp separation performance between one-pass and multiple-pass protocols.
Learning single-index targets –
First, we consider single index targets. Targets that are hard to learn for one-pass algorithms starting from uninformed initialization in high dimension are characterized by the Information Exponent (). Informally, the is equivalent to the first non-zero coefficient in the Hermite expansion of the target activation.
Definition 3.6 (Information Exponent).
Higher are associated to harder problems for one-pass training protocols. Indeed, (Ben Arous et al., 2021) provably show that one-pass SGD, with one sample per batch, weakly recovers the teacher direction only upon iterating the training schedule for time iterations:
(13) |
Recently, the time complexity has been improved up to the Correlational Statistical Query (CSQ) lower bound of , by considering an appropriate smoothing of the loss (Damian et al., 2023). Definition 3.6 has been extended to larger batch sizes in (Abbe et al., 2022; Dandi et al., 2023), without changing the overall picture; more precisely, even with fresh samples per batch, one-pass training procedures are still not able to weakly recover the signal in finite iteration time. The case corresponds to the expectation in Equation 6 being non-zero for . However, since Definition 3.1 allows for general transformations to the output , may not be in even when . The presence of general transformations in definition 3.1 allows our algorithm to bypass CSQ bounds, which are restricted to . Such general transformations are however permitted under the framework of Statistical Query (SQ) algorithms (Kearns, 1998). We thus expect gradient descent with sample complexity to inherit the hardness results established for the class of SQ algorithms (Diakonikolas et al., 2020; Goel et al., 2020; Chen et al., 2021; 2022). We emphasize however that unlike explicit SQ algorithms, our analysis shows that gradient descent performs such transformations implicitly, allowing it to reach the optimal complexity of SQ algorithms for certain class of target functions.
We illustrate the sharp contrast between one-pass and multiple-pass protocols with the examples depicted in Figure 1, which shows the scalar product (called overlap) between the learned weights and the teacher direction as a function of the time steps and compares simulation (dots) with theoretical predictions (continuous lines). There are cases:
- •
-
•
Multi-pass finite- learnable single-index targets : Fig. 1 (center) depicts the learning curve for a non-even target function, with . Here, one-pass GD is not able to achieve any significant correlation with the teacher (and it would require a number of iterations to achieve weak recovery - see eq. (13)). However, multiple-pass GD performs weak recovery in only steps. As before, the non-symmetric subspace corresponds to the teacher one (Def. 3.3).
-
•
Finite- non-learnable single-index targets : Fig. 1 (right) considers an even problem, with . Neither of the training procedures achieve weak recovery in finite time. The computational hardness of this problem is associated with the presence of symmetry in the teacher function that requires time to break. Indeed, following Definition 3.3, the even-symmetric subspace is equivalent to the teacher subspace . These results agree with the emergence of computational barriers in symmetric single-index problems like the phase retrieval one (Maillard et al., 2020). In fact, for such problems, regardless of the number of iterations, learnability requires to be larger than critical values even for the most efficient known algorithms (see (Barbier et al., 2019), Sec. 3.1).
Learning multi-index targets –
The hardness of multi-index targets learning has been the subject of numerous recent studies for single-pass algorithms (Abbe et al., 2021; 2022; 2023; Bietti et al., 2023; Zweig and Bruna, 2023; Dandi et al., 2023). The class of multi-index targets efficiently learned by one-pass algorithms has been provably associated with the Leap Complexity () of the target to be learned, which generalizes the information exponent:
Remark 3.7.
Informally, the learning dynamics of one-pass routines follow this behavior: initially, the network learns in the first step the first Hermite coefficient of the target . For every time of the one-pass schedule, the network is bound to learn in finite time only features that are linearly connected to the previously learned directions; functions possessing only such linearly connected features are leap functions (), e.g. . Similarly, functions that are quadratically connected to the learned features are leap (), e.g. . Higher target functions correspond to harder learning problems for one-pass algorithms: one-pass SGD, with one sample per batch, weakly recovers the teacher subspace by iterating the training protocol for time steps, where the substitutes the in eq. (13) (Abbe et al., 2023).
We illustrate the behavior of one-pass and multiple-pass algorithms when learning multi-index functions in Fig. 2. Using different two-index teachers (), it shows the scalar product between the learned weights and two reference vectors: a) the first Hermite coefficient of the target , called in the following ; b) the vector in the teacher subspace orthogonal to , referred as . The figure exemplifies the correlations metrics as a function of time, labeled as overlap (resp. orthogonal overlap) in the upper (resp. lower) section. There are, again, 3 cases:
-
•
Finite- learnable multi-index targets: Fig. 2 (left) depicts a target with . The teacher subspace spanned by the standard basis vectors is learned by both one-pass and multi-pass GD in finite time. At , is learned; this enables the recovery of the direction at as the target is linear in once has been learned. This hierarchical picture of learning is called staircase mechanism. Using Def. 3.3 notations, the non-symmetric teacher subspace is equivalent to the full teacher subspace .
-
•
Multi-pass finite- learnable multi-index targets: The central panel in Fig. 2 illustrates a teacher with . Both algorithms are successful in weakly recovering the direction in the first step. However, as the training continues, one-pass GD never recovers the full teacher subspace in finite time (exemplified by the zero orthogonal overlap in the lower panel). Conversely, multi-pass GD is able to perform weak recovery of the full teacher subspace by achieving a non-vanishing correlation with (non-zero orthogonal overlap in the lower section) in just steps. Again, the non-symmetric subspace is equivalent to the full teacher subspace (Def. 3.3).
-
•
Finite- non-learnable multi-index targets: The right panel of Fig. 2 considers a committee machine teacher with symmetric activation, i.e. , here . Both protocols, in this case, are only able to learn a single-index approximation of the target function in finite time, achieving non-zero correlation only with throughout the dynamics. The computational hardness of this problem is associated with the presence of a neuron exchange symmetry. Indeed, using Def. 3.3 notations, we observe that the even-symmetric subspace is a non-empty subspace of the teacher one . Therefore, as for one-pass routines, multiple-pass ones are bound to learn only in finite time steps. Such difficulties have been described in the analysis of the specialization transition in the information-theoretic/Bayes optimal case of symmetric committees (Aubin et al., 2019). As for single index models, breaking the symmetry requires to be large enough and, even in this case, the best-known algorithms require a diverging number of iterations (see (Aubin et al., 2019), Sec. 3).
3.4 From weak recovery to generalization
While Th. 3.2 provides conditions for the weak recovery (a finite overlap with directions in ), once this is done, it becomes straightforward to learn the function up to any desired accuracy with only additional samples. Indeed, strong generalization guarantees can be proven by utilizing existing results either for subsequent training with online SGD (Ben Arous et al., 2021) (to use their terminology, once you escape mediocrity, the ballistic phase is easy) or training of the second layer using an independent batch of samples as in (Damian et al., 2022; Abbe et al., 2023). See App.A.10 for such generalization sample-complexity results.
4 Main proof ideas
4.1 Learning by hidden progress: heuristic argument
While we give a rigorous proof of Thm. 3.2 in App. A, we provide now an informal description of the hidden progress in the first step of gradient descent that allows subsequent development of overlaps in the second step, that is at the root of the difference between single and multi-pass algorithms. For simplicity, we focus on the case of a single hidden neuron (). We denote the pre-activation for the training point along the neuron with , and a vector in the span of with .
From the gradient update in Eq. (5), the update lies in the span of the training inputs , with the gradient of the training example given by . For squared loss, assuming that , the gradient reads:
(14) |
At initialization and the projections along the teacher subspace (which we denote ) are approximately independent since is approximately orthogonal to the teacher subspace as well as to the inputs . The projection of the gradient along the teacher subspace is given by:
(15) |
We do expect that, due to concentration, the component of the full-batch gradient update along the teacher subspace lies along the direction given by:
(16) |
where we used the approximate independence of and to factorize the expectation. Thus, the neuron parameters at the first step are correlated with the teacher subspace only along the direction .
If , the parameters remain orthogonal to the teacher subspace. This is true whenever the of the target function is larger than . To make progress, it is thus necessary for the pre-activations to become correlated with the teacher pre-activation . This can happen in two different ways:
(i) By directly gaining components along the teacher subspace . Under online SGD, the data is used only once for the gradient updates, so only this mechanism is possible. It allows the directions learned by at any step to depend on the directions already learned by . This underlies the “staircase” phenomenon in online SGD (Abbe et al., 2021; 2022; 2023) as well as the notion of information exponent when applied to a single direction (Ben Arous et al., 2021).
(ii) By gaining components along . Recall that the target is defined as and thus can correlate with . This is what happens when using gradient descent with multi-pass in our setting. This implies that even when does not learn a direction , the pre-activation can develop a dependence on through the component of the gradient update along .
Let us see how this phenomenon, which we call hidden progress, happens in practice. From (5), the update in the pre-activation due to the first gradient step reads:
(17) |
In this sum there is one term of magnitude corresponding to , and random terms of order . This second group of terms contributes to an effective “noise” of order . The first term however, since , depends on (and thus on all components of ):
(18) |
Due to this dependence between and , in the subsequent steps i.e. , the term in the update (14) can now influence the direction of the gradient along the teacher subspace, leading to gaining correlations with new directions in . It can be seen as follow: let , it follows from the GD updates that
(19) |
Now, suppose that is not learned in the first step. However, due to the hidden progress, is now dependent on , thus allowing the new expectation of the projection of the update along given by to be non-zero. This explains how the dependence of the pre-activations on can allow learning of new directions even when the weights have not gained components along the teacher subspace.
This learning mechanism, however, fails when the target function is symmetric along . Indeed, for such a direction, retains an even dependence on , which implies that the expectation of the term remains for all time steps , with . Such directions are therefore not learned with a finite number of time-steps and batch-size even upon re-using the batches.The rigorous control of all these quantities is a difficult task a priori. One cannot, in particular, express the above sum as an expectation w.r.t independent samples since the weights now depend on all the samples. Fortunately, this is precisely the difficulty solved by the DMFT equations through an effective stochastic process on the pre-activations that are decoupled across training examples. The rigorous analysis is detailed in App. A. The main lines of the DMFT equations are in Sec. 4.2.
Finally, note that while our proof uses the Gaussian data assumption, the heuristic argument hints that this is not crucial. Additionally, in any real dataset samples are very correlated, and thus a given sample (or a very similar one) may appear many times. In this case, even single-pass algorithms will behave as predicted by our approach. We thus believe it describes a more realistic scenario than the pure single pass theories with fresh i.i.d. data.
4.2 Characterization of the dynamics
Re-using batches at each gradient step requires kee** track of the pre-activations of the parameters. Since the number of pre-activations and the dimensions of the parameters grows with , we need a low-dimensional effective dynamics characterizing the quantities of interests such as the overlaps between the student and target parameters. DMFT provides such an effective dynamics through a set of coupled stochastic processes and representing the joint-distributions of the student, teacher parameters . and the student, teacher pre-activations respectively.
We derive the equations and prove their applicability to our setting using existing results in (Celentano et al., 2021; Gerbelot et al., 2023). Asymptotically, for with , the joint distribution of the student and teacher pre-activations (for each sample), and converge in distribution to samples from the stochastic process and the standard normal variable . Similarly, the joint distribution of each component of the student and teacher weights with converge in distribution to samples from the stochastic process and the standard normal variable .
(20) | ||||
(21) |
Notice that the formula above is the high dimensional equivalent of the gradient descent update (5). Here and are zero mean Gaussian Process with covariances and respectively, with
the matrix can be viewed as an “effective regularization” on the parameters. and the projected gradient converge in probability to:
(22) | |||
(23) |
The memory kernels , , are defined as:
(24) | |||
(25) |
and , . Finally, the low dimensional projections of the weights will obey
(26) |
Notice that these definitions are well-posed because of the causal structure of the gradient descent upgrades, and by extension of (20): the distribution of is completely determined by and the auxiliary quantities in eqs. (22, 4.2). Iterating backwards we reach the initial condition , which is a simple function of the data distribution and the initial conditions of the weights. For additional details we refer to App. A. Notice that it is also possible to write this set of equations as a function of a single stochastic process on , as in App. C.
Sketch of proof of the hidden progress —
Finally, we explain how the DMFT equations relate to the phenomenon in Sec. 4.1 and allow us to prove Th. 3.2. The term in (21) precisely corresponds to the contribution to pre-activation of a point (App. A.4) from the gradient at the same point . As we discussed in Section 4.1, this term induces a dependence between and even when the overlaps are . At time , the response term simplifies to and the pre-activations can be expressed as the random variable with added Gaussian noise. Analogous to section 4.1, we denote by the limiting value of the overlaps for some with . Propagating the equations over the first two steps, and using Equation (26), we show that can be expressed as an expectation w.r.t the pre-activations of a function dependent on the target , the second layer , and the activation function :
(27) |
The function is described in App.A.5, Eq.(90). Finally, we show an equivalence between the condition to the condition for general in definition 3.1.
General multi-pass schemes —
While Theorem 3.2 considers finite number of updates with the same batch of data for each step, it can be naturally generalized to other setups involving multiple-passes over a finite-number of mini-batches of size . For instance, one can cycle over distinct minibatches with each cycle constituting one epoch or pass through the dataset. Theorem 3.2 remains valid under such a setup with the onset of weak-recovery shifting to the start of the second epoch instead of the second gradient step. We provide a sketch of this extension in Appendix B. On the other hand, if the minibatches are sampled with replacement from the dataset, the weak recovery still starts at the second gradient step. We illustrate this in Fig.4 (in appendix). Furthermore, we empirically observe that the phenomenon holds even when considering the limit of mini-batch size (Figure 5 Appendix). Proving this, however, remains out of the reach of the present technique.
5 Conclusions
Our study analyzes the training dynamics of two-layer neural networks for learning multi-index target functions, distinctively focusing on multi-pass gradient descent which involves reusing batches multiple times. We find that this enables gradient descent to exceed the constraints imposed by information and leap exponents.
Gradient descent is found to achieve a positive correlation with the target function across a broader class than previously anticipated, with only two data batch repetitions. Our analysis further demonstrates that the limitations associated with information and leap exponents, staircase learning, and CSQ lower bounds are restricted to online/single pass SGD and do not describe the class of functions inherently easy or hard to learn by gradient-based methods for neural networks.
Our conclusions follow from rigorous mathematical proofs derived from Dynamical Mean Field Theory, through which we also offer an analytical description of the dynamic processes of low-dimensional weight projections—a noteworthy insight. Additionally, we provide a closed-form depiction of these dynamical processes and illustrate our theoretical findings with numerical experiments.
6 Acknowledgements
We thank Cedric Gerbelot, Bruno Loureiro and Ludovic Stephan for insightful discussions. We also acknowledge funding from the Swiss National Science Foundation grant SNFS OperaGOST (grant number ), and SMArtNet (grant number ).
References
- Abbe et al. [2021] E. Abbe, E. Boix-Adsera, M. S. Brennan, G. Bresler, and D. Nagaraj. The staircase property: How hierarchical structure can guide deep learning. Advances in Neural Information Processing Systems, 34:26989–27002, 2021.
- Abbe et al. [2022] E. Abbe, E. Boix-Adsera, and T. Misiakiewicz. The merged-staircase property: a necessary and nearly sufficient condition for sgd learning of sparse functions on two-layer neural networks. In Conference on Learning Theory, pages 4782–4887. PMLR, 2022.
- Abbe et al. [2023] E. Abbe, E. Boix-Adsera, and T. Misiakiewicz. Sgd learning on neural networks: leap complexity and saddle-to-saddle dynamics, 2023.
- Agoritsas et al. [2018] E. Agoritsas, G. Biroli, P. Urbani, and F. Zamponi. Out-of-equilibrium dynamical mean-field equations for the perceptron model. Journal of Physics A: Mathematical and Theoretical, 51(8):085002, 2018.
- Andrews [2004] G. E. Andrews. Special functions. Cambridge University Press, 2004.
- Aubin et al. [2019] B. Aubin, A. Maillard, J. Barbier, F. Krzakala, N. Macris, and L. Zdeborová. The committee machine: computational to statistical gaps in learning a two-layers neural network. Journal of Statistical Mechanics: Theory and Experiment, 2019(12):124023, Dec. 2019. ISSN 1742-5468. doi: 10.1088/1742-5468/ab43d2. URL http://dx.doi.org/10.1088/1742-5468/ab43d2.
- Ba et al. [2022] J. Ba, M. A. Erdogdu, T. Suzuki, Z. Wang, D. Wu, and G. Yang. High-dimensional asymptotics of feature learning: How one gradient step improves the representation. In S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh, editors, Advances in Neural Information Processing Systems, volume 35, pages 37932–37946. Curran Associates, Inc., 2022.
- Ba et al. [2023] J. Ba, M. A. Erdogdu, T. Suzuki, Z. Wang, and D. Wu. Learning in the presence of low-dimensional structure: a spiked random matrix perspective. In Neurips 2023, 2023.
- Barbier et al. [2019] J. Barbier, F. Krzakala, N. Macris, L. Miolane, and L. Zdeborová. Optimal errors and phase transitions in high-dimensional generalized linear models. Proceedings of the National Academy of Sciences, 116(12):5451–5460, 2019.
- Bayati and Montanari [2011] M. Bayati and A. Montanari. The dynamics of message passing on dense graphs, with applications to compressed sensing. IEEE Transactions on Information Theory, 57(2):764–785, 2011.
- Ben Arous et al. [1997] G. Ben Arous, A. Guionnet, et al. Symmetric langevin spin glass dynamics. The Annals of Probability, 25(3):1367–1422, 1997.
- Ben Arous et al. [2021] G. Ben Arous, R. Gheissari, and A. Jagannath. Online stochastic gradient descent on non-convex losses from high-dimensional inference. Journal of Machine Learning Research, 22(106):1–51, 2021.
- Ben Arous et al. [2022] G. Ben Arous, R. Gheissari, and A. Jagannath. High-dimensional limit theorems for sgd: Effective dynamics and critical scaling. Advances in Neural Information Processing Systems, 35:25349–25362, 2022.
- Bietti et al. [2023] A. Bietti, J. Bruna, and L. Pillaud-Vivien. On learning gaussian multi-index models with gradient flow. arXiv preprint arXiv:2310.19793, 2023.
- Bolthausen [2014] E. Bolthausen. An iterative construction of solutions of the tap equations for the sherrington–kirkpatrick model. Communications in Mathematical Physics, 325(1):333–366, 2014.
- Bordelon et al. [2020] B. Bordelon, A. Canatar, and C. Pehlevan. Spectrum dependent learning curves in kernel regression and wide neural networks. In H. D. III and A. Singh, editors, Proceedings of the 37th International Conference on Machine Learning, volume 119 of Proceedings of Machine Learning Research, pages 1024–1034. PMLR, 13–18 Jul 2020.
- Bouchaud et al. [1998] J.-P. Bouchaud, L. F. Cugliandolo, J. Kurchan, and M. Mézard. Out of equilibrium dynamics in spin-glasses and other glassy systems. Spin glasses and random fields, 12:161, 1998.
- Celentano et al. [2021] M. Celentano, C. Cheng, and A. Montanari. The high-dimensional asymptotics of first order methods with random data. arXiv:2112.07572, 2021.
- Chen and Meka [2020] S. Chen and R. Meka. Learning polynomials in few relevant dimensions. In Conference on Learning Theory, pages 1161–1227. PMLR, 2020.
- Chen et al. [2021] S. Chen, A. Klivans, and R. Meka. Efficiently learning one hidden layer relu networks from queries. Advances in Neural Information Processing Systems, 34:24087–24098, 2021.
- Chen et al. [2022] S. Chen, A. Gollakota, A. Klivans, and R. Meka. Hardness of noise-free learning for two-hidden-layer neural networks. Advances in Neural Information Processing Systems, 35:10709–10724, 2022.
- Chizat and Bach [2018] L. Chizat and F. Bach. On the global convergence of gradient descent for over-parameterized models using optimal transport. Advances in neural information processing systems, 31, 2018.
- Cugliandolo [2003] L. F. Cugliandolo. Dynamics of glassy systems. In Slow Relaxations and nonequilibrium dynamics in condensed matter. Springer, 2003.
- Cui et al. [2021] H. Cui, B. Loureiro, F. Krzakala, and L. Zdeborová. Generalization error rates in kernel regression: The crossover from the noiseless to noisy regime. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P. Liang, and J. W. Vaughan, editors, Advances in Neural Information Processing Systems, volume 34, pages 10131–10143. Curran Associates, Inc., 2021.
- Damian et al. [2022] A. Damian, J. Lee, and M. Soltanolkotabi. Neural networks can learn representations with gradient descent. In P.-L. Loh and M. Raginsky, editors, Proceedings of Thirty Fifth Conference on Learning Theory, volume 178 of Proceedings of Machine Learning Research, pages 5413–5452. PMLR, 02–05 Jul 2022.
- Damian et al. [2023] A. Damian, E. Nichani, R. Ge, and J. D. Lee. Smoothing the Landscape Boosts the Signal for SGD: Optimal Sample Complexity for Learning Single Index Models. Technical report, Princeton, May 2023. arXiv:2305.10633 [cs, math, stat] type: article.
- Dandi et al. [2023] Y. Dandi, F. Krzakala, B. Loureiro, L. Pesce, and L. Stephan. How two-layer neural networks learn, one (giant) step at a time, 2023.
- Diakonikolas et al. [2020] I. Diakonikolas, D. M. Kane, V. Kontonis, and N. Zarifis. Algorithms and sq lower bounds for pac learning one-hidden-layer relu networks. In Conference on Learning Theory, pages 1514–1539. PMLR, 2020.
- Dietrich et al. [1999] R. Dietrich, M. Opper, and H. Sompolinsky. Statistical mechanics of support vector networks. Phys. Rev. Lett., 82:2975–2978, Apr 1999. doi: 10.1103/PhysRevLett.82.2975.
- Eissfeller and Opper [1992] H. Eissfeller and M. Opper. New method for studying the dynamics of disordered spin systems without finite-size effects. Physical review letters, 68(13):2094, 1992.
- Eissfeller and Opper [1994] H. Eissfeller and M. Opper. Mean-field Monte Carlo approach to the Sherrington-Kirkpatrick model with asymmetric couplings. Physical Review E, 50(2):709, 1994.
- Georges et al. [1996] A. Georges, G. Kotliar, W. Krauth, and M. J. Rozenberg. Dynamical mean-field theory of strongly correlated fermion systems and the limit of infinite dimensions. Reviews of Modern Physics, 68(1):13, 1996.
- Gerbelot et al. [2023] C. Gerbelot, E. Troiani, F. Mignacco, F. Krzakala, and L. Zdeborova. Rigorous dynamical mean field theory for stochastic gradient descent methods, 2023.
- Ghorbani et al. [2019] B. Ghorbani, S. Mei, T. Misiakiewicz, and A. Montanari. Limitations of lazy training of two-layers neural network. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
- Ghorbani et al. [2020] B. Ghorbani, S. Mei, T. Misiakiewicz, and A. Montanari. When do neural networks outperform kernel methods? In H. Larochelle, M. Ranzato, R. Hadsell, M. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 14820–14830. Curran Associates, Inc., 2020.
- Goel et al. [2020] S. Goel, A. Gollakota, Z. **, S. Karmalkar, and A. Klivans. Superpolynomial lower bounds for learning one-layer neural networks using gradient descent. In International Conference on Machine Learning, pages 3587–3596. PMLR, 2020.
- Kearns [1998] M. Kearns. Efficient noise-tolerant learning from statistical queries. Journal of the ACM (JACM), 45(6):983–1006, 1998.
- Loureiro et al. [2021] B. Loureiro, C. Gerbelot, H. Cui, S. Goldt, F. Krzakala, M. Mezard, and L. Zdeborová. Learning curves of generic features maps for realistic datasets with a teacher-student model. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P. Liang, and J. W. Vaughan, editors, Advances in Neural Information Processing Systems, volume 34, pages 18137–18151. Curran Associates, Inc., 2021.
- Maillard et al. [2020] A. Maillard, B. Loureiro, F. Krzakala, and L. Zdeborová. Phase retrieval in high dimensions: Statistical and computational phase transitions, 2020.
- Mannelli and Urbani [2021] S. S. Mannelli and P. Urbani. Just a momentum: Analytical study of momentum-based acceleration methods in paradigmatic high-dimensional non-convex problems. NeurIPS, 2021.
- Mannelli et al. [2019a] S. S. Mannelli, G. Biroli, C. Cammarota, F. Krzakala, and L. Zdeborová. Who is afraid of big bad minima? analysis of gradient-flow in spiked matrix-tensor models. In Advances in Neural Information Processing Systems, pages 8676–8686, 2019a.
- Mannelli et al. [2019b] S. S. Mannelli, F. Krzakala, P. Urbani, and L. Zdeborova. Passed & spurious: Descent algorithms and local minima in spiked matrix-tensor models. In international conference on machine learning, pages 4333–4342, 2019b.
- Mannelli et al. [2020] S. S. Mannelli, G. Biroli, C. Cammarota, F. Krzakala, P. Urbani, and L. Zdeborová. Marvels and pitfalls of the langevin algorithm in noisy high-dimensional inference. Physical Review X, 10(1):011057, 2020.
- Mei et al. [2018] S. Mei, A. Montanari, and P.-M. Nguyen. A mean field view of the landscape of two-layer neural networks. Proceedings of the National Academy of Sciences, 115(33):E7665–E7671, 2018.
- Mignacco and Urbani [2022] F. Mignacco and P. Urbani. The effective noise of stochastic gradient descent. Journal of Statistical Mechanics: Theory and Experiment, 2022(8):083405, aug 2022. doi: 10.1088/1742-5468/ac841d. URL https://doi.org/10.1088/1742-5468/ac841d.
- Mignacco et al. [2020] F. Mignacco, F. Krzakala, P. Urbani, and L. Zdeborová. Dynamical mean-field theory for stochastic gradient descent in gaussian mixture classification. Advances in Neural Information Processing Systems, 33:9540–9550, 2020.
- Mignacco et al. [2021] F. Mignacco, P. Urbani, and L. Zdeborová. Stochasticity helps to navigate rough landscapes: comparing gradient-descent-based algorithms in the phase retrieval problem. Machine Learning: Science and Technology, 2(3):035029, 2021.
- Moniri et al. [2023] B. Moniri, D. Lee, H. Hassani, and E. Dobriban. A theory of non-linear feature learning with one gradient step in two-layer neural networks, 2023.
- Montanari and Saeed [2022] A. Montanari and B. N. Saeed. Universality of empirical risk minimization. In P.-L. Loh and M. Raginsky, editors, Proceedings of Thirty Fifth Conference on Learning Theory, volume 178 of Proceedings of Machine Learning Research, pages 4310–4312. PMLR, 02–05 Jul 2022.
- Mousavi-Hosseini et al. [2023] A. Mousavi-Hosseini, D. Wu, T. Suzuki, and M. A. Erdogdu. Gradient-based feature learning under structured data, 2023.
- Rotskoff and Vanden-Eijnden [2022] G. Rotskoff and E. Vanden-Eijnden. Trainability and accuracy of artificial neural networks: An interacting particle system approach. Communications on Pure and Applied Mathematics, 75(9):1889–1935, 2022. doi: https://doi.org/10.1002/cpa.22074.
- Roy et al. [2019] F. Roy, G. Biroli, G. Bunin, and C. Cammarota. Numerical implementation of dynamical mean field theory for disordered systems: application to the lotka–volterra model of ecosystems. Journal of Physics A: Mathematical and Theoretical, 52(48):484001, Nov. 2019. ISSN 1751-8121. doi: 10.1088/1751-8121/ab1f32. URL http://dx.doi.org/10.1088/1751-8121/ab1f32.
- Saad and Solla [1995] D. Saad and S. A. Solla. On-line learning in soft committee machines. Physical Review E, 52(4):4225–4243, Oct. 1995. doi: 10.1103/PhysRevE.52.4225.
- Sirignano and Spiliopoulos [2020] J. Sirignano and K. Spiliopoulos. Mean field analysis of neural networks: A central limit theorem. Stochastic Processes and their Applications, 130(3):1820–1852, 2020.
- Sompolinsky and Zippelius [1981] H. Sompolinsky and A. Zippelius. Dynamic theory of the spin-glass phase. Phys. Rev. Lett., 47:359–362, Aug 1981.
- Sompolinsky et al. [1988] H. Sompolinsky, A. Crisanti, and H. J. Sommers. Chaos in random neural networks. Phys. Rev. Lett., 61:259–262, Jul 1988.
- Zweig and Bruna [2023] A. Zweig and J. Bruna. Symmetric single index learning, 2023.
Appendix A Mathematical Proofs
A.1 Notations
We use the asymptotic notation to denote for some constants and large enough . Similarly, denotes for any constant and large enough . We use to denote convergence in probability and convergence in distribution respectively as with . We denote subspaces and linear operators, matrices on through uppercase letters . For any subspace , we denote by , its orthogonal complement, i.e the subspace of vectors orthogonal to all .
A.2 DMFT and iterative conditioning
Unlike online SGD, the preactivations after multiple steps no longer remain Gaussian since the weights become dependent on the data. This prevents marginalizing over the orthogonal components over the preactivations and relating the learning of new directions to the Hermite decomposition of the target function. Our proof circumvents these issues by utilizing a simpler effective process that decouples the pre-activations for different samples. The effective process is obtained using a rigorous version of the Dynamical Mean Field Theory derived in [Montanari and Saeed, 2022] and [Gerbelot et al., 2023].
The derivation of Dynamical Mean Field Theory in the above works has the following essential elements:
-
1.
Iterative conditioning: The proof in [Gerbelot et al., 2023, Montanari and Saeed, 2022] for obtaining the DMFT equations relies on the observation that the gradient descent algorithm in Equation (28) for a finite-number of iterations can be described completely through projections of the inputs design matrix along a finite number of vectors in . The iterative conditioning technique [Bolthausen, 2014, Bayati and Montanari, 2011] then involves replacing the components of along directions orthogonal to these projections by independent Gaussian random variables. This leads to a non-Markovian structure in the effective processes for the activations, parameters.
-
2.
The concentration of finite-dimensional order parameters such as overlaps of the neuron parameters with the teacher neurons/subspace as well as expectations w.r.t the empirical measure of the pre-activations and parameters.
Using the above elements, DMFT provides a low-dimensional effective dynamics characterizing the limiting joint empirical measure of the student parameters, as well as the pre-activations. We illustrate the proof for the activations after the first gradient step, illustrating the relationship with the “hidden progress” described in section 4.1
Let denote the matrix of pre-activations at time . Similarly, let denote the matrix of input activations in the target function. We denote by , the matrix derivative of w.r.t the corresponding entries of the pre-activations matrix . After each gradient update, and the preactivations gain a dependence on . The Iterative conditioning technique works around this dependence by conditioning on the sigma algebra generated by and instead of on . Since interacts with and only through projections (along right with and left with respectively), the conditioning allows the components of orthogonal to and to be replaced by independent Gaussian entries.
For the first-gradient step, we only require conditioning on .
From ((5)), we obtain the following update for :
(28) |
Let By the equivalence of projection and conditioning for Gaussian random variables, we have that the following inequality holds in distribution:
(29) |
where is independent of and denote the projection operator along , defined as:
Substituting in Equation ((28)), we obtain:
(30) |
Since the projection, is along a low-dimensional subspace of dimension at most , we have . One can therefore show that converges in probability to . for any deterministic with . Applying it to the vector , conditioned on , we obtain that:
(31) |
Now, the diagonal entries of convergence in probability to due to the concentration of norms of Gaussian random vectors. This results in the term . This term is precisely the one responsible for the “hidden progress” explained in section 4.1, corresponding to the term in Equation 18. Since is independent of and , by central limit-theorem for sub-Gaussian random variables, the remaining off-diagonal terms can be shown to converge to Gaussian noise independent of with variance .
Lastly, the third term in Equation (30) can be shown to converge to Gaussian noise correlated with corresponding entries of . Specifically, by removing the conditioning on , we have through law of large numbers and Stein’s Lemma, we have that the term converges in probability to . Therefore, we obtain:
(32) |
Proceeding similarly, one obtains low-dimensional effective processes for for any time . In the following section, we derive the resulting DMFT dynamics for the setup considered in Section 1 through a reduction to the result in [Gerbelot et al., 2023]. We refer to [Gerbelot et al., 2023, Celentano et al., 2021] for detailed proofs based on the above technique.
A.3 Derivation of the exact asymptotics
We start by stating a general consequence of the main result in [Gerbelot et al., 2023].
Theorem A.1 (Corollary of Theorem 3.2 in [Gerbelot et al., 2023]).
Let be a sequence of matrices such that the overlap matrix satifies:
(33) |
where denotes a fixed matrix. Consider a dynamics of the form:
(34) |
where is pseudo-Lipshitz of finite-order and are i.i.d vectors distributed as , such that with . Then the empirical measure of the weights converges in distribution to the weight process and the empirical measure of the preactivations converges in distribution to that of the preactivation process , defined as
(35) |
(36) |
were we have
(37) |
(38) |
(39) |
Finally, and are zero-mean Gaussian processes respectively with covariances given by and :
(40) |
(41) |
The convergence in distribution of the empirical measures holds in the following sense: For any , and any pseudo-Lipschitz functions and :
(42) | |||
(43) |
where denotes the column of
The above result follows directly by substituting as in Theorem 3.2 of Gerbelot et al. [2023].
The definitions of and in Theorem A.1 require differentiating through the non-markovian processes defined by Equations 36, 35. Fortunately, can be equivalently described through an explicit set of recursive updates, which we state below for convenience:
(44) |
with boundary conditions
(45) | |||
(46) |
while
(47) |
where is a collection of stochastic processes with distribution
(48) |
and boundary conditions
(49) | |||
(50) |
To obtain the limiting equations under the setting of gradient descent with teacher weights in section 1, we utilize the generality of the update in theorem A.1, which allows for a portion of the parameters () to remain unaffected. We obtain the following result, which generalize the former theorem to the setting of our paper:
Theorem A.2.
Consider the distribution over data defined in section 2 and an update rule on the weights of the form (5), i.e:
(51) |
Then under the assumptions of Theorem 3.2, as with , the joint empirical measure of the coordinates of the student weights and the teacher weights converges in distribution to the stochastic process and the standard normal variable , in the sense of Theorem A.1. Similarly, the joint empirical measure of the student and teacher preactivations , converge in distribution to the stochastic process and the standard normal variable . and are defined recursively through the following equations:
(52) |
(53) |
Here and are zero mean Gaussian Process with covariances and respectively, given by:
(54) | |||
(55) |
where are defined as:
(56) |
The effective regularisation and the projected gradient concentrate to
(57) | |||
(58) |
The memory kernels , , for are defined through the partial derivatives with respect to the noise:
(59) | |||
(60) |
and , . Finally, satisfies the update equation
(61) |
where is defined as:
(62) |
Proof.
Analogous to the embedding of planted vectors in [Celentano et al., 2021], we start by considering an a lifted dynamics defined by concating and . First, define the extended parameters with update rule:
(63) |
The above form of updates can be seen to be a special case of Theorem A.1 with and given by:
(64) |
The assumptions on imply that is pseudo-Lipschitz of finite-order while standard concentration results for sub-exponential random variables when applied to imply that the overlap matrices at initialization converge almost surely. Therefore, Theorem A.1 applies, with the effective process for the weights and the pre-activations being described by:
(65) |
(66) |
Notice the redundancy in the above equations due to not being updated in (63). This allows us to further simplify (66) and (65), obtaining:
(67) |
(68) |
where we noticed that . These equations are the same as in A.1, with just two extra terms in (67), and . An application of the Stein’s Lemma further simplifies the term to in the Theorem as follows:
(69) |
∎
The above effective process characterizes the limits of several quantities determined by the weights and pre-activations. In particular, it provides the limits of the student-teacher overlaps:
Proof.
Observe that can be expressed as an expectation of a pseudo-lipschitz function w.r.t the joint empirical measure over the coordinates of with the value at the coordinate given by . Therefore 3.2 implies that converges in probability to the expected overlaps of the effective process which equal by definition. ∎
We also include a useful corollary, describing the evolution of the overlaps of the weights.
Lemma A.4.
Under the assumptions of A.2 the covariance
(71) | |||
(72) |
This is a consequence of linearity of expectation on (52). Concretely, viewing as a function of the Gaussian random variables , we apply the multi-variate Stein’s Lemma to obtain:
(73) |
In particular, we obtain the following expression for the covariances upto the first time-steps:
Lemma A.5.
The covariances satisfy:
(74) |
(75) |
A.4 Pre-activations at the end of the first gradient update
For , Equation (53) simplifies to:
(76) |
We now show that the first term exactly correspond to the contributions considered in section 4.1.
Lemma A.6.
Proof.
We simply apply the conditioning by projection technique described in Section A.2 to by expressing it as: , where is independent of . The result then follows from convergence in probability of to . ∎
Next, we characterize . We consider two cases:
- •
-
•
: In this case the first-layer develops an overlap along . By initialization, we have that . Equation (62) implies that is given by:
(78) Due to the choice of symmetric initialization (Equation 4), we have . Therefore, . We thus obtain
(79) where denotes element-wise multiplication and we used the independence of . Therefore, the rows of gains a rank-one spike along . This matches the corresponding results for single batch gradient steps under the online-setting [Ba et al., 2022, Dandi et al., 2023].
A.5 Proof of Theorem 3.2
To illustrate the learning of directions solely due to the hidden progress explained in section 4.1, we first focus on the case where i.e when the parameters develop no overlap along the target subspace in the first step.
From Corollary A.3, and Slutsky’s theorem, we have that:
(81) |
where from Equation (76), can be expressed as a combination of and a Gaussian random variable independent of . Furthermore, Lemma A.4 implies that the regularization strength and step-size can be set such that the entries of have unit-variance. Now, suppose for some fixed vector . First, consider the case when lies in the subspace as defined in definition 3.1.
Therefore, the overlap for the neuron can be expressed as :
(83) |
We focus on the first term in the RHS. By assumption, converge in probability to . Therefore, using Equation (76), we obtain:
(84) | ||||
(85) |
Recall that by the choice of initialization, are diagonal with entries except for the off-diagonal entries corresponding to pairing of neurons through the symmetric initialization. Furthermore, by initialization, and by assumption .
From Lemma A.4, and the definitions of we have that by setting , the covariance simplify to:
(86) |
(87) |
By definition, have diagonal entries proportional to respectively. Therefore, we can further set such that the diagonal entry of equals . By case in section A.4, we further have that is independent of .
Substituting , we obtain the precise condition on for the neuron to learn direction the second timestep. The condition is given by:
(88) |
where and are independent Gaussian random variables. Since matches in distribution , the above condition can equivalently be expressed as the following condition on :
(89) |
where:
(90) |
where u is a standard normal variable, corresponding to . The above expectation is an analytic function of . To show that it is identically non-zero, we consider the derivative w.r.t at . We have, using the dominated-convergence theorem:
where denotes the Hermite-coefficients of . Similarly, iterating -times, we have:
(91) |
where denotes the derivative of . Note that for to not be identically zero, it is sufficient that is non-zero for some . Since by assumption, , and since the monomials span the space of polynomials, we have that there exists a such that: . The conditions on further imply that . Therefore is a not identically .
Since is an analytic function non-identically zero, and the law of is absolutely continuous w.r.t the Lebesgue measure, we have that is non-zero almost surely over the initialization. Now, the second term in Equation 82 is again an analytic function in , distinct from , and can therefore be almost surely absorbed into the non-zero overlap. This proves the first part of Theorem 3.2 for develo** an overlap along a fixed direction in when . We now proceed to show that the weights span .
Let denote the dimension of the subspace . Suppose that form an orthonormal basis of . Let matrix denote matrices with columns and respectively.
Analogous to Equation (83), we obtain:
(92) |
Following the derivation of Equation (89), we obtain that the rows of the matrix are independent for neurons for (due to the symmetric initialization in Equation (4). Furthermore each row of the matrix is absolutely continuous w.r.t the Lebesgue measure on . This implies that has full row-rank almost surely for large enough .
Now, suppose that instead lies in the even-symmetric subspace . By induction and closure properties of analytic functions, we have that can be expressed as:
(93) |
for an analytic map** . Now, similar to Equation (83), we have that:
(94) |
Using Fubini’s theorem, we may take expectation w.r.t to express each entry of as:
(95) |
for some analytic This ensures that the expectation in (82) remains for all time . This proves the second part of Theorem 3.2.
A.6 Effect of previously learned directions
We now consider the case when , i.e when the first-layer develops an overlap along . As shown in 4.1, the rows of lie along the same direction given by . Without loss of generality, we assume that the direction corresponds to in the input space and that has rows along the standard basis . Note that itself lies in by setting in definition 3.1.
From Equation (80), we obtain:
(96) |
where denotes a constant dependent on . Since is now correlated with , the condition in Equation (88) is modified to:
(97) |
Again, differentiating w.r.t , we obtain:
(98) |
Similar to section A.5, we have that is sufficient for the first term to be non-zero almost surely over . If the second-term is non-zero, we have that is learned through the staircase mechanism, since it implies that contains terms dependent on and linearly coupled with . In either case, we obtain that almost surely obtains an overlap along . This concludes the proof of the first part of Theorem 3.2.
More, generally, suppose that denote a basis of the directions in learned up to time . Then, the modified condition for learned a new direction at time is:
(99) |
for a polynomial . Therefore, new directions can be learned through a combination of the staircase and hidden-progress mechanism.
A.7 Typical examples where
For several target functions of interest, the class can be shown to cover the entire target space . We list some of them below:
-
•
Single-index odd polynomials with all non-negative/non-positive coefficients. This follows since decomposes into sums of non-negative/non-positive terms.
-
•
Single-index odd Hermite polynomials. We prove this below in Lemma A.7
-
•
Staircase function . This follows directly by evaluating for .
In general, for polynomial , the condition:
(100) |
specifies an overdetermined system of infinite homogenous polynomial equations on the coefficients of . Therefore we expect the condition to fail almost surely for typical choices of . We leave an investigation of this using algebraic tools to future investigation.
Lemma A.7.
For any odd Hermite-polynomial for ,:
(101) |
where
Proof.
Using Stein’s Lemma, we have:
(102) |
Next, we recall the following relation between Hermite polynomials and their derivatives:
(103) |
Substituting in Equation (102), we obtain:
(104) |
The above expectation can be obtained analytically using the linearization formulas for Hermite polynomials [Andrews, 2004] to show that for all . ∎
A.8 Proof of Proposition 3.5
Suppose that i.e is orthogonally even-symmetric w.r.t for some transformation . Let . Then, by the invariance of the Gaussian measure under orthogonal transformations, we have:
(105) |
However, the expectation on the right can equivalently be expressed as:
where in the second equality we used and in the third the definition 3.4. Therefore, for any , we have:
(106) |
Furthermore, it is straightforward to see that remains orthogonally even-symmetric w.r.t the composition . Therefore, we have that .
A.9 Illustration of non-even symmetric hard directions
![Refer to caption](x3.png)
We now show the existence of target functions where . Without loss of generality, we assume that the rows of lie along the standard Euclidean basis
Lemma A.8.
Suppose that . Let . Then but .
Proof.
follows directly by noting that is even-symmetric along and . We further have that a target satisfying must satisfy . Therefore, since , cannot be even-symmetric along . Next, we show that . Since span , and is symmetric w.r.t permutations of , it suffices to show that condition in definition 3.4 holds for . The orthogonal complement is given by . Therefore the transformation defined by is a valid orthogonal transformation as per definition 3.4. We have:
This shows that lies in . Similarly, we have by symmetry
We present a numerical illustration of another such example in figure 3. ∎
One can in-fact construct a family of functions with a direction , for instance lying in but in general not in . To see this, let be a function , depending only on projections of along and let by an involutory orthogonal transformation on i.e an orthogonal transformation satisfying or equivalently . Now, let be an odd function. Then, consider the function:
(107) |
We observe that:
where we used that and . Therefore, for any such function , .
A.10 Implications for generalization
Since the specific guarantees of such results depend on the choice of activation and target functions, we illustrate this for the case of single-index target functions with matching activations:
Corollary A.9.
Consider the setting of a single-index target and student network with matching activations i.e. , such that is a polynomial with finite degree, satisfying the following assumption, such that:
(108) |
, where denotes the derivative. Let be the parameters obtained after two steps of gradient descent with batch size using as in Theorem 3.2. Then, almost surely over the initialization , for any , there exists a step size such that online SGD on squared loss reaches generalization error in time .
We verify numerically that the above assumption holds in particular for all odd Hermite polynomials upto order . The corollary implies that such target functions can be learned with sample complexity using gradient descent alone, without resorting to specialized algorithms and techniques such as spectral initialization.
Proof.
Let denote the single-direction in the teacher subspace with . We note that Equation (108) is proportional to the derivative of defined in Equation (89). Therefore, the condition is sufficient to ensure that is not identically zero and the student neuron almost surely develops an overlap along . The result then follows from Proposition 2.1 in [Ben Arous et al., 2021], which proves that upon weak recovery i.e a non-zero overlap the target direction , online SGD on a differentiable activation with polynomially bounded derivatives converges to strong recovery Concretely, for any starting non-zero overlap , for any , there exists and small-enough step-size such that online SGD with time achieves overlap along Due to the matching activations, this suffices to obtain arbitrary generalization error. ∎
Appendix B General Multi-Pass Schemes
![Refer to caption](x4.png)
B.1 Sketch of Proof for Extending Theorem 3.2 to Cycling over Epochs
Let denote independent minibatches of size such that with being finite. The effective dynamics for a finite number of epochs can be obtained by noting that Theorem 3.2 in Gerbelot et al. [2023] allows generalizing Theorem A.1 to dynamics of the form:
(109) |
The above form of the dynamics allows a different update to be utilized for data corresponding to different blocks . In particular, setting to whenever and otherwise, results in a cycling schedule over the mini-batches . Subsequently, one can show that the update from in the first-epoch leads to the hidden-progress effect on when the model re-uses in the second epoch.
We believe a similar result would hold for samples in the minibatch, as displayed in Figure 5
![Refer to caption](x5.png)
Appendix C Details on the numerics
C.1 DMFT equations with a single stochastic process
In this section, we present a set of exact equations equivalent to the ones in the main text, but that depend on a single stochastic process. It is possible to show that asymptotically in the proportional limit, i.e. for and , the pre-activations of the student are distributed as , with the constraint:
Here is a zero mean Gaussian Process with covariance
and the effective regularisation concentrates to
(110) |
The memory kernel is identically zero for while for it concentrates to
(111) |
Finally, the low dimensionaly projections of the weights will obey the relation
(112) |
The procedure is explained in detail in appendix D of [Gerbelot et al., 2023], and can be equivalently derived using non-rigorous field theory techniques [Agoritsas et al., 2018].
C.2 Remark on the numerical integration of the DMFT equations
DMFT is an invaluable tool in itself to probe the behaviour of gradient based algorithms. It trades the update equation over heavily coupled weights in (5) with the ones over completely decoupled preactivations (C.1) which implies that a Monte Carlo estimation based on (C.1) is going to be vastly more efficient and it’s a trivially parallelisable computation. Furthermore, equation (C.1) is exact in limit of large , which removes completely all finite size effects. In practice, an implementation of the DMFT equations is extracting times using from the initial condition distribution of the practivations and iterating forward. The Gaussian process is sampled by rotating white Gaussian noise by the LU factor of the covariance. Sampling the gaussian process is by far the costlier operation, as each time step has a complexity , for a total complexity considering all the steps up to . Notice that this is a much more direct implementation than what is done in the literature [Roy et al., 2019, Mignacco et al., 2020], which usually starts with a guess for all the quantities and proceedes with a damped fixed point iteration until convergence, with an overall complexity , where is the number of fixed point iterations. While it could appear that simply iterating forward is suboptimal, it is a much more stable and reliable procedure: if you are using processes and you iterate forward, you are sure that at at each time step you have the best possible Monte Carlo estimate of your samples.
C.3 Details on the numerical simulations
In all the figures the continuous lines are from the numerical integration of the DMFT equations while the dots are from a direct simulation of the gradient descent dynamics. The specific hyperparameters for each setting are near each figure.
For both we fixed the second layer weights to , as for the cases under consideration this is an equivalent choice to of Gaussian second layer weights . For the DMFT integration we used a minimum of Monte Carlo samples in order to have accurate lined. The error bars are too small to be visualised. The direct simulation of the gradient descent dynamics was performed either using PyTorch or a direct implementation in Numpy. In all plots we used a minimum size for the input dimension, and averaged over at least independent instances of the dynamics.
In Figure 2 we plot the overlap matrix projected on two different directions: the parallel to the subspace that is learned in the first step and one direction in the orthogonal of this space. The projection operator is computed by performing explicitly the integrals in (78)
The code is made available through the following Github repository: https://github.com/IdePHICS/benefit-reusing-batch.