Learning time-scales in two-layers neural networks
Abstract
Gradient-based learning in multi-layer neural networks displays a number of striking features. In particular, the decrease rate of empirical risk is non-monotone even after averaging over large batches. Long plateaus in which one observes barely any progress alternate with intervals of rapid decrease. These successive phases of learning often take place on very different time scales. Finally, models learnt in an early phase are typically ‘simpler’ or ‘easier to learn’ although in a way that is difficult to formalize.
Although theoretical explanations of these phenomena have been put forward, each of them captures at best certain specific regimes. In this paper, we study the gradient flow dynamics of a wide two-layer neural network in high-dimension, when data are distributed according to a single-index model (i.e., the target function depends on a one-dimensional projection of the covariates). Based on a mixture of new rigorous results, non-rigorous mathematical derivations, and numerical simulations, we propose a scenario for the learning dynamics in this setting. In particular, the proposed evolution exhibits separation of timescales and intermittency. These behaviors arise naturally because the population gradient flow can be recast as a singularly perturbed dynamical system.
Keywords: Deep learning, Neural network, Gradient flow, Dynamical system, Non-convex optimization, Incremental learning
Mathematics Subject Classification: 34E15, 37N40, 68T07
Communicated by Joan Bruna
Contents
- 1 Introduction
- 2 Setting and canonical learning order
- 3 Further related work
- 4 The large-network, high-dimensional limit
- 5 Numerical solution
- 6 Timescales hierarchy in the gradient flow dynamics
- 7 Stochastic gradient descent and finite sample size
- 8 Discussion
- A Proof of Proposition 1
- B Appendix to Section 4
- C Calculations for the analysis of mean-field gradient flow
- D Proofs of Theorem 2 and 3: learning with projected SGD
- E Counterexamples to the canonical learning order
1 Introduction
It is a recurring empirical observation that the training dynamics of neural networks exhibits a whole range of surprising behaviors:
-
1.
Plateaus. Plotting the training and test error as a function of SGD steps, using either small stepsize or large batches to average out stochasticity, reveals striking patterns. These error curves display long plateaus where barely anything seems to be happening, which are followed by rapid drops [41, 48, 39].
- 2.
-
3.
Incremental learning. Models learnt in the first phases of learning appear to be simpler than in later phases. Among others, [5] demonstrated that easier examples in a dataset are learned earlier; [28] showed that models learnt in the first phase of training correlate well with linear models; [22] showed that, in many simplified models, the dynamics of gradient descent explores the solution space in an incremental order of complexity; [39] demonstrated that, in certain settings, a function that approximates well the target is only learnt past the point of overfitting.
Understanding these phenomena is not a matter of intellectual curiosity. In particular, incremental learning plays a key role in our understanding of generalization in deep learning. Indeed, in this scenario, stop** the learning at a certain time amounts to controlling the complexity of the model learnt. The notion of complexity corresponds to the order in which the space of models is explored.
While a number of groups have developed models to explain these phenomena, it is fair to say that a complete picture is still lacking. An exhaustive overview of these works is out of place here. We will outline three possible explanations that have been developed in the past, and provide more pointers in Section 3.
Theory : Dynamics near singular points.
Several early works [41, 17, 44] pointed out that the parametrization of multi-layer neural networks presents symmetries and degeneracies. For instance, the function represented by a multi-layer perceptron is invariant under permutations of the neurons in the same layer. As a consequence, the population risk has multiple local minima connected through saddles or other singular sub-manifolds. Dynamics near these sub-manifolds naturally exhibits plateaus. Further, random or agnostic initializations typically place the network close to such submanifolds.
Theory : Linear networks.
Following the pioneering work of [7], a number of authors, most notably [43, 30], studied the behavior of deep neural networks with linear activations. While such networks can only represent linear functions, the training dynamics is highly non-linear. As demonstrated in [43], learning happens through stages that correspond to the singular value decomposition of the input-output covariance. Time scales are determined by the singular values.
Theory : Kernel regime.
Following an initial insight of [26], a number of groups proved that, for certain initializations, the training dynamics and model learnt by overparametrized neural networks is well approximated by certain linearly parametrized models. In the limit of very wide networks, the training dynamics of these models converges in turn to the training dynamics of kernel ridge(less) regression (KRR) with respect to a deterministic kernel (independent of the random initialization.) We refer to [9] for an overview and pointers to this literature. Recently [21] show that, in high dimension, the learning dynamics of KRR also exhibits plateaus and waterfalls, and learns functions of increasing complexity over a diverging sequence of timescales.
While each of these theories offers useful insights, it is important to realize that they do not agree on the basic mechanism that explains plateaus, time-scales separation, and incremental learning. In theory , plateaus are associated to singular manifolds and high-dimensional saddles, while in theories and they are related to a hierarchy of singular values of a certain matrix. In , the relevant singular values are the ones of the input-output covariance, and the fact that these singular values are well separated is postulated to be a property of the data distribution. In contrast, in the relevant singular values are the eigenvalues of the kernel operator, and hence completely independent of the output (the target function). In this case, eigenvalues which are very different are proved to exist under natural high-dimensional distributions.
Not only these theories propose different explanations, but they are also motivated by very different simplified models. Theory has been developed only for networks with a small number of hidden units. Theory only applies to networks with multiple output units, because otherwise the input-output covariance is a matrix and hence has only one non-trivial singular value. Finally, theory applies under the conditions of the linear (a.k.a. lazy) regime, namely large overparametrization and suitable initialization (see, e.g., [9]).
In order to better understand the origin of plateaus, time-scales separation, and incremental learning, we attempt a detailed analysis of gradient flow for two-layer neural networks. We consider a simple data-generation model, and propose a precise scenario for the behavior of learning dynamics. We do not assume any of the simplifying features of the theories described above: activations are non-linear; the number of hidden neurons is large; we place ourselves outside the linear (lazy) regime.
Our analysis is based on methods from dynamical systems theory: singular perturbation theory and matched asymptotic expansions. Unfortunately, we fall short of providing a general rigorous proof of the proposed scenario, but we can nevertheless prove it in several special cases and provide a heuristic argument supporting its generality.
The rest of the paper is organized as follows. Section 2 describes our data distribution, learning model, and the proposed scenario for the learning dynamics. We review further related work in Section 3. Section 4 describes the reduction of the gradient flow to a ‘mean field’ dynamics that will be the starting point of our analysis. Section 5 presents numerical evidence of the proposed learning scenario. Finally, Sections 6 to 7 present our analysis of the learning dynamics.
Notations.
In this paper, we use the classical asymptotic notations. The notations or as both denote that in the limit . The notations or both denote that the ratio remains upper bounded in the limit. The notation or denote that and both hold. Finally, denotes that in the limit.
2 Setting and canonical learning order
We are given pairs , where is a feature vector and is a response variable. We are interested in cases in which the feature vector is high-dimensional but does not contain strong structure, but the response depends on a low-dimensional projection of the data. We assume the simplest model of this type, the so-called single-index model:
(1) |
where is a link function, denotes the standard multivariate Gaussian distribution in dimension , and . We study the ability to learn model (1) using a two-layers neural network with hidden neurons:
(2) |
where collectively denotes all the model’s parameter and is the activation function of the neural network. The factor in the definition is relevant for the initialization and learning rate. We anticipate that we will initialize the ’s to be of order one, which results in second layer coefficients .
Remark 2.1.
Standard initializations in deep learning frameworks yield second-layer coefficients [29, 23, 24]. However, it is increasingly clear that this initialization presents fundamental limitations for large . Notably, two-layers networks with this initialization converges to kernel methods [37], and the latter cannot learn ridge functions from polynomially many samples [20, 47].
It is well understood that, in order to drive the learning process outside the kernel regime (for ), it is necessary to set . This is often referred to as the ‘mean-field initialization’ [32, 13, 19, 1]. We notice that suitable generalizations of the mean-field initialization are currently used in state-of-the-art implementations [45, 46].
The bulk of our work will be devoted to the analysis of projected gradient flow in on the population risk
(3) | ||||
(4) |
In Section 7, we will bound the distance between stochastic gradient descent (SGD) and gradient flow in population risk. As a consequence, we will establish finite sample generalization guarantees for SGD learning.
Projected gradient flow with respect to the risk is defined by the following ordinary differential equations (ODEs):
(5) | ||||
(6) |
Here, can be viewed as the relative step size, namely the ratio between the first and second-layer step sizes. It is useful to make a few remarks about the definition of gradient flow:
-
•
The projection ensures that remains on the unit sphere .
-
•
The overall scaling of time is arbitrary, and the matching to SGD steps will be carried out in Section 7. The factors on the right-hand side are introduced for mathematical convenience, since the partial derivatives are of order .
-
•
As aforementioned, the factor introduced in the gradient flow of the ’s plays the role of the relative step size. Throughout the paper, we will keep as a free parameter independent of , and study the evolution of gradient flow for small . This corresponds to a setting in which the second-layer coefficients are learned much faster than the first-layer weights. We emphasize however that the small limit is taken after the large limits. Thus, despite the second-layer weights are learnt faster, the evolution of first layer weights will be crucial, and lead to true feature learning.
We assume the initialization to be random with i.i.d. components :
(7) |
where is a probability measure on . The unique solution of the gradient flow ODEs with this initialization will be denoted by . We will be interested in the case of large networks () in high dimension (). As shown below, the two limits commute (over fixed time horizons).
Our main finding is that, in a number of cases, is learnt incrementally. Namely, the function evolves over time according to a sequence of polynomial approximations of . These polynomial approximations are given by the decomposition of in , where is the standard normal density: . (For notational simplicity, we will use the shorthand instead of in the sequel.)
In order to describe the polynomial approximations learnt during the training more explicitly, we decompose and into normalized Hermite polynomials:
(8) |
Here, denotes the -th Hermite polynomial, normalized so that .
As we will see, the incremental learning behavior arises for small . By the law of large numbers (see below), the following almost sure limit exists (provided is square integrable)
(9) |
We are now in position to describe the scenario that we will study in the rest of the paper.
Definition 1.
We say that the canonical learning order holds up to level for a certain target function , activation , and distribution , if the followings hold:
-
1.
The limit below exists:
(10) -
2.
There exist constants such that the following asymptotic holds as , :
Figure 1 provides a cartoon illustration of the canonical learning order.
At first sight, the setting of Eq. (2) is overly restrictive because we require and we do not have offsets in the activations. Therefore, it might seem that and is required in order to approximate arbitrarily well the target function. In contrast, the next proposition shows that the network (2) enjoys universal approximation properties.
Proposition 1.
Assume that is Lipschitz continuous and generic in the following sense: the decomposition of into Hermite polynomials does not have any coefficient equal to . For any Lipschitz function , , and such that , there exists a sequence and with such that
This result is not surprising in view of the arguments in the next sections, which suggest that indeed gradient flow constructs such an approximation for a broad class of functions of the form . We nevertheless give an independent proof in Appendix A.
A specific realization of our general setup is determined by the triple , In the rest of the paper, we will provide evidence showing that the canonical learning order holds in a number of cases. Nevertheless, we can also construct examples in which it does not hold:
-
•
If one or more of the Hermite coefficients of the activation vanish, then the canonical learning order does not hold for general . Specifically, if , then for any the function remains orthogonal to . In particular, if then the risk remains bounded away from zero for every . We refer to Appendix E.1 for a formal statement.
-
•
If the first Hermite coefficients of vanish, , , then the canonical learning order does not hold. (See Appendix E.2 for the proof.)
-
•
In fact, we expect the canonical learning order might fail every time one or more of the coefficients vanish, for . Appendix E.3 provides some heuristic justification for this failure.
Remark 2.2.
We can compare the canonical learning order described here to the ones in earlier literature and described as theory , , in the introduction. There appears points of contact, but also important differences with both theory and :
-
•
As in theory , the plateaus and separation of time scales arise because the trajectory of gradient flow is approximated by a sequence of motions along submanifolds in the space of parameters . Along the -th such submanifold is well-approximated by a degree- polynomial. Esca** each submanifold takes an increasingly longer time.
This is reminiscent of the motion between saddles investigated in earlier work [41, 17, 44]. However, unlike in earlier work, we will see that this applies to networks with a large (possibly diverging) number of hidden neurons. Also, we identify the subsequent phases of learning with the polynomial decomposition of Eq. (8).
-
•
As in theory , subsequent phases of learning correspond to increasingly accurate polynomial approximations of the target function . However, the underlying mechanism and time scales are completely different. In the linear regime, the different time scales emerge because of increasingly small eigenvalues of the neural tangent kernel. In that case, the time required to learn degree- polynomials is of order [21].
In contrast, in the canonical learning order, polynomials of degree are learnt on a time scale of order one in (and only depending on the learning rate ). This of course has important implications when approximating gradient flow by SGD. Within the linear regime, the sample size required to learn a polynomial of order scales like [21], while in the canonical learning order, it is only of order (see Section 7).
3 Further related work
As we mentioned in the introduction, plateaus and time scales in the learning dynamics of kernel models were analyzed by [21]. A sharp analysis for the related random features model was developed by [12].
Our analysis builds upon the mean-field description of learning in two-layer neural networks, which was developed in a sequence of works, see, e.g., [32, 40, 13, 33]. In particular, we leverage the fact that, for the data distribution (1), the population risk function is invariant under rotations around the axis , and this allows for a dimensionality reduction in the mean field description. Similar symmetry argument were used by [32] and, more recently, by [1].
The single-index model can be learnt using simpler methods than large two-layer networks. Limiting ourselves to the case of gradient descent algorithms, [31] proved that gradient descent with respect to the non-convex empirical risk converges to a near global optimum, provided is strictly increasing. [4] considered online SGD under more challenging learning scenarios and characterized the time (sample size) for to become significantly larger than for a random unit vector .
Learning in overparametrized two-layer networks under model (1) (or its variations) has been studied recently by several groups. In particular, [6] considers a training procedure which runs a single step gradient descent followed by freezing the first layer and performing ridge regression with respect to the second layer. This scheme is amenable to a precise characterization of the generalization error. [11] consider a similar scheme in which a first phase of gradient descent is run to achieve positive correlation with the unknown direction . [14] also consider a two-phases scheme, and prove consistency and excess risk bounds for a more general class of target functions whereby the first equation in (1) is replaced by
(11) |
with . In particular, near optimal error bounds are obtained under a non-degeneracy condition on .
[1] consider a similar model whereby , and where , and (i.e., contains the coordinates of indexed by entries of ). Under a structural assumption on (the ‘merged staircase property’), and for fixed, they prove the two stages algorithm learns the target function with sample complexity of order . This paper is technically related to ours in that it uses mean-field theory to obtain a characterization of learning in terms of a PDE in a reduced -dimensional space.
A similar model was studied by [8] that bounds the sample complexity by for learning parities on bits using gradient descent with large batches (if , [8] require steps with batch size ).
Let us emphasize that our objective is quite different from these works. We implement a simple online SGD algorithm with additional projection steps, and try to derive a precise picture of the successive phases of learning (in particular, we do not consider two-stage schemes or layer-by-layer learning). On the other hand, we focus on a relatively simple model.
To clarify the difference, it is perhaps useful to rephrase our claims in terms of sample complexity. While previous works show that the target function can be learnt with samples, we claim that it is learnt by online SGD with test error from about samples and characterize the dependence of on for small . (Falling short of a proof in the general case.)
4 The large-network, high-dimensional limit
The first step of our analysis is a reduction of the system of ODEs (5), (6), with dimension to a system of ODEs in dimensions. We will achieve this reduction in two steps:
-
First we reduce to a system in dimensions for the variables , , . This reduction is exact and is quite standard. It is done in Section 4.1.
In order to define formally the reduced system, we define the functions via:
(12) | ||||
(13) |
Note that the above identities follow from [36, Proposition 11.31]. Throughout this section, we will make the following assumptions.
- A1.
-
The distribution of weights at initialization, is supported on .
- A2.
-
The activation function is bounded: . Additionally, the functions and are bounded and of class , with uniformly bounded first and second derivatives over . A sufficient condition for this is
- A3.
-
Responses are bounded, i.e., .
Remark 4.1.
We hereby briefly explain the sufficiency of -boundedness of derivatives of and as claimed in Assumption A2. Suppose for example that and , then we have
(14) |
where follows from Gaussian integration by parts and follows from Cauchy-Schwarz inequality.
4.1 Reduction to -independent flow
Our first statement establishes reduction mentioned above. The proof of this fact is presented in Appendix B.1.
Proposition 2 (Reduction to -independent flow).
The input dimension does not appear in the reduced ODEs, Eqs. (16) to (19), and only plays a role in the initialization of the ’s and the ’s. Namely, since , we can represent with . By concentration of , this implies that, for , , are approximately .
This discussion immediately yields the following consequence.
Corollary 1.
Let be the solution of the gradient flow ODEs (5), (6) with initialization (7), and let be the unique solution of Eqs. (16) to (19), with initialization , , for . Then, for any fixed (possibly dependent on but not on ), the followings holds with probability at least over the i.i.d. initialization :
(20) | |||
(21) | |||
(22) |
Here are absolute constants and only depends on the ’s in Assumptions A1-A3.
The proof of Corollary 1 is deferred to Appendix B.2. From now on, we will assume the initialization , for , but drop the superscript for notational simplicity. We notice in passing that the right-hand sides of Eqs. (20) to (22) are independent of : this approximation step holds uniformly over . (Note that the left hand sides are normalized by as to yield the root mean square error per entry.)
4.2 Elimination of the products
In order to state the reduction outlined above, we define the mean field risk as
(23) |
Further, we denote by the solution to the following ODEs:
(24) |
Note that (24) would be identical to (16)-(17) if we had . A priori, this is not the case. However, the two systems of equations are close to each other for large as made precise by our next proposition, which formalizes reduction .
The intuitive explanation for the approximation is quite interesting. For large , due to ‘propagation of chaos’, the neuron weights are approximately independent. Further, because of the symmetry of the problem under rotations that keep fixed, weights are approximately uniformly distributed conditional on . As a consequence, decomposing , we have , with , approximately uniform on and independent. Therefore, in high dimensions we have .
Proposition 3 (Reduction to flow in ).
Let be the unique solution of the ODEs (16)-(19) with initialization , for all . Let be the unique solution of the ODEs (24) with initialization , for all .
If assumptions A1-A3 hold, then for any there exists a constant
(25) |
(with depending on the constants appearing in Assumptions A1-A3 only) such that:
Consequently,
The proof of this proposition is deferred to Appendix B.3. Now, combining the propositions and corollaries in this section, we deduce that with high probability over the i.i.d. initialization,
(26) |
4.3 Connection with mean field theory
Consider the empirical distributions of the neurons:
(27) | ||||
(28) |
with , as in the statement of Proposition 3, i.e., solving (respectively) Eqs. (16)-(19) and Eq. (24) with initial conditions as given there.
Then, it is immediate to show that solves (in weak sense) the following continuity partial differential equation (PDE) (we refer to [2, 42] for the definition of weak solutions and basic properties, and Appendix B.4 for a short derivation.)
(29) | ||||
(30) |
where is given by
(31) | ||||
(32) |
This equation can be extended to a flow in the whole space (all probability measures on equipped with the second Wasserstein distance), and interpreted as gradient flow with respect to this metric in the following risk:
(33) |
which is the obvious extension of of Eq. (23) to general probability distributions. Proposition 3 implies that for any , and under the above initial conditions,
(34) |
If we further denote by the empirical distribution of , , when , , a further application of Corollary 1 yields
(35) |
Starting with [32, 13, 40], several authors used continuity PDEs of the form (29) to study the learning dynamics of two-layer neural networks. Following the physics tradition, this is referred to as the ‘mean-field theory’ of two-layer neural networks. Appendix B.5 sketches an alternative approach to prove bounds of the form (26), (35) using the results of [32, 33]. The present derivation has the advantages of yielding a sharper bound and of being self-contained.
4.4 A general formulation
As mentioned above, the system of ODEs in Eq. (24) is a special case of the Wasserstein gradient flow of Eq. (29) whereby we set . In order to study the solutions of Eq. (29) (hence Eq. (24)) we adopt the following framework. Let denote a probability space. Let and (, ) be two measurable functions satisfying (drop** dependencies in below)
(36) |
If endowed with the uniform measure, we obtain the equations (24). In general, the push-forward of the measure through the map satisfies the mean-field equation (29). As a consequence, the dynamics (36) can be viewed as a gradient flow on the risk
(37) |
We next characterize the landscape of the risk function . In particular, we establish that under certain conditions, the global infimum of is .
Proposition 4.
The risk function can be expressed as
(38) |
Assume that for all , and that
Then, for any , there exists a triple such that , , and .
This proposition is proved in Appendix B.6.
Remark 4.2.
Proposition 4 complements Proposition 1 which establishes approximability of the target function using the networks (2) (Proposition 4 can be seen as an version of the latter). We note that the proofs of these propositions also provides insight into the structure of approximators. Namely, we can take the weights to be i.i.d. with distribution that is symmetric under rotations around , and , is concentrated close to (on a scale that can rely on the desired approximation error).
Indeed, the analysis of gradient flow in Section 6 reveals that the solutions found by gradient flow are of this nature. Namely, neurons develop a small but strictly positive alignment with . The distribution and size of the alignment evolves over time.
Remark 4.3.
The results in this section can be generalized to multi-index models: where , the space of orthogonal matrices. Further, the corresponding limiting dynamics become
Here, represents , and for , :
The definition of is the same as before.
5 Numerical solution
In Figure 2, we present the result of an Euler discretization of Eqs. (24) where is a degree- polynomial and is the ReLU activation: ,
(39) |
These plots clearly display two of the features emphasized in the introduction: plateaus separated by periods of rapid improvement of the risk; increasingly long timescales (notice the logarithmic time axis in the second and third row).
In order to examine the incremental learning structure, we rewrite the risk of Eq. (23) by decomposing and in the basis of Hermite polynomials
(40) |
We observe that, for small , the Hermite coefficients of are learned sequentially, in the order of their degree. When is sufficiently small (right plots), this incremental learning happens in well separated phases. The plateaus and waterfalls in the plots of correspond to the network learning increasingly higher degree polynomials.
In Figure 3 we plot the evolution of the values of the and , for . We observe that the overall order of magnitude of the ’s and the ’s increases when passing through the different phases of the incremental learning process. In the mean time, some of the ’s and ’s will undergo a sign change during the learning process, which is characterized by a sudden decrease and subsequent rapid increase in its magnitude.
Altogether, the results of Figures 2 and 3 are consistent with the canonical learning order up to level as per Definition 1. While we conjecture that incremental learning also occurs for higher-order polynomials, we found this hard to observe in numerical simulations: we would need to take much smaller than in Figure 2, resulting in prohibitively large simulation costs.
First, as predicted in Definition 1, the times at which the components are learned are closer on a logarithmic scale as the degree increases. It is therefore increasingly difficult to observe time scales corresponding to higher degrees.
Second, we expect there to be a choice of the initialization , activation and target function, for which not all the components of are actually learnt. We observed empirically that this happens easily for small .
To conclude this section, in Figure 4 we compare the simplified neuron dynamics (MF) of Eq. (24) and the evolution of projected gradient descent for the original two-layer neural network (NN). From the plots we observe two remarkable phenomena: (1) the evolution of the risk for NN and MF are close to each other during the entire learning process for both large learning rate ratio () and small (), and their risk curves have the same qualitative behavior even if is small (); (2) as we increase the value of from to , the alignment between the learning curves of NN and MF improves significantly. These observations justify our argument in Section 4.2 that the inter-neuron correlations are well approximated by for wide networks.
6 Timescales hierarchy in the gradient flow dynamics
We are interested in the behavior of the solution of the ODEs (36), initialized from for all (as per Proposition 3). The canonical learning order of Definition 1 concerns the behavior of solutions for . This type of questions can be addressed within the theory of dynamical systems using singular perturbation theory [25]. Here, ‘singular’ refers to the fact that multiplies one of the highest-order derivatives. In Eq. (36), multiplies the differential term , so that the ODE system becomes singular in the limit . In particular, it degenerates to the following system of differential-algebraic equations:
(41) |
Due to singularity, the qualitative behavior of the above system is dramatically different from that of Eq. (36) with small but non-zero. This is in stark contrast to regular perturbation problems, for which the limiting dynamics will still be a system of differential equations with the same order and similar qualitative behavior as the perturbed system.
As a side remark, we note that the system (36) can be seen as a slow-fast dynamical system, where the ’s are the fast variables and the ’s are the slow variables [10]. Formally, the time derivative of the ’s is multiplied by a factor . From a dynamical systems perspective, the present case is made complicated because of a bifurcation when the ’s become non-zero.
The canonical learning order provides a detailed description of this bifurcation. We will motivate this scenario using a classical, but non-rigorous, technique of singular perturbation theory, called the matched asymptotic expansion [25, Chapter 2]. This technique decomposes the approximation of the solution in several time scales on which a regular approximation holds. These time scales are traditionally called layers in the literature; however, we avoid this terminology due to the potential confusion with the layers of the neural network.
We will work mainly using the Hermite representation of the dynamical ODEs (36), which we write down for the reader’s convenience:
(42) |
The rest of this section is organized as follows. We first give a brief overview of the method of matched asymptotic expansions and a summary of our main results regarding the learning timescales in Section 6.1. Sections 6.2-6.4 respectively describe the first three time scales of the matched asymptotic expansion of (42). This gives, for each time scale, an approximation of the , . In Appendix C.2, we detail how these sections induce an evolution of the risk alternating plateaus and rapid decreases, and support the standing learning scenario of Definition 1. Finally, in Section 6.5, we conjecture the behavior on longer time scales.
Notations.
We denote the constant function . Denote the dot product on and the associated norm. For , we denote the orthogonal projection of on the hyperplane of of functions orthogonal to :
We denote and thus is the orthogonal projection of on .
6.1 Matched asymptotic expansions
The method of matched asymptotic expansions is a common approach to finding approximate solutions of perturbed differential equations. In the present paper, we are mainly interested in applying this technique to approximate the solution to the specific singularly perturbed ODE system111Although we keep calling this an ODE system, it is important to keep in mind that it takes place in an infinite-dimensional space. of Eq. (42). Denoting by the independent variable and by the perturbation parameter, the method of matched asymptotic expansions consists of the following three steps: (1) Divide the domain of (generally a subinterval of ) to several subdomains, which may overlap each other and depend on the perturbation parameter ; (2) Within each subdomain, find an accurate approximation to the perturbed system. This is usually achieved by expanding the perturbed system in powers of , and kee** only terms that are relevant to the current domain; (3) The approximate solutions obtained in Step (2) might not be valid in the overlap of two adjacent subdomains. To resolve this issue, these approximate solutions are then combined together through a process called “matching” to produce an approximation that is valid on the entire domain.
In our setting, the singularly perturbed system (42) takes the form of
We will carry out explicit calculations for the first three time scales in Sections 6.2-6.4, respectively. Here is a summary of our main findings:
-
•
In Section 6.2 we explore the learning of the constant component of the target function, which happens at the timescale . At the end of this phase, the mean-field risk (see (37)) evolves to
(43) In other words, during this phase, gradient flow learns the constant term in . At the end of this time scale we have and .
-
•
Then, in Section 6.3 we investigate the second time scale , , during which the ’s and ’s increase to a different order in . The result of this time scale is mainly technical and needed to understand the transition to the time scale of Section 6.4. We also perform the matching procedure to combine the approximate solution within this time scale to the one obtained in Section 6.2. At the end of this time scale we have and .
-
•
To understand the evolution of the risk relevant to learning the linear component of , we introduce a new time scale in Section 6.4, and show that the linear component can be learned within this time scale. To be more accurate, at the end of this time scale we have
(44) and and .
Finally, in Section 6.5, we conjecture the behavior of the approximate solutions and induced risks for longer time scales.
6.2 First time scale: constant component
We define a “fast” time variable and replace it in Eq. (42). We expand the solutions and in powers of :
(45) | ||||
(46) |
where are implicitly functions of . They are initialized at
(47) | ||||||
(48) |
to be consistent with the initial condition and .
We substitute the expansion in (42):
(49) | |||
(50) | |||
(51) | |||
(52) | |||
(53) | |||
(54) | |||
(55) |
The basic assumption of matched asymptotic expansions is that terms of the same order in can be identified (with some limitations that we develop below). For now, let us identify terms of order :
(56) | ||||
(57) |
From (57) and (48), we have : time is too short for the to be of order .
Substituting in (56), we obtain
(58) |
Recall that is the dot product on , denotes the constant function and is the orthogonal projection of on . Equation (58) can be rewritten as
which gives after integration (using (47)):
(59) | ||||
At this point, we have determined and , and thus and up to a precision, which is sufficient to obtain a -approximation of the risk (see Section C.2). However, note that we could obtain more precise estimates by identifying higher-order terms in (49)-(55). For instance, identifying the terms in (52)-(55), we obtain . This shows that the become non-zero, though only of order on the time scale ; the inner-layer weights develop an infinitesimal correlation with the true direction thanks to the linear component of and .
The approximation constructed above should be considered as valid on the time scale . As , we obtain the following approximation of the risk (see Eq. (37) for definition, and Appendix C.2 for a detailed derivation):
This approximation breaks down when we reach a new time scale, at which the are large enough for the to be affected (at leading order) by the linear part of the functions. We detail the new time scale and its resolution in the next section.
6.3 Second time scale: linear component I
In this section, we seek a second, slower time scale, for which the behavior of the asymptotic expansion is different.
Identification of the scale.
Consider , where is to be determined. We rewrite the system (42) using , and expand the solutions and :
(60) | ||||
(61) |
where the exponent is also to be determined. (Since within the previous time scale we obtained , it is natural to assume .)
Let us pause to comment on our method.
Similarly to what has been done in the previous time scale, we will substitute the expansions (60)-(61) in the equations (42) in order to compute the different terms in the expansion. However, this step also allows us to compute the exponents and , that give respectively the new time scale and the size of the ’s.
Note that we should have proceeded similarly for the first time scale, by introducing a first time variable , expanding in powers , and determining and a posteriori. This would have led, indeed, to and . However, for simplicity, we preferred to fix these values that are natural a priori.
Finally, note that the expansions (45)-(46) and (60)-(61) are different, because they are valid on different time scales. In fact, the only coherence conditions that we require below is that the expansions match in a joint asymptotic where and . We thus build different approximations for each one of the time scales, with some matching conditions; this justifies the name of matched asymptotic expansion.
and thus
(62) | ||||
(63) | ||||
(64) |
For the first time scale, we chose , so that the terms of order were negligible compared to in (62). This means that the linear components of the functions had no effect on the at leading order. We are now interested in a new time scale where and are of the same order, i.e., ; then the linear components play a role in the dynamics.
Further, for to be non-zero, we need both sides of (64) to be of the same order, thus . Putting together, this gives .
Derivation of the ODEs for this time scale.
Let us summarize equations. For and
(65) | ||||
(66) | ||||
(67) |
First, we identify the terms of order :
(68) |
This means that the trajectory remains in the affine hyperplane defined by . Intuitively, the constant component of remains fitted by the neural network in this second time scale.
Second, we identify the terms of order in (65)-(67):
(69) | ||||
(70) |
Note that, in Eqs. (69)–(70), the time derivative of does not appear, and therefore the evolution of is not determined by these equations. In fact, is best interpreted as the Lagrange multiplier associated to the constraint (68). Namely, this is a free term that can be adjusted so that the solution of the system (69)–(70) satisfies the constraint (68). We can check unknown term in (69) leaves the right degree of freedom such that this is the case: we have
In this last expression, the first unknown term can always compensate the second term so that the constraint is satisfied. The entire evolution of is determined by higher orders in the expansion.
To eliminate this Lagrange multiplier, we use again the compact notations:
(71) | ||||
(72) |
and thus
(73) | ||||
(74) |
Matching.
The initialization of the ODEs (71)-(72) for the second time scale is determined by a classical procedure that matches with the previous time scale. In this paragraph, we denote the approximation obtained in the first time scale (Section 6.2), and the approximation in the second time scale, described above.
Consider an intermediate time scale , , and assume so that
In this intermediate regime, we want the approximations provided on the first and the second time scales to match: and (resp. and ) should match to leading order.
From the first time scale approximation,
(75) | ||||
(76) | ||||
(77) | ||||
(78) | ||||
(79) |
From the second time scale approximation,
(80) | ||||
(81) |
By matching, Equations (79) and (81) should be coherent. Thus the ODE for the second time scale should be initialized from .
Similarly, the matching procedure gives that the ODE for the second time scale should be initialized from .
Solution.
As we are done with the matching procedure, we now consider the solution in the second time scale only, that we denote again by , as in (71), (72). The matching procedure motivates us to consider the solution of (73)-(74) initialized at , . This gives
To conclude, we note that is constrained by (68). Further, from (70),
thus .
Putting together, these equations give:
(82) |
We observe that and diverge as . This implies that our approximation on the second time scale must break down at a certain point. Indeed, we analyzed this time scale under the assumption that both and are of order . However, since and diverge exponentially as , as per Eq. (82), this assumption breaks down when .
More precisely, in (65) (resp. (67)), the term includes a term of the form
When and become of order , this term becomes of order , which is then of the same order as the term in (65) (resp. the term in (67)). At this point, these terms can not be neglected anymore. From (82), we have
Therefore, and become of order at the time , at which the approximation on the second time scale breaks down. We thus introduce a new time scale centered at this critical point.
6.4 Third time scale: linear component II
We now introduce the time . As is only a translation from , the ODEs in terms of are the same as the ones in term of . However, in this time scale, and have diverged. In coherence with the discussion above, we seek expansions of the form
(83) | ||||
(84) |
Similarly to the second time scale, we substitute (83)-(84) in (42) and obtain
First, we identify the terms of order :
(85) |
This means that has no component diverging in in the direction of .
Second, we identify the terms of order :
(86) |
Put together with (85), this equation ensures that the constant component of remains learned on this third time scale.
Third, we identify the terms of order :
(87) |
Again, the term is best interpreted as the Lagrange multiplier associated to the constraints (85), (86). Using the compact notations,
where in the last equality we use (85). Thus we can rewrite (87) as
(88) |
and thus
(89) |
In Appendix C.1, we solve this system of ODEs and determine the initial condition by matching with the previous layer. The result is that
(90) |
where is the function
(91) |
This solution finishes to describe how the linear part of the function is learned. Plugging it into the equations for and , we get
which converges to as . Consequently, we obtain the following approximation for within this time scale (again, see Appendix C.2 for details):
6.5 Conjectured behavior for larger time scales
The analysis of the previous sections naturally suggests the existence of a sequence of cutoffs. At each time scale, a new polynomial component of is learned within a window that is much shorter than the time elapsed before that phase started. Along this sequence, we expect and to grow to increasingly larger scales in (but remains while diverges).
More precisely, we assume that during the -th phase, the network learns the degree- component , and various quantities satisfy the following scaling behavior:
(92) |
where is an increasing sequence and are decreasing sequences. Further, while learning of this component takes place when , the actual evolution of the risk (and of the neural network) take place on much shorter scales, namely:
(93) |
where is also decreasing, with . The goal of this section is to provide heuristic arguments to conjecture the values of , , and . We will base this conjecture on a rigorous analysis of a simplified model.
The simplified model is motivated by the expectation (supported by the heuristics and simulations in the previous sections) that learning each component happens independently from the details of the evolution on previous time scales. In the simplified model, the activation function is proportional to the -th Hermite polynomial, namely . This is the component of that we expect to be relevant on the -th time scale. The gradient flow equations (42) then read:
(94) |
with corresponding risk component
We capture the effect of learning dynamics on the previous time scales by the overall magnitude of the ’s and ’s at initialization. Namely, we choose the scale of initialization of the simplified model to be given by the end of the -th time scale, i.e., and . Further, in order for the -th component to be learned, namely
(95) |
we require so that . Analogously, we assume .
Based on this consideration, we introduce the rescaled variables
Rewriting Eq. (94) in terms of ’s and ’s, and using , we get that
(96) |
In order for the ’s and ’s to be learned simultaneously, we need , which implies . Making a further change of the time variable , where , it follows that
(97) |
Moreover, rewriting the risk in terms of the rescaled variables , satisfies the ODE:
(98) |
Note that with our choice of and , we have . This means that the ’s and ’s are initialized at the same scale, namely
(99) |
The theorem below describes quantitatively the dynamics of the simplified model for small , and determines the value of (recall that ):
Theorem 1 (Evolution of the simplified gradient flow).
Assume and let be the unique solution of the ODE system (97), initialized as per Eq. (99) (note in particular that ). Then the followings hold:
-
Let us denote
(100) and assume . For , define
(101) Then, for any fixed we have as . Further, if is a discrete probability measure, then there exists and, for any a constant independent of such that
(102) (103) namely the -th component is learnt in an time window around .
-
Similarly, we denote
(104) If , then the same claims as in hold.
-
If neither of the conditions at points , holds, and
(105) for almost every . Then, for such and each , there exists a constant such that
(106) meaning that converges to eventually.
We further note that with , and with .
Remark 6.1.
Under the conditions of cases and , we see that the degree- component of the target function is learnt within an time window around , which is consistent with the timescales conjectured in Definition 1.
Remark 6.2.
Case corresponds to becoming close to in time , and staying at . In other words, the neurons become orthogonal to the target direction and play no role in learning higher-degree components any longer.
Informally, case couples the learning of different polynomial components. It can happen that the learning phase induces an effective initialization within the domain of case .
We expect this not to be the case for suitable choices of initialization (or equivalently ), , and . Establishing this would amount to establishing that the canonical learning order holds.
7 Stochastic gradient descent and finite sample size
So far we focused on analyzing the projected gradient flow (GF) dynamics with respect to the population risk, as defined in Eqs. (5)-(6). In this section, we extract the implications of our analysis of GF on online projected stochastic gradient descent, which is a projected version of the SGD dynamics (162).
For simplicity of notation, we denote by a datapoint and by the parameters of neuron . For and , we define
The projected SGD dynamics is specified as follows:
(107) |
where for and compact , , and . Note that the ’s here are different from the ’s in Section 6.
We prove that, for small , the projected SGD of Eq. (107) is close to the gradient flow of Eqs. (5)-(6). Throughout this section, we make the following assumptions similar to those assumed in Section 4:
- A1.
-
is supported on . Hence, for all .
- A2.
-
The activation function is bounded: . Additionally, define for :
(108) (109) We then require the functions and to be bounded and differentiable, with uniformly bounded and Lipschitz continuous gradients for all :
(110) (111) Similar to Remark 4.1, we can show that a sufficient condition for Eq.s (110) and (111) is
where the constant depends uniquely on .
- A3.
-
Assume , then we require that almost surely. Moreover, we assume that for all , both and are -sub-Gaussian.
The following theorem upper bounds the distance between gradient flow and projected stochastic gradient descent dynamics.
Theorem 2 (Difference between GF and Projected SGD).
The proof is presented in Appendix D and follows the same scheme as in that of Theorem 1 part (B) in [33]. The main difference with respect to that theorem is here we are interested in projected SGD (and GF) instead of plain SGD (and GF), hence an additional step of approximation is required, and the ’s and ’s need to be treated separately. We next draw implications of the last result on learning by online SGD within the canonical learning order.
Theorem 3.
Fix any . Assume and the initialization be such that the canonical learning order of Definition 1 holds up to level for some , and that
(117) |
Then, there exist constants , , and that depend on (together with and ) such that the following happens. Assume and are such that , , and the step size and number of samples (equivalently, number of steps) satisfy
(118) | ||||
(119) |
Then, with probability at least , the projected gradient descent algorithm of Eq. (107) achieves population risk smaller than :
(120) |
8 Discussion
We conclude by discussing some of our findings as well as potential extensions of our work. As mentioned in the introduction, our initial motivation was to understand certain ubiquitous phenomena in the learning dynamics of multi-layer neural networks. A particularly striking phenomenon that we could reproduce in the present mathematical setting is the coexistence of plateaus in which the risk barely changes and sudden drops.
In the next paragraphs, we will briefly emphasize results or future directions that were not anticipated at the beginning of this work.
Implicit bias in function space.
We provided evidence towards the canonical learning order of Definition 1. According to this scenario, the target function is learnt according to its decomposition into Hermite polynomials, with lower degree components learnt first. This theory applies to online SGD via Theorem 2 and Theorem 3. In this setting, the number of SGD steps correspond to the number of samples. Therefore, at a small sample size, SGD will fit a low degree polynomial approximation of the target function, with the degree increasing with samples.
A similar phenomenon is observed with (rotationally invariant) kernel methods [34], with one important difference. Here the number of samples always scale linearly in the degree, while for kernel methods, different polynomial degree correspond to different scalings with the dimension.
Implicit bias in parameter space.
Our analysis tracks the evolution of the weights as well. As explained in Section 6, in order for the degree- component of the target function to be well approximated (in the limit), it is sufficient that . Here is an abstract neuron index, is the second-layer weight and is the projection of the first layer weight along the target direction .
Naively, one would expect that, in order for learning to take place, first layer weights should be well aligned with , i.e. should concentrate close to one. However this is not the only way to satisfy the constraints . Indeed, our analysis in Section 6 indicates that gradient flow satisfies this constraint with and with , (so that will be of order one) as . In other words, the alignment is small, and second layer weights are large. (In general, weights on multiple scales coexist.)
The role of the learning rate .
The initialization of parameters and relative step-sizes play a key role in modern (non-convex) machine learning. The combination of the two scalings (initialization and relative stepsize) affects the learning dynamics. In order to clarify this point, we can consider a general parametrization (we keep )
and gradient flow dynamics
(Note that the learning rate in the second equation can be set to without loss of generality, by rescaling the time axis.) Rewriting this in terms of the coefficients , so that the function representation is kept fixed, we have
while the second equation remains unchanged. This parametrization allows us to compare various scalings in a uniform fashion.
- •
-
•
In this paper: , , after .
- •
As mentioned already, mean field scaling can exhibit better feature learning properties. In particular, the class of functions studied in the present paper can require much larger sample size to learn under the classical scaling [37, 20, 47]. The choice of initialization in this paper is the same as in the mean field literature, with the difference that the relative learning rate is a factor smaller, hence making it –in a sense– slightly closer to the the classical scaling. It would be interesting to explore other scalings as well.
We also note that, while the limit of small is interesting, setting directly leads to a singular behavior222No matter how we rescale time, in this case learning takes place instantly, up to a certain critical degree.. Formally, setting corresponds to kee** second layer weights equal to their optimal values: a correct analysis of this case requires to account for the role of stepsize and not just use the gradient flow approximation.
More complex network models.
The choice of the neural network model in this paper was mainly dictated by the desire to avoid inessential technicalities. It would be important to move towards more realistic models.
First, we used projected gradient descent to constrain the weights’ norms . While this is a common theoretical device in studying single-index models [4, 11], we believe that techniques developed here can be extended to the more general case. Analogously, we could add biases to the network architecture and hence replace Eq. (2) by
(121) |
With this change, the limiting mean-field dynamics will be an autonomous ODE system of where . We expect that its evolution will be qualitatively similar to that of the simplified dynamics considered in the paper.
Second, the single-index model studied here is a simple example of target function which requires feature learning. An obvious generalization is to consider multi-index models, as already discussed in Remark 4.3.
Finally, it would be interesting to generalize our analysis to classification losses.
Acknowledgments
This work was supported by the NSF through award DMS-2031883, the Simons Foundation through Award 814639 for the Collaboration on the Theoretical Foundations of Deep Learning, the NSF grant CCF-2006489 and the ONR grant N00014-18-1-2729, and a grant from Eric and Wendy Schmidt at the Institute for Advanced Studies. Part of this work was carried out while Andrea Montanari was on partial leave from Stanford and a Chief Scientist at Ndata Inc dba Project N. The present research is unrelated to AM’s activity while on leave.
References
- Abbe et al. [2022] Emmanuel Abbe, Enric Boix Adsera, and Theodor Misiakiewicz. The merged-staircase property: a necessary and nearly sufficient condition for sgd learning of sparse functions on two-layer neural networks. In Conference on Learning Theory, pages 4782–4887. PMLR, 2022.
- Ambrosio et al. [2005] Luigi Ambrosio, Nicola Gigli, and Giuseppe Savaré. Gradient flows: in metric spaces and in the space of probability measures. Springer Science & Business Media, 2005.
- Arnaboldi et al. [2023] Luca Arnaboldi, Ludovic Stephan, Florent Krzakala, and Bruno Loureiro. From high-dimensional & mean-field dynamics to dimensionless odes: A unifying approach to sgd in two-layers networks. arXiv preprint arXiv:2302.05882, 2023.
- Arous et al. [2021] Gerard Ben Arous, Reza Gheissari, and Aukosh Jagannath. Online stochastic gradient descent on non-convex losses from high-dimensional inference. The Journal of Machine Learning Research, 22(1):4788–4838, 2021.
- Arpit et al. [2017] Devansh Arpit, Stanisław Jastrzębski, Nicolas Ballas, David Krueger, Emmanuel Bengio, Maxinder S Kanwal, Tegan Maharaj, Asja Fischer, Aaron Courville, Yoshua Bengio, et al. A closer look at memorization in deep networks. In International conference on machine learning, pages 233–242. PMLR, 2017.
- Ba et al. [2022] Jimmy Ba, Murat A Erdogdu, Taiji Suzuki, Zhichao Wang, Denny Wu, and Greg Yang. High-dimensional asymptotics of feature learning: How one gradient step improves the representation. In Advances in Neural Information Processing Systems, 2022.
- Baldi and Hornik [1989] Pierre Baldi and Kurt Hornik. Neural networks and principal component analysis: Learning from examples without local minima. Neural networks, 2(1):53–58, 1989.
- Barak et al. [2022] Boaz Barak, Benjamin L Edelman, Surbhi Goel, Sham Kakade, Eran Malach, and Cyril Zhang. Hidden progress in deep learning: Sgd learns parities near the computational limit. arXiv:2207.08799, 2022.
- Bartlett et al. [2021] Peter L Bartlett, Andrea Montanari, and Alexander Rakhlin. Deep learning: a statistical viewpoint. Acta numerica, 30:87–201, 2021.
- Berglund [2001] Nils Berglund. Perturbation theory of dynamical systems. arXiv preprint math/0111178, 2001.
- Bietti et al. [2022] Alberto Bietti, Joan Bruna, Clayton Sanford, and Min Jae Song. Learning single-index models with shallow neural networks. Advances in Neural Information Processing Systems, 35:9768–9783, 2022.
- Bodin and Macris [2021] Antoine Bodin and Nicolas Macris. Model, sample, and epoch-wise descents: exact solution of gradient flow in the random feature model. Advances in Neural Information Processing Systems, 34:21605–21617, 2021.
- Chizat and Bach [2018] Lenaic Chizat and Francis Bach. On the global convergence of gradient descent for over-parameterized models using optimal transport. Advances in neural information processing systems, 31, 2018.
- Damian et al. [2022] Alexandru Damian, Jason Lee, and Mahdi Soltanolkotabi. Neural networks can learn representations with gradient descent. In Conference on Learning Theory, pages 5413–5452. PMLR, 2022.
- [15] Encyclopedia of Mathematics. Bernoulli equation. http://encyclopediaofmath.org/index.php?title=Bernoulli_equation&oldid=40764.
- Frye and Efthimiou [2012] Christopher Frye and Costas J Efthimiou. Spherical harmonics in p dimensions. arXiv preprint arXiv:1205.3548, 2012.
- Fukumizu and Amari [2000] Kenji Fukumizu and Shun-ichi Amari. Local minima and plateaus in hierarchical structures of multilayer perceptrons. Neural networks, 13(3):317–327, 2000.
- Ghorbani et al. [2020a] Behrooz Ghorbani, Song Mei, Theodor Misiakiewicz, and Andrea Montanari. Discussion of:“nonparametric regression using deep neural networks with relu activation function”. The Annals of Statistics, 48(4), 2020a.
- Ghorbani et al. [2020b] Behrooz Ghorbani, Song Mei, Theodor Misiakiewicz, and Andrea Montanari. When do neural networks outperform kernel methods? Advances in Neural Information Processing Systems, 33:14820–14830, 2020b.
- Ghorbani et al. [2021] Behrooz Ghorbani, Song Mei, Theodor Misiakiewicz, and Andrea Montanari. Linearized two-layers neural networks in high dimension. The Annals of Statistics, 49(2):1029–1054, 2021.
- Ghosh et al. [2021] Nikhil Ghosh, Song Mei, and Bin Yu. The three stages of learning dynamics in high-dimensional kernel methods. In International Conference on Learning Representations, 2021.
- Gissin et al. [2019] Daniel Gissin, Shai Shalev-Shwartz, and Amit Daniely. The implicit bias of depth: How incremental learning drives generalization. arXiv preprint arXiv:1909.12051, 2019.
- Glorot and Bengio [2010] Xavier Glorot and Yoshua Bengio. Understanding the difficulty of training deep feedforward neural networks. In Proceedings of the thirteenth international conference on artificial intelligence and statistics, pages 249–256. JMLR Workshop and Conference Proceedings, 2010.
- He et al. [2015] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Delving deep into rectifiers: Surpassing human-level performance on imagenet classification. In Proceedings of the IEEE international conference on computer vision, pages 1026–1034, 2015.
- Holmes [2013] Mark Holmes. Introduction to Perturbation Methods. Springer Texts in Applied Mathematics, 2013.
- Jacot et al. [2018] Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Convergence and generalization in neural networks. Advances in neural information processing systems, 31, 2018.
- ** et al. [2019] Chi **, Praneeth Netrapalli, Rong Ge, Sham M Kakade, and Michael I Jordan. A short note on concentration inequalities for random vectors with subgaussian norm. arXiv preprint arXiv:1902.03736, 2019.
- Kalimeris et al. [2019] Dimitris Kalimeris, Gal Kaplun, Preetum Nakkiran, Benjamin Edelman, Tristan Yang, Boaz Barak, and Haofeng Zhang. Sgd on neural networks learns functions of increasing complexity. Advances in neural information processing systems, 32, 2019.
- LeCun et al. [2002] Yann LeCun, Léon Bottou, Genevieve B Orr, and Klaus-Robert Müller. Efficient backprop. In Neural networks: Tricks of the trade, pages 9–50. Springer, 2002.
- Li et al. [2020] Zhiyuan Li, Yu** Luo, and Kaifeng Lyu. Towards resolving the implicit bias of gradient descent for matrix factorization: Greedy low-rank learning. arXiv preprint arXiv:2012.09839, 2020.
- Mei et al. [2018a] Song Mei, Yu Bai, and Andrea Montanari. The landscape of empirical risk for nonconvex losses. The Annals of Statistics, 46(6A):2747–2774, 2018a.
- Mei et al. [2018b] Song Mei, Andrea Montanari, and Phan-Minh Nguyen. A mean field view of the landscape of two-layer neural networks. Proceedings of the National Academy of Sciences, 115(33):E7665–E7671, 2018b.
- Mei et al. [2019] Song Mei, Theodor Misiakiewicz, and Andrea Montanari. Mean-field theory of two-layers neural networks: dimension-free bounds and kernel limit. In Conference on Learning Theory, pages 2388–2464. PMLR, 2019.
- Mei et al. [2022] Song Mei, Theodor Misiakiewicz, and Andrea Montanari. Generalization error of random feature and kernel methods: hypercontractivity and kernel matrix concentration. Applied and Computational Harmonic Analysis, 59:3–84, 2022.
- Montanari and Zhong [2022] Andrea Montanari and Yiqiao Zhong. The interpolation phase transition in neural networks: Memorization and generalization under lazy training. The Annals of Statistics, 50(5):2816–2847, 2022.
- O’Donnell [2014] Ryan O’Donnell. Analysis of boolean functions. Cambridge University Press, 2014.
- Oymak and Soltanolkotabi [2020] Samet Oymak and Mahdi Soltanolkotabi. Toward moderate overparameterization: Global convergence guarantees for training shallow neural networks. IEEE Journal on Selected Areas in Information Theory, 1(1):84–105, 2020.
- Pinkus [1999] Allan Pinkus. Approximation theory of the mlp model in neural networks. Acta numerica, 8:143–195, 1999.
- Power et al. [2022] Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin, and Vedant Misra. Grokking: Generalization beyond overfitting on small algorithmic datasets. arXiv:2201.02177, 2022.
- Rotskoff and Vanden-Eijnden [2018] Grant Rotskoff and Eric Vanden-Eijnden. Parameters as interacting particles: long time convergence and asymptotic error scaling of neural networks. Advances in neural information processing systems, 31, 2018.
- Saad and Solla [1995] David Saad and Sara A Solla. On-line learning in soft committee machines. Physical Review E, 52(4):4225, 1995.
- Santambrogio [2015] Filippo Santambrogio. Optimal transport for applied mathematicians. Birkäuser, NY, 55(58-63):94, 2015.
- Saxe et al. [2013] Andrew M Saxe, James L McClelland, and Surya Ganguli. Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. arXiv preprint arXiv:1312.6120, 2013.
- Wei et al. [2008] Haikun Wei, Jun Zhang, Florent Cousseau, Tomoko Ozeki, and Shun-ichi Amari. Dynamics of learning near singularities in layered networks. Neural computation, 20(3):813–843, 2008.
- Yang and Hu [2020] Greg Yang and Edward J Hu. Feature learning in infinite-width neural networks. arXiv:2011.14522, 2020.
- Yang and Hu [2021] Greg Yang and Edward J Hu. Tensor programs iv: Feature learning in infinite-width neural networks. In International Conference on Machine Learning, pages 11727–11737. PMLR, 2021.
- Yehudai and Shamir [2019] Gilad Yehudai and Ohad Shamir. On the power and limitations of random features for understanding neural networks. Advances in Neural Information Processing Systems, 32, 2019.
- Yoshida and Okada [2019] Yuki Yoshida and Masato Okada. Data-dependence of plateau phenomenon in learning with neural network—statistical mechanical analysis. Advances in Neural Information Processing Systems, 32, 2019.
Appendix A Proof of Proposition 1
By standard approximation theory arguments [38], it is sufficient to show that there exists an integrable function such that
(122) |
(We denote by the uniform probability measure over .)
Denote by the Gegenbauer polynomial of order and degree (see, e.g., [34]). Namely, form an orthogonal system with respect to the measure with density , . Recall that for fixed of norm , the polynomials are spherical harmonics satisfying
(123) |
Also, is the dimension of the space of spherical harmonics of degree , whence form an orthonormal set. We will denote by the -th coefficient of the expansion of in this basis, and similarly for , with coefficients , namely
As shown for instance in [34], is the -th Hermite coefficient of and similarly for . In particular, for all large enough. For a large integer let
By Eq. (123), we have, for ,
Denoting by a uniform random vector on the sphere of radius , and , we have
where in we used concentration of -squared random variables, Lipschitzness of and , and that is the projection of orthogonal to polynomials of degree at most (with respect to the measure with density proportional to on ). Therefore
The claim (122) follows by taking .
Appendix B Appendix to Section 4
B.1 Proof of Proposition 2
B.2 Proof of Corollary 1
First, note that in the proof of Lemma 1, we obtain the following a priori estimate on the magnitude of the ’s:
(125) |
where only depends on the ’s in Assumptions A1-A3. Using a similar argument as that in the proof of Proposition 3, we obtain that for any and ,
and for ,
Therefore, we deduce that
Defining
then we know that . Applying Grönwall’s inequality yields
Since and for any , . Using standard concentration inequalities, we know that
(126) |
with probability at least , where and are both absolute constants. Therefore,
(127) | ||||
(128) | ||||
(129) |
Next we upper bound the risk difference, by direct calculation,
with probability at least , where the constant only depends on the ’s from Assumptions A1-A3. The conclusion now follows from taking the supremum over all . This completes the proof of Corollary 1.
B.3 Proof of Proposition 3
We consider , the dot product between and that is out of the relevant subspace spanned by . We show that these variables satisfy the ODEs
(130) |
By definition of , we readily see that
Plugging in Eq.s (17) to (19) gives that
This proves Eq. (130).
Lemma 1.
If Assumptions A1-A3 hold, then we have for any fixed :
Proof.
To begin with, using Eq. (130), we obtain that
Using the ODEs for the ’s, we obtain that
where follows from our assumptions and the fact that , since by gradient flow equations. Moreover, the constant only depends on the ’s. Since for all , we know that for all , thus leading to the following estimate:
where the constant only depends on the ’s in our assumptions. At initialization, we know that . Applying Grönwall’s inequality yields that
which further implies that
This completes the proof. ∎
We show that
(131) |
To this end, we define . By our assumption, . Moreover, using the same technique as in the proof of Lemma 1, we know that for all . According to Eq.s (16)-(19) and Eq. (24), we deduce that
thus leading to the following estimate:
where in we use the Cauchy-Schwarz inequality and the inequality of arithmetic and geometric means, and follows from the conclusion of Lemma 1. Similarly, we obtain that
which further implies that
Combining the above estimates, we finally deduce that
Applying Grönwall’s inequality immediately implies
(132) |
which further leads to Eq. (131) and concludes the proof of Proposition 3. The “consequently” part can be shown via direct calculation, but we include it here for the sake of completeness. By definition, for any we have
Therefore,
(133) |
as desired.
B.4 Derivation of the mean field dynamics (29)
For any bounded continuous , we have
where follows from the ODE satisfied by the ’s, and in we use integration by parts. We thus obtain that
which recovers Eq. (29).
B.5 Details of the alternative mean field approach
Let
(134) |
where is the solution of (5)–(6). is a measure on solving the continuity PDE
(135) |
where is given by
A remarkable property of the equation (135) is that it preserves invariance to rotations orthogonal to . Indeed, assume that is invariant to rotations orthogonal to . In this case, we show that and depend only on and . Let (resp. ) denote the component of (resp. ) orthogonal to . Let denote a random uniform rotation orthogonal to . By the rotation invariance of ,
The random variable is a one dimensional projection of a random variable uniform on the unit sphere of the hyperplane orthogonal to ; thus it has the density (see, e.g., [16, Lemma 4.17]). Denote
then we have
(136) |
Further, we compute
In the equation above, we have and as a.s., we have
Thus we obtain
Note that
and thus we have
(137) |
Of course, a discrete measure of the form (134) can not be invariant to rotations orthogonal to . However, if the are initialized uniformly on the unit sphere, then the measure converges to a measure with the rotation invariance as . One can then apply the results of [33] to control the deviations from this limit. Let us thus assume that satisfies the rotation invariance. Define the map . Then, from (136), (137), the push-forward of through the map satisfies the continuity equation
where is given by
When , converges weakly to the Dirac mass . As a consequence,
As a consequence, in the limit , we recover the equations (29)–(32). Moreover, if , then converges weakly to as .
B.6 Proof of Proposition 4
First, note that the potential functions and admit the following expansion:
As a consequence, we deduce that
Now we show that the above risk can be arbitrarily small. We will choose to be the Lebesgue measure on and so that . Now, we define the following set of sequences
Since and , we know that , i.e., is a linear subspace of . Now it suffices to show that is dense in , which is equivalent to , namely
Fix any such and take such that for all , for some . We then have
where the last step follows from dominated convergence theorem. Indeed, by Hölder’s inequality,
As a consequence, the function series uniformly absolutely converges to the continuous function on . The above argument then implies that for any , , which further implies that . Therefore, for all . Since for all , we must have for all , i.e., . This completes the proof of the density of in , and thus the proof of the Proposition.
Appendix C Calculations for the analysis of mean-field gradient flow
C.1 Solution of Eq. (89)
In order to solve the system (89), we start from an associated one-dimensional ODE.
Lemma 2.
The solution of the ODE
(138) |
with initial condition is
(139) |
Proof.
For simplicity, denote , and . Then
This is Bernoulli differential equation (see, e.g., [15]). In this situation, the classical trick is to reduce the problem to a linear inhomogeneous first-order equation by considering
Solving this linear inhomogeneous first-order equation gives
and thus
which is the claimed result. ∎
Let be a solution of (138) and consider
(140) |
Then are solutions of the constrained ODE system (85), (88). Indeed,
thus the constraint (85) is satisfied. Further
A similar computation shows that the differential equation for is also satisfied. This concludes that (140) is a valid candidate to solve the third time scale.
Matching.
To determine the value of the initialization we perform a matching procedure with the previous time scale. In this paragraph, we denote the approximation obtained in the second time scale (Section 6.3), and the approximation in the third time scale (Section 6.4 and above).
Consider an intermediate time scale with . Assume . Then
From the approximation (82) on the second time scale,
(141) |
From the approximation on the third time scale,
Note that as , from (139),
Thus
(142) |
By matching, Equations (141) and (142) should be coherent. This gives
and thus
(143) |
One could check similarly that also satisfies the matching conditions under the same constraint, and thus that (140) are indeed the solutions of the third time scale.
C.2 Induced approximation of the risk
In this section, we show that the behavior of and derived in Sections 6.2–6.4 leads to an evolution of the risk alternating plateaus and rapid decreases, in agreement with the canonical learning order of Definition 1. For the convenience of the reader, we recall the expression (37) of the risk
First time scale (Section 6.2).
On this time scale, we have and . Thus for all , whence .
Second time scale (Section 6.3).
On this time scale, we have and . Thus for all , .
Further, using (68),
Thus as ,
This second time scale does not induce any transition of the risk (but was necessary to understand the divergence of and ).
Third time scale (Section 6.4).
On this time scale, we have and . Thus for all , .
C.3 Proof of Theorem 1
Throughout the proof, we will use the shorthand to represent . First, note that according to the ODE satisfied by (Eq. (98)), we know that must be non-increasing, thus for small enough ,
Hence, we obtain the estimates:
According to the comparison theorem for system of ODEs, we know that , for all where
and
(144) |
The above system of ODEs can be solved analytically via integration. First, we note that
which implies that (further note )
(145) |
The ODE system then reduces to , which admits the solution
(146) |
Since , we know that until , which means that until . As a consequence,
until . This means that the learning of the -th component will not begin until , namely for any fixed . Note that the above argument applies to all of the settings in the theorem statement.
Next, we show that for any fixed , , which means that the -th component can be learnt in time. To prove our claim by contradiction, assume that there exists and a sequence , such that
(147) |
By definition of , we know that ,
Now, assume the condition of setting (a) holds and denote
(148) |
Then by definition and our assumption that is of the same order as , we know that . Since , there exists such that . Note that here we can choose and to be arbitrarily small since the set is non-increasing in and . For and , we have
Moreover, we know that at initialization, . Using the ODE comparison theorem and a similar argument as that in proving , we deduce that for sufficiently large such that , there exist constants that does not depend on satisfying the following: For all and ,
This further implies that at time ,
(149) |
According to Eq. (98), we know that will decrease to exponentially fast in an time window after , which contradicts our assumption (147). This proves that under setting (a). Next, we show that setting (b) can be reduced to setting (a). Under setting (b), let us denote
Then similar to the previous argument, there exists such that , and further we can choose and to be arbitrarily small. For , we have
Hence, both and will decrease at initialization. Moreover, Eq. (97) implies that
Integrating both sides of the above equation, we obtain that
(150) |
which is close to as long as . To be accurate, let us define
then we know that and under the assumption (147), where the latter claim can be proved through making the change of variable and . Note that after the time point , the sign of changes. Hence, , and and will begin to increase for . Similarly, we can show that in time after , both and become of order , and we still have . This reduces our case to case .
We have proven that under settings (a) and (b), for any fixed . This means that some of the neurons become of order and the -th component of the target function is learnt at a timescale of order . Next, we show that if the probability measure is discrete, then the evolution of actually happens in an time window. It suffices to prove that, for any a small constant (),
(151) |
as . Note that by continuity and monotonicity of , we have
By definition of , we know that ,
Denote by the realizations of under the discrete measure , and by the point masses of . Then, we know that
(152) |
which implies that , s.t. . Applying Lemma 3 yields
(153) | ||||
(154) |
It then follows from Eq. (98) that will decrease to exponentially fast, and Eq. (151) holds consequently. This completes the proof for settings (a) and (b).
We then focus on the case (c). By our assumption, for almost every there exists (may depend on ) such that
for sufficiently small . Therefore, and will keep decreasing until one of them reaches , which means that
(155) |
According to Eq. (150) and the inequality , will not reach until reaches . Furthermore, for any ,
thus leading to
(156) | ||||
(157) |
Using again the comparison theorem for ODE, we get that
(158) |
Since , it follows immediately that for any , there exists a constant such that
(159) |
This completes the discussion for case (c), thus concluding the proof of Theorem 1.
Lemma 3.
Let be a constant that does not depend on . Then there exists a constant that only depends on and such that the following holds: For any , satisfying and , we have
(160) |
Proof.
If , then we immediately get
Otherwise, , and consequently
where the last line follows from the AM-GM inequality. This completes the proof. ∎
Appendix D Proofs of Theorem 2 and 3: learning with projected SGD
We will prove Theorem 2 which bounds the distance between GF and projected SGD in sub-Sections D.1 through D.3, with sub-Section D.4 devoted to the proof of Theorem 3. Throughout this section, we use to refer to any constant that only depends on the ’s from Assumptions A1-A3, whereas the value of can change from line to line. We start with an elementary lemma that establishes the Lipschitz continuity of the gradient flow trajectory:
Lemma 4 (A priori estimate).
There exists a constant that only depends on the ’s, such that for all , is supported on , namely for all . Moreover, for any , we have
Proof.
First, notice that along the trajectory of gradient flow, the risk must be non-increasing. In fact, we have
Therefore, we obtain that
where the last line follows from our assumption. Since , we know that , and . Moreover, according to Eq. (6), we have
thus leading to
This completes the proof. ∎
In what follows we define two discretized versions of Eq.s (5) and (6), namely the gradient descent (GD) and stochastic gradient descent (SGD) dynamics. They will serve as important intermediate objects for our proof.
- •
-
•
One-pass stochastic gradient descent: Under the same choice of the step size and initialization, and let be i.i.d. samples from , where
The iteration equations for one-pass SGD read:
(162) Note that Eq. (162) can also be written as:
D.1 Difference between GF and GD
For notational simplicity, we denote for and , and
Similarly, , and
Moreover, for and , we define the following two functionals:
and . Then, Eq.s (5) and (6) and Eq. (161) can be rewritten as
respectively. The lemma below will be used several times in the proof.
Lemma 5.
Denoting and . If and for all ( is any fixed absolute constant, for example, here we can take ), then we have
(163) | ||||
(164) | ||||
(165) |
where the constant only depends on the ’s. As a consequence, we obtain that
Proof.
First, by triangle inequality, we have
Second, using again triangle inequality, we deduce that
where follows from the inequality , which is a result of the following direct calculation:
This completes the proof of Lemma 5, since the “as a consequence” part follows naturally from the upper bounds obtained earlier. ∎
Lemma 6.
Following the notation and assumption of Lemma 5, we have
Proof.
By definition of the risk function and triangle inequality, we deduce that
This concludes the proof. ∎
First, let us define the error function
and the stop** time . For and , we have the following estimate:
For any , by Lemma 4 and 5 we have (denote , and notice that we can take since )
Using again Lemma 4 and 5, we obtain that
thus leading to
For , we have . Hence,
Applying Grönwall’s inequality yields
Therefore, for all and , we have
This proves , and consequently
which immediately implies that
Finally, with the aid of Lemma 6, we get the following upper bound on the difference between the risk of gradient flow and gradient descent:
To summarize, we have the following:
Theorem 4 (Difference between GF and GD).
There exists a constant that only depends on the ’s, such that for any and
the following holds for all :
(166) | ||||
(167) | ||||
(168) |
D.2 Difference between GD and SGD
The proof for this section is almost identical to Appendix C.5 in [33]. The only difference is that, here we need to verify that is an -sub-Gaussian random vector. This follows from the identity and Assumption A3. We thus obtain the following interpolation bound between GD and SGD:
Theorem 5 (Difference between GD and SGD).
There exists a constant that only depends on the ’s, such that for any and
the following happens with probability at least : For all , we have
(169) | ||||
(170) | ||||
(171) |
D.3 Difference between SGD and projected SGD
The aim of this section is to prove a coupling bound between the trajectory of SGD and that of projected SGD, thus finally leading to an upper bound on the difference between the risk of projected gradient flow and projected SGD. To begin with, let us fix and choose
as in Theorem 2, where is a large enough constant (to be determined later). Define
then for and , we have (note that here )
Denoting , we know from Assumption A3 that, conditioning on , is an -sub-Gaussian random vector. By well-known results on Euclidean norm of sub-Gaussian random vectors (see, e.g., [27]), we know that there exists a constant satisfying
Choosing and applying a union bound gives
Therefore, with probability at least , for all and , we have
The above bound also holds for the trajectory of SGD, namely after replacing with . Now, let us define the approximation error for and , then we get the following decomposition:
where has zero mean. With our choice of , one can verify that as long as , Lemma 7 is applicable to
Hence, we deduce from the definition of that
thus leading to the following estimate:
where is due to the fact that , and . According to the definition of , we obtain that
thus leading to (using the same argument as in the proof of Lemma 5)
and
Moreover, by (conditional) sub-Gaussianity of the ’s, we know that
Combining the above estimates, it then follows that
Using the same proof technique as in Appendix C.5 of [33], we conclude that
Similarly as in the proof of Theorem 4, we define
Then, for , we have
Proceeding with the same argument, it follows that
Therefore, we finally conclude that
Applying Grönwall’s inequality (discrete version) yields that
as long as with . Note that the above inequality holds for all with probability at least , which further implies that , and consequently
Applying again Lemma 6, we deduce that
Combining the above estimates gives the following:
Theorem 6 (Difference between SGD and projected SGD).
There exists a constant that only depends on the ’s, such that for any and
the following happens with probability at least : For all , we have
(172) | ||||
(173) | ||||
(174) |
Lemma 7.
Let , , where and . Then we have
Proof.
Using Taylor expansion, we know that
which implies
The proof is completed by noting that
∎
D.4 Proof of Theorem 3
By our assumption, we know that the canonical learning order holds up to level , and that
Then, according to Definition 1, there exists , such that for all and , one has
Moreover, from Section 4 we know that with probability at least over the i.i.d. initialization,
(175) |
where only depends on . Now we choose and . It then follows that
(176) |
According to Theorem 2, we know that with probability at least ,
(177) |
with . We now take
Then, by our choice of and , we know that . Further, taking
(178) |
we obtain that
(179) |
The above happens with probability . Hence, our conclusion follows naturally from the assumption .
Appendix E Counterexamples to the canonical learning order
E.1 Case 1: for some
For any fixed , we have
Moreover, the risk is always lower bounded by
where follows from orthogonality between and .
E.2 Case 2: for some
We consider the reduced mean-field equations (24):
Note that if , then for some continuous function . Denoting and , the above equation regarding the evolution of the ’s can be written as
where is a matrix-valued function satisfying
Using the similar a priori estimate as in the proof of Lemma 1, we can show that
for any finite time , which immediately implies that for . Therefore, we won’t be able to learn any component of with degree .
E.3 Case 3: for some
We may assume , and analyze the simplified ODE system (97), which reduces to
(180) |
We thus obtain the following equations:
(181) |
which means that for any ,
(182) |
Therefore, most of the neurons cannot evolve to the magnitude of in the process of learning the -th component, and therefore fails to provide an effective initialization for learning the next component .