The Benefits of Reusing Batches for Gradient Descent in Two-Layer Networks: Breaking the Curse of Information and Leap Exponents

Yatin Dandi Information Learning and Physics Laboratory, École Polytechnique Fédérale de Lausanne (EPFL) Statistical Physics Of Computation Laboratory, École Polytechnique Fédérale de Lausanne (EPFL) Emanuele Troiani Statistical Physics Of Computation Laboratory, École Polytechnique Fédérale de Lausanne (EPFL) Luca Arnaboldi Information Learning and Physics Laboratory, École Polytechnique Fédérale de Lausanne (EPFL) Luca Pesce Information Learning and Physics Laboratory, École Polytechnique Fédérale de Lausanne (EPFL)
Lenka Zdeborová
Statistical Physics Of Computation Laboratory, École Polytechnique Fédérale de Lausanne (EPFL)
Florent Krzakala Information Learning and Physics Laboratory, École Polytechnique Fédérale de Lausanne (EPFL)
Abstract

We investigate the training dynamics of two-layer neural networks when learning multi-index target functions. We focus on multi-pass gradient descent (GD) that reuses the batches multiple times and show that it significantly changes the conclusion about which functions are learnable compared to single-pass gradient descent. In particular, multi-pass GD with finite stepsize is found to overcome the limitations of gradient flow and single-pass GD given by the information exponent (Ben Arous et al., 2021) and leap exponent (Abbe et al., 2023) of the target function. We show that upon re-using batches, the network achieves in just two time steps an overlap with the target subspace even for functions not satisfying the staircase property (Abbe et al., 2021). We characterize the (broad) class of functions efficiently learned in finite time. The proof of our results is based on the analysis of the Dynamical Mean-Field Theory (DMFT). We further provide a closed-form description of the dynamical process of the low-dimensional projections of the weights, and numerical experiments illustrating the theory.

1 Introduction

Recent years have witnessed significant theoretical advancements in understanding the dynamics of training neural networks using gradient descent, to unravel the learning mechanisms of these networks, particularly how they adapt to data and identify pivotal features for predicting the target function. Significant progress has been made over the last few years in the case of two-layer networks, in large part thanks to the so-called mean-field analysis (Mei et al., 2018; Chizat and Bach, 2018; Rotskoff and Vanden-Eijnden, 2022; Sirignano and Spiliopoulos, 2020)). Most of the theoretical efforts, in particular, focused either on one-pass optimization algorithms, where each iteration involves a new fresh batch of data, or to the limit of gradient flow in the population loss. For high-dimensional synthetic Gaussian data, and a low dimensional target function (a multi-index model), the class of functions efficiently learned by these one-pass methods has been thoroughly analyzed in a series of recent works, and have been shown to be limited by the so-called information exponent (Ben Arous et al., 2021) and leap exponent (Abbe et al., 2022; 2023) of the target. These analyses have sparked many follow-up theoretical works over the last few months, see, e.g. (Damian et al., 2022; 2023; Dandi et al., 2023; Bietti et al., 2023; Ba et al., 2023; Moniri et al., 2023; Mousavi-Hosseini et al., 2023; Zweig and Bruna, 2023).

However, a common practice in machine learning involves repeatedly traversing the same mini-batch of data. This paper, therefore, aims to go beyond the constraints of single-pass algorithms and to evaluate whether multiple-pass training overcomes these inherent flaws of single-pass methods. We focus on gradient descent, certainly the most straightforward procedure in this family. The theoretical framework we use to prove our main results is based on Dynamical Mean Field Theory (DMFT), which was developed in the statistical physics community (Sompolinsky et al., 1988) to analyze correlated systems, and recently made rigorous in the context of high dimensional machine-learning problems in (Celentano et al., 2021; Gerbelot et al., 2023).

Our findings significantly alter the prevailing narrative in the literature. We demonstrate that gradient descent surpasses the limitations imposed by the information and leap exponents, achieving a positive correlation with the target function for a much broader class than the staircase functions (Abbe et al., 2021), even with minimal (that is, two) repetition of data batches. We characterize the (broad) class of functions efficiently learned in finite time. Among the exceptions are symmetric functions that remain a challenge due to their extended symmetry-breaking times (a natural feature of physical dynamics (Bouchaud et al., 1998)).

With independent Gaussian datapoints as inputs, both one-pass SGD and gradient flow on population loss result in pre-activations remaining distributed as Gaussian random variables. This Gaussianity underlies the analysis of such settings, starting from the seminal work of Saad and Solla (1995). In contrast, upon re-using batches, the pre-activations develop non-Gaussian components correlated with the targets. This non-Gaussianity is a crucial aspect of the stark contrast in the learning of directions compared to the one-pass setting. While we establish our results for discrete steps with extensive batches of size 𝒪(d)𝒪𝑑\mathcal{O}(d)caligraphic_O ( italic_d ) where d𝑑ditalic_d denotes the input dimension, we expect our conclusions about the learning of new directions to also hold while performing gradient flow on the empirical loss, since such a setup would lead to the development of similar non-Gaussian components in the pre-activations, in contrast to gradient flow on the population loss where the pre-activations remain Gaussian.

Our results demonstrate that contrary to the common wisdom “the more data the better”, gradient descent on the same batch can surpass one-pass SGD on different batches, even when one-pass SGD utilizes a larger number of datapoints. More generally, we believe that our analysis provides insights into incremental learning of features in the presence of correlations between datapoints across batches, which is typical in most high dimensional datasets having a small number of “latent” factors. Our conclusions follow from a rigorous mathematical proof rooted in DMFT, which we also use to provide an analytic description of the dynamic processes of low-dimensional weight projections. This analysis has interest on its own.

2 Setting and main contributions

Let 𝒟𝒟\mathcal{D}caligraphic_D be the set of data {𝐳νd,yν}ν[n]subscriptsubscript𝐳𝜈superscript𝑑subscript𝑦𝜈𝜈delimited-[]𝑛\{\mathbf{z}_{\nu}\in{\mathbb{R}}^{d},y_{\nu}\}_{\nu\in[n]}{ bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_ν ∈ [ italic_n ] end_POSTSUBSCRIPT. The input data 𝐳νsubscript𝐳𝜈\mathbf{z}_{\nu}bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT are taken as a standard i.i.d. Gaussian, while the labels are generated by a teacher, or target, function yν=f(𝐳ν)subscript𝑦𝜈superscript𝑓subscript𝐳𝜈y_{\nu}=f^{\star}(\mathbf{z}_{\nu})italic_y start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT = italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ). We consider multi-index target function, dependent on a low-dimensional subspace of the input space:

yν=f(𝐳ν)=g(W𝐳νd),𝐳ν𝒩(0,𝟙d),formulae-sequencesubscript𝑦𝜈superscript𝑓subscript𝐳𝜈superscript𝑔superscript𝑊subscript𝐳𝜈𝑑similar-tosubscript𝐳𝜈𝒩0subscript1𝑑\displaystyle y_{\nu}=f^{\star}(\mathbf{z}_{\nu})=g^{\star}\left(\frac{W^{% \star}\mathbf{z}_{\nu}}{\sqrt{d}}\right),\,\mathbf{z}_{\nu}\sim{\cal N}(0,{% \mathbbm{1}}_{d})\,,italic_y start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT = italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) = italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( divide start_ARG italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) , bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , blackboard_1 start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) , (1)

We assume for convenience that W={𝐰r}r[k]k×dsuperscript𝑊subscriptsubscriptsuperscript𝐰𝑟𝑟delimited-[]𝑘superscript𝑘𝑑W^{\star}=\{\mathbf{w}^{\star}_{r}\}_{r\in[k]}\in\mathbb{R}^{k\times d}italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = { bold_w start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_r ∈ [ italic_k ] end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_k × italic_d end_POSTSUPERSCRIPT is normalized row-wise on the sphere 𝒮d1(d)subscript𝒮𝑑1𝑑\mathcal{S}_{d-1}(\sqrt{d})caligraphic_S start_POSTSUBSCRIPT italic_d - 1 end_POSTSUBSCRIPT ( square-root start_ARG italic_d end_ARG ), with orthogonal weights i.e 𝐰l,𝐰m=0subscriptsuperscript𝐰𝑙subscriptsuperscript𝐰𝑚0\langle\mathbf{w}^{\star}_{l},\mathbf{w}^{\star}_{m}\rangle=0⟨ bold_w start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , bold_w start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⟩ = 0 for lm[r]𝑙𝑚delimited-[]𝑟l\neq m\in[r]italic_l ≠ italic_m ∈ [ italic_r ].

The data are handled to a two-layer network (the student) f𝑓fitalic_f with first layer weights W={𝐰i}i[p]p×d𝑊subscriptsubscript𝐰𝑖𝑖delimited-[]𝑝superscript𝑝𝑑W=\{\mathbf{w}_{i}\}_{i\in[p]}\in\mathbb{R}^{p\times d}italic_W = { bold_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i ∈ [ italic_p ] end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_p × italic_d end_POSTSUPERSCRIPT (p𝑝pitalic_p is the number of neurons in the hidden layer) and second layer weights 𝐚p𝐚superscript𝑝\mathbf{a}\in\mathbb{R}^{p}bold_a ∈ blackboard_R start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT with an activation function σ𝜎\sigmaitalic_σ, that is:

f(W𝐳d)=j=1pajσ(𝐰j,𝐳d)𝑓𝑊𝐳𝑑superscriptsubscript𝑗1𝑝subscript𝑎𝑗𝜎subscript𝐰𝑗𝐳𝑑\displaystyle f\left(\frac{W\mathbf{z}}{\sqrt{d}}\right)=\sum_{j=1}^{p}a_{j}% \sigma{\left(\frac{\langle\mathbf{w}_{j},\mathbf{z}\rangle}{\sqrt{d}}\right)}italic_f ( divide start_ARG italic_W bold_z end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) = ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_σ ( divide start_ARG ⟨ bold_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , bold_z ⟩ end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) (2)

Our main goal is to analyze the dynamics of gradient descent that minimizes the empirical Mean Squared Error (MSE) loss \mathcal{L}caligraphic_L at time t[T]𝑡delimited-[]𝑇t\in[T]italic_t ∈ [ italic_T ]:

empricalsubscriptemprical\displaystyle{\cal R}_{\rm emprical}caligraphic_R start_POSTSUBSCRIPT roman_emprical end_POSTSUBSCRIPT =ν=1n(W(t)𝐳νd,f(𝐳ν))absentsuperscriptsubscript𝜈1𝑛superscript𝑊𝑡subscript𝐳𝜈𝑑superscript𝑓subscript𝐳𝜈\displaystyle=\sum_{\nu=1}^{n}\mathcal{L}\left(\frac{W^{(t)}\mathbf{z}_{\nu}}{% \sqrt{d}},f^{\star}(\mathbf{z}_{\nu})\right)= ∑ start_POSTSUBSCRIPT italic_ν = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT caligraphic_L ( divide start_ARG italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG , italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) ) (3)
=12ν=1n[g(W𝐳νd)f(W(t)𝐳νd)]2,absent12superscriptsubscript𝜈1𝑛superscriptdelimited-[]superscript𝑔superscript𝑊subscript𝐳𝜈𝑑𝑓superscript𝑊𝑡subscript𝐳𝜈𝑑2\displaystyle=\frac{1}{2}\sum_{\nu=1}^{n}\left[g^{\star}\left(\frac{W^{\star}% \mathbf{z}_{\nu}}{\sqrt{d}}\right)-f\left(\frac{W^{(t)}\mathbf{z}_{\nu}}{\sqrt% {d}}\right)\right]^{2}\,,= divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_ν = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT [ italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( divide start_ARG italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) - italic_f ( divide start_ARG italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ,

in the high-dimensional limit where n,d𝑛𝑑n,d\!\to\!\inftyitalic_n , italic_d → ∞ with n/d=α=Θ(1)𝑛𝑑𝛼Θ1n/d\!=\!\alpha\!=\!\Theta(1)italic_n / italic_d = italic_α = roman_Θ ( 1 ). We use a common assumption that is amenable to rigorous theoretical guarantees: we keep the second layer weights 𝐚𝐚\mathbf{a}bold_a fixed at initialization. For convenience, we further impose the constraint of symmetric initialization common in such analyses (Dandi et al., 2023; Damian et al., 2022). Concretely, we assume that the number of neurons p𝑝pitalic_p is even and the weights satisfy at initialization:

ai=api+1,𝐰i0=𝐰pi+10for all i[p/2],formulae-sequencesubscript𝑎𝑖subscript𝑎𝑝𝑖1formulae-sequencesuperscriptsubscript𝐰𝑖0superscriptsubscript𝐰𝑝𝑖10for all 𝑖delimited-[]𝑝2a_{i}=-a_{p-i+1},\quad\mathbf{w}_{i}^{0}=\mathbf{w}_{p-i+1}^{0}\quad\text{for % all }i\in[p/2],italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = - italic_a start_POSTSUBSCRIPT italic_p - italic_i + 1 end_POSTSUBSCRIPT , bold_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT = bold_w start_POSTSUBSCRIPT italic_p - italic_i + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT for all italic_i ∈ [ italic_p / 2 ] , (4)

which ensures that the output f(W(t)𝐳νd)𝑓superscript𝑊𝑡subscript𝐳𝜈𝑑f\left(\frac{W^{(t)}\mathbf{z}_{\nu}}{\sqrt{d}}\right)italic_f ( divide start_ARG italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) equals 00 at initialization. For i,[p/2]i,\in[p/2]italic_i , ∈ [ italic_p / 2 ], the weights are initialized as ai1p𝒩(0,1),𝐰i(0)𝒩(0,𝟙d)formulae-sequencesimilar-tosubscript𝑎𝑖1𝑝𝒩01similar-tosuperscriptsubscript𝐰𝑖0𝒩0subscript1𝑑a_{i}\sim\frac{1}{p}\mathcal{N}(0,1),\mathbf{w}_{i}^{(0)}\!\!\sim\!\!\mathcal{% N}(0,{\mathbbm{1}}_{d})italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ divide start_ARG 1 end_ARG start_ARG italic_p end_ARG caligraphic_N ( 0 , 1 ) , bold_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ∼ caligraphic_N ( 0 , blackboard_1 start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) Subsequently, with 𝐚𝐚\mathbf{a}bold_a fixed, the first layer weights W={𝐰i}i[p]𝑊subscriptsubscript𝐰𝑖𝑖delimited-[]𝑝W=\{\mathbf{w}_{i}\}_{i\in[p]}italic_W = { bold_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i ∈ [ italic_p ] end_POSTSUBSCRIPT are learned using gradient descent, producing the following sequence of iterates up to a final time T𝑇Titalic_T:

𝐰i(t+1)=(1ηλ)𝐰i(t)ην=1n𝐰i(t)(W(t)𝐳νd,f(𝐳ν))superscriptsubscript𝐰𝑖𝑡11𝜂𝜆subscriptsuperscript𝐰𝑡𝑖𝜂superscriptsubscript𝜈1𝑛subscriptsubscriptsuperscript𝐰𝑡𝑖superscript𝑊𝑡subscript𝐳𝜈𝑑superscript𝑓subscript𝐳𝜈\displaystyle\mathbf{w}_{i}^{(\!t+1\!)}\!=\!(1\!-\!\eta\lambda)\mathbf{w}^{(\!% t\!)}_{i}\!-\!\eta\!\sum_{\nu=1}^{n}\!\nabla_{\!\!\mathbf{w}^{(\!t\!)}_{i}}% \mathcal{L}\!\left(\!\frac{W^{(t)}\mathbf{z}_{\nu}}{\sqrt{d}},f^{\star}(% \mathbf{z}_{\nu})\!\!\right)bold_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = ( 1 - italic_η italic_λ ) bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_η ∑ start_POSTSUBSCRIPT italic_ν = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L ( divide start_ARG italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG , italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) ) (5)

where η𝜂\eta\!\in\!\mathbb{R}italic_η ∈ blackboard_R is the learning rate and λ𝜆\lambda\in\mathbb{R}italic_λ ∈ blackboard_R is the explicit regularisation. We may refer to these steps as the representation learning steps, in which the first layer weights learn how to adapt to the low dimensional structure identified by the teacher subspace Wsuperscript𝑊W^{\star}italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT.

Our main contributions in this paper are the following:

  • We characterize the class of multi-index targets that can be learned efficiently by two-layer networks trained with a finite number of iterations of gradient descent in the high dimensional limit (d)𝑑(d\to\infty)( italic_d → ∞ ) with large batch sizes (n=αd,α=O(1)formulae-sequence𝑛𝛼𝑑𝛼𝑂1n=\alpha d,\alpha=O(1)italic_n = italic_α italic_d , italic_α = italic_O ( 1 )). We establish a strong separation between what can be learned with one-pass algorithms (that use new fresh batches at every step) and multi-pass gradient approaches that can use the same batch many times (see Figs. 1 and 2 for examples).

  • We show that while both gradient flow (Bietti et al., 2023) and single-pass algorithms suffer from the curse of the information exponent (Ben Arous et al., 2021), and are limited to staircase learning (Abbe et al., 2023), requiring a diverging number of iterations for non-staircase functions, some of these problems become trivial when allowing reusing samples multiple times, and features can be learned in just T=2𝑇2T=2italic_T = 2 iterations. This disproves, in particular, a recent conjecture by (Abbe et al., 2023).

  • The simplest examples of directions that cannot be learned in a finite number of steps relate to symmetries in the target function. This includes phase retrieval (Maillard et al., 2020) or the specialization transition in committees, as discussed in the Bayes optimal approaches of single-index (Barbier et al., 2019) and multi-index (Aubin et al., 2019) models.

  • The proof of our results is based on the concept of “hidden progress”, and crucially uses the rigorous Dynamical Mean Field Theory (DMFT) (Celentano et al., 2021; Gerbelot et al., 2023). This has an interest on its own as it provides a sharp example of how DMFT can help to understand batch reusing to go beyond the current state-of-the-art results.

  • Finally, we use DMFT to provide a closed-form description of the dynamics of gradient descent for two-layer nets. Kee** track of the correlations induced by re-using the same batch leads to a set of integro-differential equations. We provide rigorous theoretical guarantees in the correlated samples regime without assuming the resampling of a fresh new batch for each iteration of the algorithm. We corroborate the theoretical claims with numerical simulations (See https://github.com/IdePHICS/benefit-reusing-batch. ).

Other Related works –

A major issue in machine learning theory is figuring out how well two-layer neural networks adapt to low-dimensional structures in the data. Different results have tightly characterized the limitations of networks in which the first layer of weights W𝑊Witalic_W is kept fixed, i.e. equivalent to kernel approaches (Dietrich et al., 1999; Ghorbani et al., 2019; 2020; Bordelon et al., 2020; Loureiro et al., 2021; Cui et al., 2021). This class of learning algorithms, although amenable to theoretical analysis, is unable to learn features in the data. Therefore, one central avenue of research in this context is to understand the efficiency of the representation learning (or feature learning) when training with gradient-based algorithms to overcome the limitations of the kernel regime. Sharp separation results between the performance of neural networks at initialization (random features) and trained with only one step of gradient descent (with a large learning rate) have been offered (Ba et al., 2022; Damian et al., 2022; Dandi et al., 2023).

The class of features efficiently learned with multiple steps of one-pass SGD with one sample per batch is characterized by the information exponent (IEIE\rm{IE}roman_IE) (Ben Arous et al., 2021) of the target function. In the context of single-index learning, denoting \ellroman_ℓ the IEIE\rm{IE}roman_IE of the target, the algorithm needs T=O(d1)𝑇𝑂superscript𝑑1T\!=\!O(d^{\ell-1})italic_T = italic_O ( italic_d start_POSTSUPERSCRIPT roman_ℓ - 1 end_POSTSUPERSCRIPT ) steps to perform weak recovery of the teacher direction, i.e., obtaining an overlap between learned weights and 𝐰superscript𝐰\mathbf{w}^{\star}bold_w start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT better than random guessing (Ben Arous et al., 2021). Recently, these results have been improved up to the Correlational Statistical Query (CSQ) lower bound of dmax(2,1)superscript𝑑max21d^{{{\rm max}(\frac{\ell}{2}},1)}italic_d start_POSTSUPERSCRIPT roman_max ( divide start_ARG roman_ℓ end_ARG start_ARG 2 end_ARG , 1 ) end_POSTSUPERSCRIPT, by considering an appropriate smoothing of the loss (Damian et al., 2023). A generalization to large batch one-pass SGD is in (Dandi et al., 2023).

Similarly, multi-index feature learning presents an unavoidable computational barrier for one-pass algorithms. (Abbe et al., 2021) first characterizes a hierarchical picture of learning in the Boolean data case: informally, the features efficiently learned at each step of the one-pass algorithm need to be linearly connected with the previously learned features. This concept is formalized by the definition of the staircase property (Abbe et al., 2021). This hierarchical picture of learning is extended to large batches in the SGD and non-Boolean data in (Abbe et al., 2022; 2023; Dandi et al., 2023). Moreover, (Abbe et al., 2023) conjecture that re-using the batch can reduce the sample complexity of the target with leap \ellroman_ℓ only up to O(dmax(2,1))𝑂superscript𝑑max21O(d^{{\rm max}(\frac{\ell}{2},1)})italic_O ( italic_d start_POSTSUPERSCRIPT roman_max ( divide start_ARG roman_ℓ end_ARG start_ARG 2 end_ARG , 1 ) end_POSTSUPERSCRIPT ), corresponding to the lower bound for Correlational Statistical Query (CSQ) algorithms.

We disprove this conjecture and show that the sample complexity for a large class of functions can be reduced to O(d)𝑂𝑑O(d)italic_O ( italic_d ) independently of the leap exponent \ellroman_ℓ. More generally, our results show that CSQ lower bounds and the notions of staircase property and information exponent are limited to online-SGD on Gaussian/Boolean data, and do not describe the class of functions inherently easy or hard to learn by gradient-based methods. We also show that learning non-even single-index functions does not require techniques such as spectral warm-start (Chen and Meka, 2020).

Dynamical Mean Field Theory has a long history in statistical physics. Early theories of dynamics in complex systems were pioneered in soft spin glass models (Sompolinsky and Zippelius, 1981) and toy models of random feature deep networks (Sompolinsky et al., 1988). The DMFT approach used in this paper was first proposed as a way to study “hard spins” in spin glass models (Eissfeller and Opper, 1992; 1994), and was later generalized to “soft spins” (Cugliandolo, 2003) and more realistic models in condensed matter (Georges et al., 1996). In the context of learning, DMFT was used for optimization problems (Mannelli et al., 2019a; b; 2020; Mannelli and Urbani, 2021) and for analyzing the behavior and the noise of gradient-based algorithms (Mignacco et al., 2020; 2021; Mignacco and Urbani, 2022). From the mathematics point of view, these DMFT equations were first proven rigorously in the seminal work of (Ben Arous et al., 1997) in the context of spin glasses. Important progress was achieved recently with rigorous proofs of the DMFT equations for multi-index models (Celentano et al., 2021; Gerbelot et al., 2023) that we use to prove our main results.

Refer to caption
Figure 1: One-pass and multi-pass GD for single-index models – The overlap |𝐰,𝐰^d|superscript𝐰^𝐰𝑑\left|\frac{\langle\mathbf{w}^{\star},\hat{\mathbf{w}}\rangle}{d}\right|| divide start_ARG ⟨ bold_w start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , over^ start_ARG bold_w end_ARG ⟩ end_ARG start_ARG italic_d end_ARG | between the learned weight and the target/teacher direction, is plotted as a function of the iteration time of both single-pass (red) and multi-pass (blue) GD. Continuous lines are given theory, dots are simulations. Left: Easy finite-T learnable single-index target g=tanhsuperscript𝑔g^{\star}\!=\!\tanhitalic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = roman_tanh: both one-pass and multi-pass GD obtain positive correlation after a finite number of iterations as the information exponent of the target is =11\ell\!=\!1roman_ℓ = 1. Center: Multi-pass finite-T learnable single-index target: g=He3superscript𝑔subscriptHe3g^{\star}\!=\!\mathrm{He}_{3}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = roman_He start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT. Multi-pass GD achieves a non-zero correlation in just two steps, but the one-pass algorithm learns nothing. Right: Finite-time nonlearnable single-index targets g=He4superscript𝑔subscriptHe4g^{\star}\!=\!\mathrm{He}_{4}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = roman_He start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT; the target function is even and thus, as stated in Thm. 3.2, breaking this symmetry is hard in finite number of steps, resulting in a vanishing correlation with the teacher direction 𝐰superscript𝐰\mathbf{w}^{\star}bold_w start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT for both algorithms in any finite time. (Simulation are averaged over 32323232 runs, d=5000𝑑5000d=5000italic_d = 5000, with σ=relu𝜎relu\sigma=\rm reluitalic_σ = roman_relu, n=3d𝑛3𝑑n=3ditalic_n = 3 italic_d, p=1𝑝1p=1italic_p = 1, η=0.1𝜂0.1\eta=0.1italic_η = 0.1).
Refer to caption
Figure 2: One-pass and multi-pass GD for multi-index models – The overlaps between the student weights along the first direction learned, namely 𝐂[f]𝐂delimited-[]superscript𝑓\mathbf{C}[f^{\star}]bold_C [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ], and its orthogonal, is plotted versus the number of iterations for three different classes of functions. Left: Easy finite-T learnable multi-index target both the algorithms learn all the relevant directions when an ”easy” function is used as a target (here (p=8𝑝8p=8italic_p = 8)). Center: Multi-pass finite-T learnable multi-index target both the algorithms learn the first Hermite direction 𝐂[f]𝐂delimited-[]superscript𝑓\mathbf{C}[f^{\star}]bold_C [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ] but only multi-pass SGD achieve a non-null correlation in the orthogonal. This illustrates how reusing samples allows us to surpass the staircase limitation of single-pass approaches (p=2𝑝2p=2italic_p = 2). Right: Finite-time non-learnable multi-index target neither of the two algorithm can learn 𝐂[f]𝐂superscriptdelimited-[]superscript𝑓bottom\mathbf{C}[f^{\star}]^{\bot}bold_C [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊥ end_POSTSUPERSCRIPT with this target (p=2𝑝2p=2italic_p = 2). (Simulation are averaged over 32323232 runs, d=5000𝑑5000d\!=\!5000italic_d = 5000, with σ=relu𝜎relu\sigma\!=\!\rm reluitalic_σ = roman_relu, n=3d𝑛3𝑑n=3ditalic_n = 3 italic_d, η=0.1𝜂0.1\eta\!=\!0.1italic_η = 0.1).

3 Statement of the results

Here, we introduce the main results covering the theoretical learning guarantees with gradient descent and contrast them with the known one-pass results. We exploit the rigorous DMFT construction to prove the first key result: two-layer networks efficiently learn a large class of multi-index targets in only T=2𝑇2T\!=\!2italic_T = 2 iterations, breaking the curse of one-pass algorithms dictated by the information and leap exponents.

3.1 Finite-T Learnable and Non-learnable directions

We first identify which target directions are hard to learn for multi-pass gradient descent. Define Usuperscript𝑈U^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT to be the subspace spanned by the rows of the target weights Wsuperscript𝑊W^{\star}italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. The “hard” directions are the ones where any transformation of the output f(𝐳)superscript𝑓𝐳f^{\star}(\mathbf{z})italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) does not lead to a linear correlation along the direction. We now define the subspace of such directions:

Definition 3.1.

We define Psuperscript𝑃P^{\star}italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT as the subspace of directions 𝐯Usuperscript𝐯superscript𝑈\mathbf{v}^{\star}\in U^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT such that for any polynomial F::𝐹F:\mathbb{R}\rightarrow\mathbb{R}italic_F : blackboard_R → blackboard_R with coefficients in \mathbb{R}blackboard_R, the following condition is satisfied:

𝔼𝐳[F(f(𝐳))𝐯,𝐳]=0,subscript𝔼𝐳delimited-[]𝐹superscript𝑓𝐳superscript𝐯𝐳0\mathbb{E}_{\mathbf{z}}\left[F(f^{\star}(\mathbf{z}))\langle\mathbf{v}^{\star}% ,\mathbf{z}\rangle\right]=0,blackboard_E start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT [ italic_F ( italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) ) ⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z ⟩ ] = 0 , (6)

Similarly, we denote by Asuperscript𝐴A^{\star}italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, the subspace of directions where the above condition is satisfied for all real-valued analytic functions F𝐹Fitalic_F.

One part of our main result shows that directions in Asuperscript𝐴A^{\star}italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT cannot be learned even by re-using batches of size 𝒪(d)𝒪𝑑\mathcal{O}(d)caligraphic_O ( italic_d ) in a finite number of gradient steps. Furthermore, under suitable conditions on σ𝜎\sigmaitalic_σ and 𝐚𝐚\mathbf{a}bold_a (discussed in Theorem 3.2 and Appendix A.5), we show that after two gradient steps, the first layer learns all directions in the complement Psubscriptsuperscript𝑃perpendicular-toP^{\star}_{\perp}italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT. We are now ready to state our main result:

Theorem 3.2.

Suppose that n/d=α>0𝑛𝑑𝛼0n/d=\alpha>0italic_n / italic_d = italic_α > 0. Let 𝐯Psuperscript𝐯subscriptsuperscript𝑃perpendicular-to\mathbf{v}^{\star}\in P^{\star}_{\perp}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT denote an arbitrary direction in the orthogonal complement of the subspace Psuperscript𝑃P^{\star}italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT defined in definition 3.1 with norm d𝑑\sqrt{d}square-root start_ARG italic_d end_ARG and a fixed representation in the basis Wsuperscript𝑊W^{\star}italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Suppose further that the activation function σ𝜎\sigmaitalic_σ is analytic, with polynomially bounded derivatives satisfying 𝔼z𝒩(0,1)[σ(z)]0subscript𝔼similar-to𝑧𝒩01delimited-[]superscript𝜎𝑧0\mathbb{E}_{z\sim\mathcal{N}(0,1)}\left[\sigma^{\prime}(z)\right]\neq 0blackboard_E start_POSTSUBSCRIPT italic_z ∼ caligraphic_N ( 0 , 1 ) end_POSTSUBSCRIPT [ italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_z ) ] ≠ 0 and σk(0)0ksuperscript𝜎𝑘00for-all𝑘\sigma^{k}(0)\neq 0\ \forall k\in\mathbb{N}italic_σ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( 0 ) ≠ 0 ∀ italic_k ∈ blackboard_N. Then, for any gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT with derivatives bounded by polynomials, there exist η>0,λ>0formulae-sequence𝜂0𝜆0\eta>0,\lambda>0italic_η > 0 , italic_λ > 0 such that almost surely over the choice of 𝐚𝐚\mathbf{a}bold_a, we have:

𝐖(2)𝐯d=Θd(1),delimited-∥∥superscript𝐖2superscript𝐯𝑑subscriptΘ𝑑1\left\lVert\frac{\mathbf{W}^{(2)}\mathbf{v}^{\star}}{d}\right\rVert=\Theta_{d}% (1)\,,∥ divide start_ARG bold_W start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_ARG start_ARG italic_d end_ARG ∥ = roman_Θ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( 1 ) , (7)

with high probability as n,d𝑛𝑑n,d\rightarrow\inftyitalic_n , italic_d → ∞. Furthermore, for large enough p𝑝pitalic_p, 𝐖(2)superscript𝐖2\mathbf{W}^{(2)}bold_W start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT asymptotically spans Psubscriptsuperscript𝑃perpendicular-toP^{\star}_{\perp}italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT:

inf𝐯P𝐖(2)𝐯d=Θd(1),subscriptinfimumsuperscript𝐯subscriptsuperscript𝑃perpendicular-todelimited-∥∥superscript𝐖2superscript𝐯𝑑subscriptΘ𝑑1\inf_{\mathbf{v}^{\star}\in P^{\star}_{\perp}}\left\lVert\frac{\mathbf{W}^{(2)% }\mathbf{v}^{\star}}{d}\right\rVert=\Theta_{d}(1)\,,roman_inf start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ divide start_ARG bold_W start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_ARG start_ARG italic_d end_ARG ∥ = roman_Θ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( 1 ) , (8)

with high probability as n,d𝑛𝑑n,d\rightarrow\inftyitalic_n , italic_d → ∞. In other words, directions 𝐯Psuperscript𝐯subscriptsuperscript𝑃perpendicular-to\mathbf{v}^{\star}\in P^{\star}_{\perp}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT are learned in T=2𝑇2T=2italic_T = 2 gradient steps.

Suppose, however that the teacher subspace U=Asuperscript𝑈superscript𝐴U^{\star}\!=\!A^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, then:

sup𝐯U𝐖(t)𝐯d=od(1),subscriptsupremumsuperscript𝐯superscript𝑈delimited-∥∥superscript𝐖𝑡superscript𝐯𝑑subscript𝑜𝑑1\sup_{\mathbf{v}^{\star}\in U^{\star}}\left\lVert\frac{\mathbf{W}^{(t)}\mathbf% {v}^{\star}}{d}\right\rVert=o_{d}(1)\,,roman_sup start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∥ divide start_ARG bold_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_ARG start_ARG italic_d end_ARG ∥ = italic_o start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( 1 ) , (9)

with high probability as n,d𝑛𝑑n,d\!\rightarrow\!\inftyitalic_n , italic_d → ∞, for any finite time t𝑡titalic_t. Thus, none of the directions are learned in any finite number of GD steps.

The proof is based on the analysis of the DMFT equations discussed in Sec. 4.2, is given in App. A, and we provide an informal heuristic derivation in sec. 4.1. While the above negative result requires all directions in Usuperscript𝑈U^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT to be in Asuperscript𝐴A^{\star}italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and thus in Psuperscript𝑃P^{\star}italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, in App A.6, we discuss the more general setup where learning of certain directions in Psubscriptsuperscript𝑃perpendicular-toP^{\star}_{\perp}italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT can affect the learning of directions in Usuperscript𝑈U^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT in subsequent timesteps.

When the expectation in Equation 6 is non-zero for F𝐹Fitalic_F being the identity map**, i.e. F=id𝐹idF=\mathrm{id}italic_F = roman_id, 𝐯superscript𝐯\mathbf{v}^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is in-fact learned in the first gradient step (Ba et al., 2022; Dandi et al., 2023) or through online SGD (Ben Arous et al., 2022; Abbe et al., 2023). We discuss this further in Section 3.3.

Our analysis reveals that the effect of re-using batches is to implicitly transform the output in the subsequent steps, allowing a larger set of directions to be learned. However, for directions in Asuperscript𝐴A^{\star}italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, such transformations are still insufficient.

3.2 Characterization of hard directions through symmetries

While Definition 3.1 characterizes the subspace of hard directions Asuperscript𝐴A^{\star}italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, it requires checking that the equality in Equation 6 holds for any real analytic transformation F𝐹Fitalic_F. We now show that a sufficient condition for 𝐯Asuperscript𝐯superscript𝐴\mathbf{v}^{\star}\in A^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is for fsuperscript𝑓f^{\star}italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT to possess certain symmetries along 𝐯superscript𝐯\mathbf{v}^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. This leads us to identify subspaces of hard directions, contained in Asuperscript𝐴A^{\star}italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, linked to symmetries w.r.t certain transformations. We characterize such subspaces below. The simplest such symmetry is defined through reflection along 𝐯superscript𝐯\mathbf{v}^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT:

Definition 3.3.

For any direction 𝐯0dsuperscript𝐯0superscript𝑑\mathbf{v}^{\star}\neq 0\in\mathbb{R}^{d}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ≠ 0 ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, let R𝐯subscript𝑅superscript𝐯R_{\mathbf{v}^{\star}}italic_R start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT denote the reflection operator along 𝐯superscript𝐯\mathbf{v}^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, i.e. R𝐯=𝐈21𝐯2𝐯𝐯Tsubscript𝑅superscript𝐯𝐈21superscriptdelimited-∥∥superscript𝐯2superscript𝐯superscriptsuperscript𝐯𝑇R_{\mathbf{v}^{\star}}=\mathbf{I}-2\frac{1}{\left\lVert\mathbf{v}^{\star}% \right\rVert^{2}}\mathbf{v^{\star}}\mathbf{v^{\star}}^{T}italic_R start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = bold_I - 2 divide start_ARG 1 end_ARG start_ARG ∥ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT. We say that a direction 𝐯Usuperscript𝐯superscript𝑈\mathbf{v}^{\star}\in U^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is even-symmetric w.r.t fsuperscript𝑓f^{\star}italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT if for any 𝐳d𝐳superscript𝑑\mathbf{z}\in\mathbb{R}^{d}bold_z ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT:

f(R𝐯𝐳)=f(𝐳)superscript𝑓subscript𝑅superscript𝐯𝐳superscript𝑓𝐳f^{\star}(R_{\mathbf{v}^{\star}}\mathbf{z})=f^{\star}(\mathbf{z})italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_R start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_z ) = italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) (10)

We denote by Esuperscript𝐸E^{\star}italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT the subspace spanned by all even-symmetric directions in Usuperscript𝑈U^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT.

It is straightforward to see that any 𝐯Esuperscript𝐯superscript𝐸\mathbf{v}^{\star}\in E^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT leads to Equation 6 being satisfied for any transformation F𝐹Fitalic_F, since 𝐯superscript𝐯\mathbf{v}^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT remains even w.r.t the function F(f())𝐹superscript𝑓F\left(f^{\star}(\cdot)\right)italic_F ( italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( ⋅ ) ). Therefore, EAsuperscript𝐸superscript𝐴E^{\star}\subseteq A^{\star}italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT However, the set of non-learnable directions can be larger due to the presence of additional symmetries. We now define such a larger subspace of hard directions arising due to a symmetry w.r.t reflections along 𝐯superscript𝐯\mathbf{v}^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT coupled with orthogonal transformations along the orthogonal subspace:

Definition 3.4.

For any direction 𝐯0superscript𝐯0\mathbf{v}^{\star}\neq 0bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ≠ 0 in Usuperscript𝑈U^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, let R𝐯subscript𝑅superscript𝐯R_{\mathbf{v}^{\star}}italic_R start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT be as defined in Definition 3.3. Let Osubscript𝑂perpendicular-toO_{\perp}italic_O start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT be a matrix in the orthogonal group on the d1𝑑1d-1italic_d - 1 dimensional subspace {𝐯}subscriptsuperscript𝐯perpendicular-to\{\mathbf{v}^{\star}\}_{\perp}{ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT i.e the orthogonal complement of the linear subspace spanned by 𝐯superscript𝐯\mathbf{v}^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. We say that a direction 𝐯superscript𝐯\mathbf{v}^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is orthogonally-even-symmetric w.r.t fsuperscript𝑓f^{\star}italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, if there exists an OO({v})subscript𝑂perpendicular-to𝑂subscriptsuperscript𝑣perpendicular-toO_{\perp}\in O(\{v^{\star}\}_{\perp})italic_O start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT ∈ italic_O ( { italic_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT ), such that for any 𝐳d𝐳superscript𝑑\mathbf{z}\in\mathbb{R}^{d}bold_z ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT:

f(OR𝐯𝐳)=f(𝐳)superscript𝑓subscript𝑂perpendicular-tosubscript𝑅superscript𝐯𝐳superscript𝑓𝐳f^{\star}(O_{\perp}R_{\mathbf{v}^{\star}}\mathbf{z})=f^{\star}(\mathbf{z})italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_O start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_z ) = italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) (11)

We denote by OE𝑂superscript𝐸OE^{\star}italic_O italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT the subspace spanned by all orthogonally-even-symmetric directions in Usuperscript𝑈U^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT.

By setting Osubscript𝑂perpendicular-toO_{\perp}italic_O start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT as the identity map** in the above definition, we recover the condition for 𝐯Esuperscript𝐯superscript𝐸\mathbf{v}^{\star}\in E^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Therefore, we have that EOEsuperscript𝐸𝑂superscript𝐸E^{\star}\subseteq OE^{\star}italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ italic_O italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. While OE𝑂superscript𝐸OE^{\star}italic_O italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is the largest set of directions we’ve identified as being hard, the true set of hard directions may be larger still and is given by Asuperscript𝐴A^{\star}italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT in Definition 3.1. We show in Appendix A.8 that the directions in OE𝑂superscript𝐸OE^{\star}italic_O italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT are indeed hard as per Definition 3.1:

Proposition 3.5.

Let the subspaces Asuperscript𝐴A^{\star}italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and OE𝑂superscript𝐸OE^{\star}italic_O italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT be as defined in Defs. 3.1 and 3.4 respectively. Then, OEA𝑂superscript𝐸superscript𝐴OE^{\star}\subseteq A^{\star}italic_O italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT i.e. all directions in E,OEsuperscript𝐸𝑂superscript𝐸E^{\star},OE^{\star}italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_O italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT are hard as per Thm. 3.2.

App.A.7 gives several examples where P=E=Usubscriptsuperscript𝑃perpendicular-tosuperscriptsubscript𝐸perpendicular-tosuperscript𝑈P^{\star}_{\perp}=E_{\perp}^{\star}=U^{\star}italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT = italic_E start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, such as single-index targets with odd Hermite activations, staircase functions, etc. Interestingly, we show that there exist functions fsuperscript𝑓f^{\star}italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT where the set OE𝑂superscript𝐸OE^{\star}italic_O italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is strictly larger than Esuperscript𝐸E^{\star}italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Consequently, for such functions, Esuperscript𝐸E^{\star}italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is strictly contained in A,Psuperscript𝐴superscript𝑃A^{\star},P^{\star}italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. We discuss such target functions in Appendix A.9. For example, we show in Appendix A.9, that for the target function f(𝐳)=z1z2z3superscript𝑓𝐳subscript𝑧1subscript𝑧2subscript𝑧3f^{\star}(\mathbf{z})=z_{1}z_{2}z_{3}italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) = italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT, the direction 𝐯=𝐞1+𝐞2+𝐞3superscript𝐯subscript𝐞1subscript𝐞2subscript𝐞3\mathbf{v}^{\star}=\mathbf{e}_{1}+\mathbf{e}_{2}+\mathbf{e}_{3}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + bold_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + bold_e start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT does not lie in Esuperscript𝐸E^{\star}italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT but lies in OE𝑂superscript𝐸OE^{\star}italic_O italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and thus in A,Psuperscript𝐴superscript𝑃A^{\star},P^{\star}italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT.

3.3 Comparison between one-pass and multi-pass GD

Our results are particularly interesting in the context of a recent line of work on the limitations of one-pass algorithms. (Ben Arous et al., 2021; Abbe et al., 2021; 2022; 2023; Dandi et al., 2023; Bietti et al., 2023; Zweig and Bruna, 2023). We can demonstrate, in particular, a sharp separation performance between one-pass and multiple-pass protocols.

Learning single-index targets –

First, we consider single index targets. Targets that are hard to learn for one-pass algorithms starting from uninformed initialization in high dimension are characterized by the Information Exponent (IEIE\rm{IE}roman_IE). Informally, the IEIE\rm{IE}roman_IE is equivalent to the first non-zero coefficient in the Hermite expansion of the target activation.

Definition 3.6 (Information Exponent).

(Ben Arous et al., 2021) Let HejsubscriptHe𝑗\mathrm{He}_{j}roman_He start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT be the jlimit-from𝑗j-italic_j -th Hermite polynomial. Using the definition for the target of eq. (1), reading in the single-index case as f=g(𝐰,𝐳)superscript𝑓superscript𝑔superscript𝐰𝐳f^{\star}=g^{\star}(\langle\mathbf{w}^{\star},\mathbf{z}\rangle)italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( ⟨ bold_w start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z ⟩ ), the IEIE\rm IEroman_IE is defined as:

IE=min{j:𝔼ξ𝒩(0,1)[g(ξ)Hej(ξ)]0}IE:𝑗subscript𝔼similar-to𝜉𝒩01delimited-[]superscript𝑔𝜉subscriptHe𝑗𝜉0\displaystyle{\rm{IE}}=\min\{j\in\mathbb{N}:\mathbb{E}_{\xi\sim\mathcal{N}(0,1% )}\left[g^{\star}(\xi)\mathrm{He}_{j}(\xi)\right]\neq 0\}roman_IE = roman_min { italic_j ∈ blackboard_N : blackboard_E start_POSTSUBSCRIPT italic_ξ ∼ caligraphic_N ( 0 , 1 ) end_POSTSUBSCRIPT [ italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_ξ ) roman_He start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_ξ ) ] ≠ 0 } (12)

Higher IEsIEs\rm IEsroman_IEs are associated to harder problems for one-pass training protocols. Indeed, (Ben Arous et al., 2021) provably show that one-pass SGD, with one sample per batch, weakly recovers the teacher direction 𝐰superscript𝐰\mathbf{w}^{\star}bold_w start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT only upon iterating the training schedule for T(IE)𝑇IET(\rm IE)italic_T ( roman_IE ) time iterations:

T(IE)={𝒪(dIE1)if IE>2𝒪(dlogd)if IE=2𝒪(d)if IE=1𝑇IEcasesotherwise𝒪superscriptdIE1if IE>2otherwise𝒪ddif IE=2otherwise𝒪dif IE=1\displaystyle T(\rm IE)=\begin{cases}&\mathcal{O}(d^{\rm IE-1})\qquad\hskip 3.% 00003pt\text{if $\rm{IE}>2$}\\ &\mathcal{O}(d\log{d})\qquad\text{if $\rm{IE}=2$}\\ &\mathcal{O}(d)\qquad\hskip 22.0pt\text{if $\rm{IE}=1$}\end{cases}italic_T ( roman_IE ) = { start_ROW start_CELL end_CELL start_CELL caligraphic_O ( roman_d start_POSTSUPERSCRIPT roman_IE - 1 end_POSTSUPERSCRIPT ) if roman_IE > 2 end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL caligraphic_O ( roman_d roman_log roman_d ) if roman_IE = 2 end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL caligraphic_O ( roman_d ) if roman_IE = 1 end_CELL end_ROW (13)

Recently, the time complexity has been improved up to the Correlational Statistical Query (CSQ) lower bound of dmax(IE2,1)superscript𝑑maxIE21d^{{\rm max}{(\frac{\rm{IE}}{2}},1)}italic_d start_POSTSUPERSCRIPT roman_max ( divide start_ARG roman_IE end_ARG start_ARG 2 end_ARG , 1 ) end_POSTSUPERSCRIPT, by considering an appropriate smoothing of the loss (Damian et al., 2023). Definition 3.6 has been extended to larger batch sizes in (Abbe et al., 2022; Dandi et al., 2023), without changing the overall picture; more precisely, even with n=o(dIE)𝑛𝑜superscript𝑑IEn=o(d^{\rm{IE}})italic_n = italic_o ( italic_d start_POSTSUPERSCRIPT roman_IE end_POSTSUPERSCRIPT ) fresh samples per batch, one-pass training procedures are still not able to weakly recover the signal in finite iteration time. The case IE=1IE1\rm IE=1roman_IE = 1 corresponds to the expectation in Equation 6 being non-zero for F=id𝐹idF=\mathrm{id}italic_F = roman_id. However, since Definition 3.1 allows for general transformations F𝐹Fitalic_F to the output fsuperscript𝑓f^{\star}italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, 𝐰superscript𝐰\mathbf{w}^{\star}bold_w start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT may not be in P,Asuperscript𝑃superscript𝐴P^{\star},A^{\star}italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT even when IE>1IE1\rm{IE}>1roman_IE > 1. The presence of general transformations F𝐹Fitalic_F in definition 3.1 allows our algorithm to bypass CSQ bounds, which are restricted to F=id𝐹idF=\mathrm{id}italic_F = roman_id. Such general transformations are however permitted under the framework of Statistical Query (SQ) algorithms (Kearns, 1998). We thus expect gradient descent with 𝒪(d)𝒪𝑑\mathcal{O}(d)caligraphic_O ( italic_d ) sample complexity to inherit the hardness results established for the class of SQ algorithms (Diakonikolas et al., 2020; Goel et al., 2020; Chen et al., 2021; 2022). We emphasize however that unlike explicit SQ algorithms, our analysis shows that gradient descent performs such transformations implicitly, allowing it to reach the optimal complexity of SQ algorithms for certain class of target functions.

We illustrate the sharp contrast between one-pass and multiple-pass protocols with the examples depicted in Figure 1, which shows the scalar product (called overlap) between the learned weights and the teacher direction as a function of the time steps and compares simulation (dots) with theoretical predictions (continuous lines). There are 3333 cases:

  • Easy finite-T𝑇Titalic_T learnable single-index targets (IE=1)IE1\left(\rm{IE}\!=\!1\right)( roman_IE = 1 ): The left panel of Fig. 1 show the learning curve for a problem with IE=1IE1\rm{IE}\!=\!1roman_IE = 1. Both single-pass and multiple-pass GD correlates with the target in finite time. The non-symmetric subspace coincides with the teacher one E=Usuperscriptsubscript𝐸perpendicular-tosuperscript𝑈E_{\perp}^{\star}=U^{\star}italic_E start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT (Def. 3.3).

  • Multi-pass finite-T𝑇Titalic_T learnable single-index targets (IE>1,nonevenf)IE1nonevensuperscriptf\left(\rm IE\!>\!1,\rm{non-even\,\,f^{\star}}\right)( roman_IE > 1 , roman_non - roman_even roman_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ): Fig. 1 (center) depicts the learning curve for a non-even target function, with IE=3IE3\rm{IE}=3roman_IE = 3. Here, one-pass GD is not able to achieve any significant correlation with the teacher 𝐰superscript𝐰\mathbf{w}^{\star}bold_w start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT (and it would require a number of iterations T=𝒪(d)𝑇𝒪𝑑T=\mathcal{O}(d)italic_T = caligraphic_O ( italic_d ) to achieve weak recovery - see eq. (13)). However, multiple-pass GD performs weak recovery in only T=2𝑇2T=2italic_T = 2 steps. As before, the non-symmetric subspace corresponds to the teacher one E=Usuperscriptsubscript𝐸perpendicular-tosuperscript𝑈E_{\perp}^{\star}=U^{\star}italic_E start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT (Def. 3.3).

  • Finite-T𝑇Titalic_T non-learnable single-index targets (IE>1,evenf)IE1evensuperscriptf\left(\rm{IE}\!>\!1,\,\,even\,\,f^{\star}\right)( roman_IE > 1 , roman_even roman_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ): Fig. 1 (right) considers an even problem, with IE=4IE4\rm IE=4roman_IE = 4. Neither of the training procedures achieve weak recovery in finite time. The computational hardness of this problem is associated with the presence of symmetry in the teacher function that requires time to break. Indeed, following Definition 3.3, the even-symmetric subspace Esuperscript𝐸E^{\star}italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is equivalent to the teacher subspace Usuperscript𝑈U^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. These results agree with the emergence of computational barriers in symmetric single-index problems like the phase retrieval one (Maillard et al., 2020). In fact, for such problems, regardless of the number of iterations, learnability requires α𝛼\alphaitalic_α to be larger than critical values even for the most efficient known algorithms (see (Barbier et al., 2019), Sec. 3.1).

Learning multi-index targets –

The hardness of multi-index targets learning has been the subject of numerous recent studies for single-pass algorithms (Abbe et al., 2021; 2022; 2023; Bietti et al., 2023; Zweig and Bruna, 2023; Dandi et al., 2023). The class of multi-index targets efficiently learned by one-pass algorithms has been provably associated with the Leap Complexity (LCLC\rm LCroman_LC) of the target to be learned, which generalizes the information exponent:

Remark 3.7.

To enhance the clarity of the presentation, we limit the definition of the LCLC\rm LCroman_LC to an informal one. We refer to Section 𝐁.2𝐁.2\bf B.2bold_B bold_.2 (isoLeap) in (Abbe et al., 2023) and Definition 𝟑3\bf 3bold_3 in (Dandi et al., 2023) (leap index) for details.

Informally, the learning dynamics of one-pass routines follow this behavior: initially, the network learns in the first step the first Hermite coefficient of the target fsuperscript𝑓f^{\star}italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. For every time t[T]𝑡delimited-[]𝑇t\!\in\![T]italic_t ∈ [ italic_T ] of the one-pass schedule, the network is bound to learn in finite time only features that are linearly connected to the previously learned directions; functions possessing only such linearly connected features are leap 1111 functions (LC=1LC1\rm LC=1roman_LC = 1), e.g. f(𝐳)=z1+z1z2superscript𝑓𝐳subscript𝑧1subscript𝑧1subscript𝑧2f^{\star}(\mathbf{z})\!=\!z_{1}\!+\!z_{1}z_{2}italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) = italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. Similarly, functions that are quadratically connected to the learned features are leap 2222 (LC=2LC2\rm LC\!=\!2roman_LC = 2), e.g. f(𝐳)=z1+z1z2z3superscript𝑓𝐳subscript𝑧1subscript𝑧1subscript𝑧2subscript𝑧3f^{\star}(\mathbf{z})\!=\!z_{1}+z_{1}z_{2}z_{3}italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) = italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT. Higher LCLC\rm{LC}roman_LC target functions correspond to harder learning problems for one-pass algorithms: one-pass SGD, with one sample per batch, weakly recovers the teacher subspace Usuperscript𝑈U^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT by iterating the training protocol for T(LC)𝑇LCT(\rm{LC})italic_T ( roman_LC ) time steps, where the LCLC\rm{LC}roman_LC substitutes the IEIE\rm{IE}roman_IE in eq. (13) (Abbe et al., 2023).

We illustrate the behavior of one-pass and multiple-pass algorithms when learning multi-index functions in Fig. 2. Using different two-index teachers (k=2𝑘2k=2italic_k = 2), it shows the scalar product between the learned weights and two reference vectors: a) the first Hermite coefficient of the target fsuperscript𝑓f^{\star}italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, called in the following 𝐂1[f]subscript𝐂1delimited-[]superscript𝑓\mathbf{C}_{1}[f^{\star}]bold_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ]; b) the vector in the teacher subspace Usuperscript𝑈U^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT orthogonal to 𝐂1[f]subscript𝐂1delimited-[]superscript𝑓\mathbf{C}_{1}[f^{\star}]bold_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ], referred as 𝐂1[f]subscript𝐂1superscriptdelimited-[]superscript𝑓perpendicular-to\mathbf{C}_{1}[f^{\star}]^{\perp}bold_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⟂ end_POSTSUPERSCRIPT. The figure exemplifies the correlations metrics as a function of time, labeled as overlap (resp. orthogonal overlap) in the upper (resp. lower) section. There are, again, 3 cases:

  • Finite-T𝑇Titalic_T learnable multi-index targets: Fig. 2 (left) depicts a target with LC=1LC1\rm LC\!=\!1roman_LC = 1. The teacher subspace Usuperscript𝑈U^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT spanned by the standard basis vectors {𝐞1,𝐞2}subscript𝐞1subscript𝐞2\{\mathbf{e}_{1},\mathbf{e}_{2}\}{ bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT } is learned by both one-pass and multi-pass GD in finite time. At T=1𝑇1T\!=\!1italic_T = 1, 𝐞1=𝐂1[f]subscript𝐞1subscript𝐂1delimited-[]superscript𝑓\mathbf{e}_{1}\!=\!\mathbf{C}_{1}[f^{\star}]bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = bold_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ] is learned; this enables the recovery of the direction 𝐞2subscript𝐞2\mathbf{e}_{2}bold_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT at T=2𝑇2T\!=\!2italic_T = 2 as the target is linear in z2subscript𝑧2z_{2}italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT once e1subscript𝑒1e_{1}italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT has been learned. This hierarchical picture of learning is called staircase mechanism. Using Def. 3.3 notations, the non-symmetric teacher subspace Esubscriptsuperscript𝐸perpendicular-toE^{\star}_{\perp}italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT is equivalent to the full teacher subspace Usuperscript𝑈U^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT.

  • Multi-pass finite-T𝑇Titalic_T learnable multi-index targets: The central panel in Fig. 2 illustrates a teacher with LC=3LC3\rm LC=3roman_LC = 3. Both algorithms are successful in weakly recovering the direction 𝐂1[f]subscript𝐂1delimited-[]superscript𝑓\mathbf{C}_{1}[f^{\star}]bold_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ] in the first step. However, as the training continues, one-pass GD never recovers the full teacher subspace in finite time (exemplified by the zero orthogonal overlap in the lower panel). Conversely, multi-pass GD is able to perform weak recovery of the full teacher subspace Usuperscript𝑈U^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT by achieving a non-vanishing correlation with 𝐂1[f]subscript𝐂1superscriptdelimited-[]superscript𝑓perpendicular-to\mathbf{C}_{1}[f^{\star}]^{\perp}bold_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⟂ end_POSTSUPERSCRIPT (non-zero orthogonal overlap in the lower section) in just T=2𝑇2T=2italic_T = 2 steps. Again, the non-symmetric subspace E=Usuperscriptsubscript𝐸perpendicular-tosuperscript𝑈E_{\perp}^{\star}=U^{\star}italic_E start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is equivalent to the full teacher subspace Usuperscript𝑈U^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT (Def. 3.3).

  • Finite-T𝑇Titalic_T non-learnable multi-index targets: The right panel of Fig. 2 considers a committee machine teacher with symmetric activation, i.e. f(𝐳)=r=12σ(𝐳,𝐞r)superscript𝑓𝐳superscriptsubscript𝑟12superscript𝜎𝐳subscript𝐞𝑟f^{\star}(\mathbf{z})=\sum_{r=1}^{2}\sigma^{\star}\left(\langle\mathbf{z},% \mathbf{e}_{r}\rangle\right)italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) = ∑ start_POSTSUBSCRIPT italic_r = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( ⟨ bold_z , bold_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ⟩ ), here LC=2LC2\rm LC=2roman_LC = 2. Both protocols, in this case, are only able to learn a single-index approximation of the target function in finite time, achieving non-zero correlation only with 𝐂1[f]𝐞1+𝐞2proportional-tosubscript𝐂1delimited-[]superscript𝑓subscript𝐞1subscript𝐞2\mathbf{C}_{1}[f^{\star}]\propto\mathbf{e}_{1}+\mathbf{e}_{2}bold_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ] ∝ bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + bold_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT throughout the dynamics. The computational hardness of this problem is associated with the presence of a neuron exchange symmetry. Indeed, using Def. 3.3 notations, we observe that the even-symmetric subspace E={12(𝐞2𝐞1)}superscript𝐸12subscript𝐞2subscript𝐞1E^{\star}=\{\frac{1}{\sqrt{2}}\left(\mathbf{e}_{2}-\mathbf{e}_{1}\right)\}italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = { divide start_ARG 1 end_ARG start_ARG square-root start_ARG 2 end_ARG end_ARG ( bold_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) } is a non-empty subspace of the teacher one Usuperscript𝑈U^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Therefore, as for one-pass routines, multiple-pass ones are bound to learn only 𝐯=(𝐞1+𝐞2)/2superscript𝐯subscript𝐞1subscript𝐞22\mathbf{v}^{\star}=\left(\mathbf{e}_{1}+\mathbf{e}_{2}\right)/\sqrt{2}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = ( bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + bold_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) / square-root start_ARG 2 end_ARG in finite time steps. Such difficulties have been described in the analysis of the specialization transition in the information-theoretic/Bayes optimal case of symmetric committees (Aubin et al., 2019). As for single index models, breaking the symmetry requires α𝛼\alphaitalic_α to be large enough and, even in this case, the best-known algorithms require a diverging number of iterations (see (Aubin et al., 2019), Sec. 3).

3.4 From weak recovery to generalization

While Th. 3.2 provides conditions for the weak recovery (a finite overlap with directions in Usuperscript𝑈U^{\star}italic_U start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT), once this is done, it becomes straightforward to learn the function up to any desired accuracy with only 𝒪(d)𝒪𝑑\mathcal{O}(d)caligraphic_O ( italic_d ) additional samples. Indeed, strong generalization guarantees can be proven by utilizing existing results either for subsequent training with online SGD (Ben Arous et al., 2021) (to use their terminology, once you escape mediocrity, the ballistic phase is easy) or training of the second layer using an independent batch of 𝒪(d)𝒪𝑑\mathcal{O}(d)caligraphic_O ( italic_d ) samples as in (Damian et al., 2022; Abbe et al., 2023). See App.A.10 for such generalization sample-complexity results.

4 Main proof ideas

4.1 Learning by hidden progress: heuristic argument

While we give a rigorous proof of Thm.  3.2 in App. A, we provide now an informal description of the hidden progress in the first step of gradient descent that allows subsequent development of overlaps in the second step, that is at the root of the difference between single and multi-pass algorithms. For simplicity, we focus on the case of a single hidden neuron (p=1𝑝1p\!=\!1italic_p = 1). We denote hμ(t)=𝐰(t),𝐳μ/dsubscriptsuperscript𝑡𝜇superscript𝐰𝑡subscript𝐳𝜇𝑑h^{(t)}_{\mu}\!=\!\langle\mathbf{w}^{(t)},\mathbf{z}_{\mu}\rangle/\sqrt{d}italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT = ⟨ bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_z start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ⟩ / square-root start_ARG italic_d end_ARG the pre-activation for the νthsubscript𝜈𝑡\nu_{th}italic_ν start_POSTSUBSCRIPT italic_t italic_h end_POSTSUBSCRIPT training point along the neuron with μ[n]𝜇delimited-[]𝑛\mu\in[n]italic_μ ∈ [ italic_n ], and 𝐯superscript𝐯\mathbf{v}^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT a vector in the span of 𝐖superscript𝐖\mathbf{W}^{\star}bold_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT with 𝐯=ddelimited-∥∥superscript𝐯𝑑\left\lVert\mathbf{v}^{\star}\right\rVert\!=\!\sqrt{d}∥ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ = square-root start_ARG italic_d end_ARG.

From the gradient update in Eq. (5), the update lies in the span of the training inputs {𝐳ν}ν=1nsuperscriptsubscriptsubscript𝐳𝜈𝜈1𝑛\{\mathbf{z}_{\nu}\}_{\nu=1}^{n}{ bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_ν = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, with the gradient of the νthsubscript𝜈𝑡\nu_{th}italic_ν start_POSTSUBSCRIPT italic_t italic_h end_POSTSUBSCRIPT training example given by 𝐠ν=a(hν(t),f(𝐳ν))σ(hν(t))𝐳ν/dsubscript𝐠𝜈𝑎superscriptsubscriptsuperscript𝑡𝜈superscript𝑓subscript𝐳𝜈superscript𝜎subscriptsuperscript𝑡𝜈subscript𝐳𝜈𝑑\mathbf{g}_{\nu}=a\mathcal{L}^{\prime}\left(h^{(t)}_{\nu},f^{\star}(\mathbf{z}% _{\nu})\right)\sigma^{\prime}(h^{(t)}_{\nu})\mathbf{z}_{\nu}/\sqrt{d}bold_g start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT = italic_a caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT , italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT / square-root start_ARG italic_d end_ARG. For squared loss, assuming that f(W(t)𝐳νd)0𝑓superscript𝑊𝑡subscript𝐳𝜈𝑑0f\left(\frac{W^{(t)}\mathbf{z}_{\nu}}{\sqrt{d}}\right)\approx 0italic_f ( divide start_ARG italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) ≈ 0, the gradient reads:

𝐠ν(t)af(𝐳ν)σ(hν(t))𝐳ν/d.subscriptsuperscript𝐠𝑡𝜈𝑎superscript𝑓subscript𝐳𝜈superscript𝜎subscriptsuperscript𝑡𝜈subscript𝐳𝜈𝑑\mathbf{g}^{(t)}_{\nu}\approx-af^{\star}(\mathbf{z}_{\nu})\sigma^{\prime}(h^{(% t)}_{\nu})\mathbf{z}_{\nu}/\sqrt{d}\,.bold_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ≈ - italic_a italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT / square-root start_ARG italic_d end_ARG . (14)

At initialization hν(0)=𝐰(0),𝐳ν/dsubscriptsuperscript0𝜈superscript𝐰0subscript𝐳𝜈𝑑h^{(0)}_{\nu}=\langle\mathbf{w}^{(0)},\mathbf{z}_{\nu}\rangle/\sqrt{d}italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT = ⟨ bold_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ⟩ / square-root start_ARG italic_d end_ARG and the projections along the teacher subspace (which we denote 𝐡ν=1dW𝐳νksubscriptsuperscript𝐡𝜈1𝑑superscript𝑊subscript𝐳𝜈superscript𝑘\mathbf{h}^{\star}_{\nu}=\frac{1}{\sqrt{d}}W^{\star}\mathbf{z}_{\nu}\in\mathbb% {R}^{k}bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT) are approximately independent since 𝐰(0)superscript𝐰0\mathbf{w}^{(0)}bold_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT is approximately orthogonal to the teacher subspace as well as to the inputs {𝐳ν}ν=1nsuperscriptsubscriptsubscript𝐳𝜈𝜈1𝑛\{\mathbf{z}_{\nu}\}_{\nu=1}^{n}{ bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_ν = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT. The projection of the gradient along the teacher subspace is given by:

W(ν=1n𝐠ν(t)d)ν=1naf(𝐳ν)σ(hν(t))𝐡νd.superscript𝑊superscriptsubscript𝜈1𝑛subscriptsuperscript𝐠𝑡𝜈𝑑superscriptsubscript𝜈1𝑛𝑎superscript𝑓subscript𝐳𝜈superscript𝜎subscriptsuperscript𝑡𝜈subscriptsuperscript𝐡𝜈𝑑\displaystyle W^{\star}\left(\sum_{\nu=1}^{n}\frac{\mathbf{g}^{(t)}_{\nu}}{% \sqrt{d}}\right)\!\approx\!\!-\!\sum_{\nu=1}^{n}af^{\star}(\mathbf{z}_{\nu})% \sigma^{\prime}(h^{(t)}_{\nu})\frac{\mathbf{h}^{\star}_{\nu}}{d}\,.italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_ν = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT divide start_ARG bold_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) ≈ - ∑ start_POSTSUBSCRIPT italic_ν = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_a italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) divide start_ARG bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT end_ARG start_ARG italic_d end_ARG . (15)

We do expect that, due to concentration, the component of the full-batch gradient update along the teacher subspace lies along the direction given by:

𝔼[f(𝐳ν)σ(hν(0))𝐡ν]𝔼[σ(hν(0))]𝔼[f(𝐳ν)𝐡ν],𝔼delimited-[]superscript𝑓subscript𝐳𝜈superscript𝜎subscriptsuperscript0𝜈subscriptsuperscript𝐡𝜈𝔼delimited-[]superscript𝜎subscriptsuperscript0𝜈𝔼delimited-[]superscript𝑓subscript𝐳𝜈subscriptsuperscript𝐡𝜈\mathbb{E}\left[f^{\star}(\mathbf{z}_{\nu})\sigma^{\prime}(h^{(0)}_{\nu})% \mathbf{h}^{\star}_{\nu}\right]\!\approx\!\mathbb{E}\left[\sigma^{\prime}(h^{(% 0)}_{\nu})\right]\!\mathbb{E}\left[f^{\star}(\mathbf{z}_{\nu})\mathbf{h}^{% \star}_{\nu}\right],blackboard_E [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ] ≈ blackboard_E [ italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) ] blackboard_E [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ] , (16)

where we used the approximate independence of hν(0)subscriptsuperscript0𝜈h^{(0)}_{\nu}italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT and 𝐡νsubscriptsuperscript𝐡𝜈\mathbf{h}^{\star}_{\nu}bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT to factorize the expectation. Thus, the neuron parameters 𝐰(1)superscript𝐰1\mathbf{w}^{(1)}bold_w start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT at the first step are correlated with the teacher subspace only along the direction 𝔼[f(𝐳ν)𝐡ν]𝔼delimited-[]superscript𝑓subscript𝐳𝜈subscriptsuperscript𝐡𝜈\mathbb{E}\left[f^{\star}(\mathbf{z}_{\nu})\mathbf{h}^{\star}_{\nu}\right]blackboard_E [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ].

If 𝔼[f(𝐳ν)𝐡ν]=0𝔼delimited-[]superscript𝑓subscript𝐳𝜈subscriptsuperscript𝐡𝜈0\mathbb{E}\left[f^{\star}(\mathbf{z}_{\nu})\mathbf{h}^{\star}_{\nu}\right]=0blackboard_E [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ] = 0, the parameters remain orthogonal to the teacher subspace. This is true whenever the LCLC\rm LCroman_LC of the target function is larger than 1111. To make progress, it is thus necessary for the pre-activations hν(t)=𝐰(t),𝐳ν/dsubscriptsuperscript𝑡𝜈superscript𝐰𝑡subscript𝐳𝜈𝑑h^{(t)}_{\nu}=\langle\mathbf{w}^{(t)},\mathbf{z}_{\nu}\rangle/\sqrt{d}italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT = ⟨ bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ⟩ / square-root start_ARG italic_d end_ARG to become correlated with the teacher pre-activation 𝐡νsubscriptsuperscript𝐡𝜈\mathbf{h}^{\star}_{\nu}bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT. This can happen in two different ways:

(i) By 𝐰(t)superscript𝐰𝑡\mathbf{w}^{(t)}bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT directly gaining components along the teacher subspace Wsuperscript𝑊W^{*}italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. Under online SGD, the data is used only once for the gradient updates, so only this mechanism is possible. It allows the directions learned by 𝐰(t)superscript𝐰𝑡\mathbf{w}^{(t)}bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT at any step to depend on the directions already learned by 𝐰(t)superscript𝐰𝑡\mathbf{w}^{(t)}bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT. This underlies the “staircase” phenomenon in online SGD (Abbe et al., 2021; 2022; 2023) as well as the notion of information exponent when applied to a single direction (Ben Arous et al., 2021).

(ii) By 𝐰(t)superscript𝐰𝑡\mathbf{w}^{(t)}bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT gaining components along 𝐳νsubscript𝐳𝜈\mathbf{z}_{\nu}bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT. Recall that the target is defined as yν=f(𝐳ν),subscript𝑦𝜈superscript𝑓subscript𝐳𝜈y_{\nu}=f^{\star}(\mathbf{z}_{\nu})\,,italic_y start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT = italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) , and thus 𝐰(t)superscript𝐰𝑡\mathbf{w}^{(t)}bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT can correlate with 𝐡νsubscriptsuperscript𝐡𝜈\mathbf{h}^{\star}_{\nu}bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT. This is what happens when using gradient descent with multi-pass in our setting. This implies that even when 𝐰(1)superscript𝐰1\mathbf{w}^{(1)}bold_w start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT does not learn a direction 𝐯subscript𝐯\mathbf{v}_{*}bold_v start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT, the pre-activation hν(t)subscriptsuperscript𝑡𝜈h^{(t)}_{\nu}italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT can develop a dependence on 𝐡νsubscriptsuperscript𝐡𝜈\mathbf{h}^{\star}_{\nu}bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT through the component of the gradient update along 𝐳νsubscript𝐳𝜈\mathbf{z}_{\nu}bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT.

Let us see how this phenomenon, which we call hidden progress, happens in practice. From (5), the update in the pre-activation hν(1)subscriptsuperscript1𝜈h^{(1)}_{\nu}italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT due to the first gradient step reads:

hν(1)=(1\displaystyle h^{(1)}_{\nu}\!=\!(1italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT = ( 1 ηλ)h(0)νηadν=1n(hν(0),f(𝐳ν))σ(hν(0))𝐳ν,𝐳ν.\displaystyle-\eta\lambda)h^{(0)}_{\nu}-\eta\frac{a}{d}\sum_{{\nu^{\prime}}=1}% ^{n}\mathcal{L}^{\prime}(h^{(0)}_{\nu^{\prime}},f^{\star}(\mathbf{z}_{\nu^{% \prime}}))\sigma^{\prime}(h^{(0)}_{\nu^{\prime}})\langle\mathbf{z}_{\nu^{% \prime}},\mathbf{z}_{\nu}\rangle\,.- italic_η italic_λ ) italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT - italic_η divide start_ARG italic_a end_ARG start_ARG italic_d end_ARG ∑ start_POSTSUBSCRIPT italic_ν start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) ⟨ bold_z start_POSTSUBSCRIPT italic_ν start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ⟩ . (17)

In this sum there is one term of magnitude 𝒪(𝐳ν2/d)=O(1)𝒪superscriptdelimited-∥∥subscript𝐳𝜈2𝑑𝑂1\mathcal{O}\left({\left\lVert\mathbf{z}_{\nu}\right\rVert^{2}}/{d}\right)=O(1)caligraphic_O ( ∥ bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / italic_d ) = italic_O ( 1 ) corresponding to ν=νsuperscript𝜈𝜈\nu^{\prime}\!=\!\nuitalic_ν start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_ν, and d1𝑑1d-1italic_d - 1 random terms of order 𝒪(𝐳ν,𝐳ν/d)=O(1/d)𝒪superscript𝐳𝜈superscript𝐳𝜈𝑑𝑂1𝑑\mathcal{O}\left({\langle\mathbf{z}^{\nu},\mathbf{z}^{\nu}\rangle}{/d}\right)=% O\left(1/{\sqrt{d}}\right)caligraphic_O ( ⟨ bold_z start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT , bold_z start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT ⟩ / italic_d ) = italic_O ( 1 / square-root start_ARG italic_d end_ARG ). This second group of terms contributes to an effective “noise” of order O(1)𝑂1O(1)italic_O ( 1 ). The first term however, since 𝐳ν22dsuperscriptsubscriptnormsubscript𝐳𝜈22𝑑\|{\mathbf{z}}_{\nu}\|_{2}^{2}\approx d∥ bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≈ italic_d, depends on f(𝐳ν)superscript𝑓subscript𝐳𝜈f^{\star}(\mathbf{z}_{\nu})italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) (and thus on all components of 𝐡νsubscriptsuperscript𝐡𝜈\mathbf{h}^{\star}_{\nu}bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT):

(hν(0),f(𝐳ν))σ(hν(0)).superscriptsubscriptsuperscript0𝜈superscript𝑓subscript𝐳𝜈superscript𝜎subscriptsuperscript0𝜈\mathcal{L}^{\prime}(h^{(0)}_{\nu},f^{\star}(\mathbf{z}_{\nu}))\sigma^{\prime}% ({h^{(0)}_{\nu}})\,.caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT , italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) . (18)

Due to this dependence between hν(1)subscriptsuperscript1𝜈h^{(1)}_{\nu}italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT and f(𝐳ν)superscript𝑓subscript𝐳𝜈f^{\star}(\mathbf{z}_{\nu})italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ), in the subsequent steps i.e. T=2𝑇2T\!=\!2italic_T = 2, the term σ(hν(1))superscript𝜎subscriptsuperscript1𝜈\sigma^{\prime}(h^{(1)}_{\nu})italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) in the update (14) can now influence the direction of the gradient along the teacher subspace, leading to 𝐰(2)superscript𝐰2\mathbf{w}^{(2)}bold_w start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT gaining correlations with new directions in Wsuperscript𝑊W^{\star}italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. It can be seen as follow: let m𝐯(t)=𝐰(t),𝐯/dsuperscriptsubscript𝑚subscript𝐯𝑡superscript𝐰𝑡superscript𝐯𝑑m_{\mathbf{v}_{\star}}^{(t)}=\langle\mathbf{w}^{(t)},\mathbf{v}^{\star}\rangle% /\sqrt{d}italic_m start_POSTSUBSCRIPT bold_v start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = ⟨ bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩ / square-root start_ARG italic_d end_ARG, it follows from the GD updates that

m𝐯(2)=(1\displaystyle m_{\mathbf{v}^{\star}}^{(2)}\!=\!(1italic_m start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT = ( 1 ηλ)m𝐯(1)ηajdν=1n(hν(1),f(𝐳ν))σ(hν(1))h𝐯ν\displaystyle-\eta\lambda)m_{\mathbf{v}^{\star}}^{(1)}-\eta\frac{a_{j}}{d}\sum% _{\nu=1}^{n}\mathcal{L}^{\prime}(h^{(1)}_{\nu},f^{\star}(\mathbf{z}_{\nu}))% \sigma^{\prime}(h^{(1)}_{\nu})h^{\mathbf{v^{\star}}}_{\nu}- italic_η italic_λ ) italic_m start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT - italic_η divide start_ARG italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_d end_ARG ∑ start_POSTSUBSCRIPT italic_ν = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT , italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) italic_h start_POSTSUPERSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT (19)

Now, suppose that 𝐯superscript𝐯\mathbf{v^{\star}}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is not learned in the first step. However, due to the hidden progress, hν(1)subscriptsuperscript1𝜈h^{(1)}_{\nu}italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT is now dependent on hν𝐯=𝐯,𝐳ν/dsubscriptsuperscriptsuperscript𝐯𝜈superscript𝐯subscript𝐳𝜈𝑑h^{\mathbf{v}^{\star}}_{\nu}=\langle\mathbf{v}^{\star},\mathbf{z}_{\nu}\rangle% /\sqrt{d}italic_h start_POSTSUPERSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT = ⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ⟩ / square-root start_ARG italic_d end_ARG, thus allowing the new expectation of the projection of the update along 𝐯superscript𝐯\mathbf{v^{\star}}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT given by 𝔼[f(𝐳ν)σ(hν(1))hν𝐯]𝔼delimited-[]superscript𝑓subscript𝐳𝜈superscript𝜎subscriptsuperscript1𝜈subscriptsuperscriptsuperscript𝐯𝜈\mathbb{E}\left[f^{\star}(\mathbf{z}_{\nu})\sigma^{\prime}(h^{(1)}_{\nu})h^{% \mathbf{v^{\star}}}_{\nu}\right]blackboard_E [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) italic_h start_POSTSUPERSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ] to be non-zero. This explains how the dependence of the pre-activations hν(t)subscriptsuperscript𝑡𝜈h^{(t)}_{\nu}italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT on f(𝐳ν)superscript𝑓superscript𝐳𝜈f^{\star}(\mathbf{z}^{\nu})italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT ) can allow learning of new directions even when the weights have not gained components along the teacher subspace.

This learning mechanism, however, fails when the target function is symmetric along 𝐯superscript𝐯\mathbf{v^{\star}}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Indeed, for such a direction, hν(1)subscriptsuperscript1𝜈h^{(1)}_{\nu}italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT retains an even dependence on hν𝐯subscriptsuperscriptsuperscript𝐯𝜈h^{\mathbf{v}^{\star}}_{\nu}italic_h start_POSTSUPERSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT, which implies that the expectation of the term 𝔼[f(𝐳ν)σ(hν(t))hν𝐯]𝔼delimited-[]superscript𝑓subscript𝐳𝜈superscript𝜎subscriptsuperscript𝑡𝜈subscriptsuperscriptsuperscript𝐯𝜈\mathbb{E}\left[f^{\star}(\mathbf{z}_{\nu})\sigma^{\prime}(h^{(t)}_{\nu})h^{% \mathbf{v^{\star}}}_{\nu}\right]blackboard_E [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) italic_h start_POSTSUPERSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ] remains 00 for all time steps t[T]𝑡delimited-[]𝑇t\in[T]italic_t ∈ [ italic_T ], with T=𝒪(1)𝑇𝒪1T=\mathcal{O}(1)italic_T = caligraphic_O ( 1 ). Such directions are therefore not learned with a finite number of time-steps and batch-size n=O(d)𝑛𝑂𝑑n=O(d)italic_n = italic_O ( italic_d ) even upon re-using the batches.The rigorous control of all these quantities is a difficult task a priori. One cannot, in particular, express the above sum as an expectation w.r.t independent samples 𝐳νsubscript𝐳𝜈\mathbf{z}_{\nu}bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT since the weights W(1)superscript𝑊1W^{(1)}italic_W start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT now depend on all the samples. Fortunately, this is precisely the difficulty solved by the DMFT equations through an effective stochastic process on the pre-activations that are decoupled across training examples. The rigorous analysis is detailed in App. A. The main lines of the DMFT equations are in Sec. 4.2.

Finally, note that while our proof uses the Gaussian data assumption, the heuristic argument hints that this is not crucial. Additionally, in any real dataset samples are very correlated, and thus a given sample (or a very similar one) may appear many times. In this case, even single-pass algorithms will behave as predicted by our approach. We thus believe it describes a more realistic scenario than the pure single pass theories with fresh i.i.d. data.

4.2 Characterization of the dynamics

Re-using batches at each gradient step requires kee** track of the pre-activations of the parameters. Since the number of pre-activations and the dimensions of the parameters grows with d𝑑ditalic_d, we need a low-dimensional effective dynamics characterizing the quantities of interests such as the overlaps between the student and target parameters. DMFT provides such an effective dynamics through a set of coupled stochastic processes 𝜽(t)psuperscript𝜽𝑡superscript𝑝\bm{\theta}^{(t)}\in\mathbb{R}^{p}bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT and 𝐡(t)psuperscript𝐡𝑡superscript𝑝\mathbf{h}^{(t)}\in\mathbb{R}^{p}bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT representing the joint-distributions of the student, teacher parameters W(t),Wsuperscript𝑊𝑡superscript𝑊W^{(t)},W^{\star}italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. and the student, teacher pre-activations respectively.

We derive the equations and prove their applicability to our setting using existing results in (Celentano et al., 2021; Gerbelot et al., 2023). Asymptotically, for d𝑑d\!\to\!\inftyitalic_d → ∞ with n=αd𝑛𝛼𝑑n\!=\!\alpha ditalic_n = italic_α italic_d, the joint distribution of the student and teacher pre-activations (for each sample), 𝐡ν(t)=W(t)𝐳ν/dpsubscriptsuperscript𝐡𝑡𝜈superscript𝑊𝑡subscript𝐳𝜈𝑑superscript𝑝\mathbf{h}^{(t)}_{\nu}\!=\!W^{(t)}\mathbf{z}_{\nu}/\sqrt{d}\!\in\!\mathbb{R}^{p}bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT = italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT / square-root start_ARG italic_d end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT and 𝐡ν=W𝐳ν/dksubscriptsuperscript𝐡𝜈superscript𝑊subscript𝐳𝜈𝑑superscript𝑘\mathbf{h}^{\star}_{\nu}\!=\!W^{\star}\mathbf{z}_{\nu}/\sqrt{d}\!\in\!\mathbb{% R}^{k}bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT = italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT / square-root start_ARG italic_d end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT converge in distribution to samples from the stochastic process 𝐡(t)superscript𝐡𝑡\mathbf{h}^{(t)}bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and the standard normal variable 𝐡superscript𝐡\mathbf{h}^{\star}bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Similarly, the joint distribution of each component of the student and teacher weights Wi(t)p,Wikformulae-sequencesubscriptsuperscript𝑊𝑡𝑖superscript𝑝subscriptsuperscript𝑊𝑖superscript𝑘W^{(t)}_{i}\in\mathbb{R}^{p},W^{\star}_{i}\in\mathbb{R}^{k}italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT , italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT with i[d]𝑖delimited-[]𝑑i\in[d]italic_i ∈ [ italic_d ] converge in distribution to samples from the stochastic process 𝜽(t)superscript𝜽𝑡\bm{\theta}^{(t)}bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and the standard normal variable 𝜽superscript𝜽\bm{\theta}^{\star}bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT.

𝜽(t+1)=(1ηληΛ(t))𝜽(t)+ητ=0t1R(t,τ)𝜽(τ)ηg(t)𝜽+ητ=0t1R~(t,τ)𝜽+η𝒖(t)superscript𝜽𝑡11𝜂𝜆𝜂superscriptΛ𝑡superscript𝜽𝑡𝜂superscriptsubscript𝜏0𝑡1superscriptsubscript𝑅𝑡𝜏superscript𝜽𝜏𝜂superscript𝑔𝑡superscript𝜽𝜂superscriptsubscript𝜏0𝑡1superscriptsubscript~𝑅𝑡𝜏superscript𝜽𝜂superscript𝒖𝑡\displaystyle\bm{\theta}^{(t+1)}=\left(1-\eta\lambda-\eta\Lambda^{(t)}\right)% \bm{\theta}^{(t)}+\,\eta\sum_{\tau=0}^{t-1}R_{\mathcal{L}}^{(t,\tau)}\bm{% \theta}^{(\tau)}-\,\eta g^{(t)}\bm{\theta}^{\star}+\,\eta\sum_{\tau=0}^{t-1}% \tilde{R}_{\mathcal{L}}^{(t,\tau)}\bm{\theta}^{\star}+\,\eta\bm{u}^{(t)}bold_italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = ( 1 - italic_η italic_λ - italic_η roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_η ∑ start_POSTSUBSCRIPT italic_τ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT caligraphic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT - italic_η italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT + italic_η ∑ start_POSTSUBSCRIPT italic_τ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT over~ start_ARG italic_R end_ARG start_POSTSUBSCRIPT caligraphic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT + italic_η bold_italic_u start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT (20)
𝒉(t)=ητ=0t1Rθ(t,τ)𝒉(𝒉(τ),𝒉)+𝝎(t)superscript𝒉𝑡𝜂superscriptsubscript𝜏0𝑡1superscriptsubscript𝑅𝜃𝑡𝜏subscript𝒉superscript𝒉𝜏superscript𝒉superscript𝝎𝑡\displaystyle\bm{h}^{(t)}=-\eta\sum_{\tau=0}^{t-1}R_{\theta}^{(t,\tau)}\nabla_% {\bm{h}}\mathcal{L}(\bm{h}^{(\tau)},\bm{h}^{\star})+\bm{\omega}^{(t)}bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = - italic_η ∑ start_POSTSUBSCRIPT italic_τ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) + bold_italic_ω start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT (21)

Notice that the formula above is the high dimensional equivalent of the gradient descent update (5). Here 𝐮(t)superscript𝐮𝑡\mathbf{u}^{(t)}bold_u start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and (𝝎(t),𝜽)superscript𝝎𝑡superscript𝜽(\bm{\omega}^{(t)},\bm{\theta}^{\star})( bold_italic_ω start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) are zero mean Gaussian Process with covariances C(t,τ)superscriptsubscript𝐶𝑡𝜏C_{\mathcal{L}}^{(t,\tau)}italic_C start_POSTSUBSCRIPT caligraphic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT and Ω(t,τ)superscriptΩ𝑡𝜏\Omega^{(t,\tau)}roman_Ω start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT respectively, with

C(t,τ)=α𝔼𝒉(t),𝒉[𝒉(𝒉(t),𝒉)𝒉(𝒉(τ),𝒉)]superscriptsubscript𝐶𝑡𝜏𝛼subscript𝔼superscript𝒉𝑡superscript𝒉delimited-[]subscript𝒉superscript𝒉𝑡superscript𝒉subscript𝒉superscriptsuperscript𝒉𝜏superscript𝒉top\displaystyle C_{\mathcal{L}}^{(t,\tau)}\!\!=\alpha\mathbb{E}_{\bm{h}^{(t)},% \bm{h}^{\star}}\!\left[\nabla_{\bm{h}}\mathcal{L}(\bm{h}^{(t)},\bm{h}^{\star})% \nabla_{\bm{h}}\mathcal{L}(\bm{h}^{(\tau)},\bm{h}^{\star})^{\top}\!\right]italic_C start_POSTSUBSCRIPT caligraphic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = italic_α blackboard_E start_POSTSUBSCRIPT bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ]
Ω(t,τ)=[Cθ(t,τ)M(t)M(τ)1]=𝔼𝜽(t),𝜽[(𝜽(t)𝜽)(𝜽(τ)𝜽)]superscriptΩ𝑡𝜏matrixsuperscriptsubscript𝐶𝜃𝑡𝜏superscript𝑀𝑡superscript𝑀𝜏1subscript𝔼superscript𝜽𝑡superscript𝜽delimited-[]matrixsuperscript𝜽𝑡superscript𝜽superscriptmatrixsuperscript𝜽𝜏superscript𝜽top\displaystyle\Omega^{(t,\tau)}\!\!=\!\!\begin{bmatrix}C_{\theta}^{(t,\tau)}&M^% {(t)}\\ M^{(\tau)}&1\end{bmatrix}\!\!=\!\mathbb{E}_{\bm{\theta}^{(t)},\bm{\theta}^{% \star}}\!\left[\begin{pmatrix}\bm{\theta}^{(t)}\\ \!\!\bm{\theta}^{\star}\end{pmatrix}\begin{pmatrix}\bm{\theta}^{(\tau)}\!\\ \bm{\theta}^{\star}\!\end{pmatrix}^{\top}\!\right]roman_Ω start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = [ start_ARG start_ROW start_CELL italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT end_CELL start_CELL italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL italic_M start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT end_CELL start_CELL 1 end_CELL end_ROW end_ARG ] = blackboard_E start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ ( start_ARG start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ) ( start_ARG start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ]

the matrix Λ(t)superscriptΛ𝑡\Lambda^{(t)}roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT can be viewed as an “effective regularization” on the parameters. Λ(t)superscriptΛ𝑡\Lambda^{(t)}roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and the projected gradient g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT converge in probability to:

Λ(t)=α𝔼𝒉(t),𝒉[𝐡2(𝐡(t),𝒉)],superscriptΛ𝑡𝛼subscript𝔼superscript𝒉𝑡superscript𝒉delimited-[]subscriptsuperscript2𝐡superscript𝐡𝑡superscript𝒉\displaystyle\Lambda^{(t)}\!=\!\alpha\mathbb{E}_{\bm{h}^{(t)},\bm{h}^{\star}}% \left[\nabla^{2}_{\mathbf{h}}\mathcal{L}\left(\mathbf{h}^{(t)},\bm{h}^{\star}% \right)\right]\,,\,roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = italic_α blackboard_E start_POSTSUBSCRIPT bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT caligraphic_L ( bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ] , (22)
g(t)=α𝔼𝒉(t),𝒉[𝐡(𝐡(t),𝒉)𝐡]superscript𝑔𝑡𝛼subscript𝔼superscript𝒉𝑡superscript𝒉delimited-[]subscript𝐡superscript𝐡𝑡superscript𝒉superscript𝐡absenttop\displaystyle g^{(t)}\!=\!\alpha\mathbb{E}_{\bm{h}^{(t)},\bm{h}^{\star}}\left[% \nabla_{\mathbf{h}}\mathcal{L}\left(\mathbf{h}^{(t)},\bm{h}^{\star}\right)% \mathbf{h}^{\star\top}\right]italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = italic_α blackboard_E start_POSTSUBSCRIPT bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT caligraphic_L ( bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) bold_h start_POSTSUPERSCRIPT ⋆ ⊤ end_POSTSUPERSCRIPT ] (23)

The memory kernels R(t,τ)superscriptsubscript𝑅𝑡𝜏R_{\mathcal{L}}^{(t,\tau)}italic_R start_POSTSUBSCRIPT caligraphic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT, R~(t,τ)superscriptsubscript~𝑅𝑡𝜏\tilde{R}_{\mathcal{L}}^{(t,\tau)}over~ start_ARG italic_R end_ARG start_POSTSUBSCRIPT caligraphic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT, Rθ(t,τ)superscriptsubscript𝑅𝜃𝑡𝜏R_{\theta}^{(t,\tau)}italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT are defined as:

Rθ(t,τ)=𝔼𝜽(t),𝜽[θ(t)𝐮(τ)],superscriptsubscript𝑅𝜃𝑡𝜏subscript𝔼superscript𝜽𝑡superscript𝜽delimited-[]superscript𝜃𝑡superscript𝐮𝜏\displaystyle R_{\theta}^{(t,\tau)}=\mathbb{E}_{\bm{\theta}^{(t)},\bm{\theta}^% {\star}}\left[\frac{\partial\,\mathbf{\theta}^{(t)}}{\partial\,\mathbf{u}^{(% \tau)}}\right]\,,italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = blackboard_E start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ divide start_ARG ∂ italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_u start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT end_ARG ] ,
R(t,τ)=α𝔼𝒉(t),𝒉[𝐡(𝐡(t),𝒉)𝝎(τ)],superscriptsubscript𝑅𝑡𝜏𝛼subscript𝔼superscript𝒉𝑡superscript𝒉delimited-[]subscript𝐡superscript𝐡𝑡superscript𝒉superscript𝝎𝜏\displaystyle R_{\mathcal{L}}^{(t,\tau)}=\alpha\mathbb{E}_{\bm{h}^{(t)},\bm{h}% ^{\star}}\left[\frac{\partial\,\nabla_{\mathbf{h}}\mathcal{L}\left(\mathbf{h}^% {(t)},\bm{h}^{\star}\right)}{\partial\,\bm{\omega}^{(\tau)}}\right]\,,italic_R start_POSTSUBSCRIPT caligraphic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = italic_α blackboard_E start_POSTSUBSCRIPT bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ divide start_ARG ∂ ∇ start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT caligraphic_L ( bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_ARG start_ARG ∂ bold_italic_ω start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT end_ARG ] , (24)
R~(t,τ)=α𝔼𝒉(t),𝒉[𝐡(𝐡(t),𝒉)(𝜽)(τ)],superscriptsubscript~𝑅𝑡𝜏𝛼subscript𝔼superscript𝒉𝑡superscript𝒉delimited-[]subscript𝐡superscript𝐡𝑡superscript𝒉superscriptsuperscript𝜽𝜏\displaystyle\tilde{R}_{\mathcal{L}}^{(t,\tau)}=\alpha\mathbb{E}_{\bm{h}^{(t)}% ,\bm{h}^{\star}}\left[\frac{\partial\,\nabla_{\mathbf{h}}\mathcal{L}\left(% \mathbf{h}^{(t)},\bm{h}^{\star}\right)}{\partial\,\left(\bm{\theta}^{\star}% \right)^{(\tau)}}\right]\,,\quadover~ start_ARG italic_R end_ARG start_POSTSUBSCRIPT caligraphic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = italic_α blackboard_E start_POSTSUBSCRIPT bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ divide start_ARG ∂ ∇ start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT caligraphic_L ( bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_ARG start_ARG ∂ ( bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT end_ARG ] , (25)

and R(t,t)=R~(t,t)=0superscriptsubscript𝑅𝑡𝑡superscriptsubscript~𝑅𝑡𝑡0R_{\mathcal{L}}^{(t,t)}=\tilde{R}_{\mathcal{L}}^{(t,t)}=0italic_R start_POSTSUBSCRIPT caligraphic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_t ) end_POSTSUPERSCRIPT = over~ start_ARG italic_R end_ARG start_POSTSUBSCRIPT caligraphic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_t ) end_POSTSUPERSCRIPT = 0, Rθ(t,t)=1superscriptsubscript𝑅𝜃𝑡𝑡1R_{\theta}^{(t,t)}=1italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_t ) end_POSTSUPERSCRIPT = 1. Finally, the low dimensional projections of the weights M(t)superscript𝑀𝑡M^{(t)}italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT will obey

M(t+1)=(1ηλ)M(t)ηg(t).superscript𝑀𝑡11𝜂𝜆superscript𝑀𝑡𝜂superscript𝑔𝑡M^{(t+1)}=(1-\eta\lambda)M^{(t)}-\eta g^{(t)}\,.italic_M start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = ( 1 - italic_η italic_λ ) italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_η italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT . (26)

Notice that these definitions are well-posed because of the causal structure of the gradient descent upgrades, and by extension of (20): the distribution of (𝜽(t+1),𝐡(t+1))superscript𝜽𝑡1superscript𝐡𝑡1(\bm{\theta}^{(t+1)},\mathbf{h}^{(t+1)})( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) is completely determined by {(𝜽(τ),𝐡(τ))}τ[t]subscriptsuperscript𝜽𝜏superscript𝐡𝜏𝜏delimited-[]𝑡\{(\bm{\theta}^{(\tau)},\mathbf{h}^{(\tau)})\}_{\tau\in[t]}{ ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_τ ∈ [ italic_t ] end_POSTSUBSCRIPT and the auxiliary quantities in eqs. (224.2). Iterating backwards we reach the initial condition 𝜽(0)superscript𝜽0\bm{\theta}^{(0)}bold_italic_θ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT, which is a simple function of the data distribution and the initial conditions of the weights. For additional details we refer to App. A. Notice that it is also possible to write this set of equations as a function of a single stochastic process on 𝐡𝐡\mathbf{h}bold_h, as in App. C.

Sketch of proof of the hidden progress —

Finally, we explain how the DMFT equations relate to the phenomenon in Sec. 4.1 and allow us to prove Th. 3.2. The term 𝐡(1)(𝐡(1))subscriptsuperscript𝐡1superscript𝐡1\nabla_{\mathbf{h}^{(1)}}\mathcal{L}\left(\mathbf{h}^{(1)}\right)∇ start_POSTSUBSCRIPT bold_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L ( bold_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) in (21) precisely corresponds to the contribution to pre-activation of a point 𝐳νsuperscript𝐳𝜈\mathbf{z}^{\nu}bold_z start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT (App. A.4) from the gradient at the same point 𝐳νsuperscript𝐳𝜈\mathbf{z}^{\nu}bold_z start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT. As we discussed in Section 4.1, this term induces a dependence between 𝐡ν(1)subscriptsuperscript𝐡1𝜈\mathbf{h}^{(1)}_{\nu}bold_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT and 𝐡νsubscriptsuperscript𝐡𝜈\mathbf{h}^{\star}_{\nu}bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT even when the overlaps M(t)superscript𝑀𝑡M^{(t)}italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT are 00. At time T=1𝑇1T=1italic_T = 1, the response term simplifies to Rθ(1,0)=𝐈𝐝superscriptsubscript𝑅𝜃10subscript𝐈𝐝R_{\theta}^{(1,0)}=\mathbf{I_{d}}italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 , 0 ) end_POSTSUPERSCRIPT = bold_I start_POSTSUBSCRIPT bold_d end_POSTSUBSCRIPT and the pre-activations can be expressed as the random variable 𝐡(t)(𝐡(t))subscriptsuperscript𝐡𝑡superscript𝐡𝑡\nabla_{\mathbf{h}^{(t)}}\mathcal{L}\left(\mathbf{h}^{(t)}\right)∇ start_POSTSUBSCRIPT bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L ( bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) with added Gaussian noise. Analogous to section 4.1, we denote by M𝐯(t)superscriptsubscript𝑀superscript𝐯𝑡M_{\mathbf{v}^{*}}^{(t)}italic_M start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT the limiting value of the overlaps 1dW(t)𝐯1𝑑superscript𝑊𝑡superscript𝐯\frac{1}{d}W^{(t)}\mathbf{v}^{*}divide start_ARG 1 end_ARG start_ARG italic_d end_ARG italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT for some 𝐯=(𝐮)Wsuperscript𝐯superscriptsuperscript𝐮topsuperscript𝑊\mathbf{v}^{*}=(\mathbf{u}^{*})^{\top}W^{*}bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = ( bold_u start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT with 𝐮psuperscript𝐮superscript𝑝\mathbf{u}^{*}\in\mathbb{R}^{p}bold_u start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT. Propagating the equations over the first two steps, and using Equation (26), we show that M𝐯(2)superscriptsubscript𝑀superscript𝐯2M_{\mathbf{v}^{*}}^{(2)}italic_M start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT can be expressed as an expectation w.r.t the pre-activations 𝐡superscript𝐡\mathbf{h^{\star}}bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT of a function dependent on the target gsuperscript𝑔g^{\star}italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, the second layer 𝐚𝐚\mathbf{a}bold_a, and the activation function σ𝜎\sigmaitalic_σ:

M𝐯(2)=(1ηλ)M𝐯(1)+ηα𝔼𝐡[Fσ,a(g(𝐡))𝐡]𝐮.superscriptsubscript𝑀superscript𝐯21𝜂𝜆superscriptsubscript𝑀superscript𝐯1𝜂𝛼subscript𝔼superscript𝐡delimited-[]subscript𝐹𝜎𝑎superscript𝑔superscript𝐡superscriptsuperscript𝐡topsuperscript𝐮M_{\mathbf{v}^{*}}^{(2)}\!=\!(1\!-\!\eta\lambda)M_{\mathbf{v}^{*}}^{(1)}+\eta% \alpha\mathbb{E}_{\mathbf{h^{\star}}}\left[F_{\sigma,a}(g^{\star}(\mathbf{h^{% \star}}))\mathbf{h^{\star}}^{\top}\right]\mathbf{u}^{\star}\,.italic_M start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT = ( 1 - italic_η italic_λ ) italic_M start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT + italic_η italic_α blackboard_E start_POSTSUBSCRIPT bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ italic_F start_POSTSUBSCRIPT italic_σ , italic_a end_POSTSUBSCRIPT ( italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ) bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT . (27)

The function Fσ,asubscript𝐹𝜎𝑎F_{\sigma,a}italic_F start_POSTSUBSCRIPT italic_σ , italic_a end_POSTSUBSCRIPT is described in App.A.5, Eq.(90). Finally, we show an equivalence between the condition 𝔼𝐡[Fσ,a(g(𝐡))𝐡]𝐮=0subscript𝔼superscript𝐡delimited-[]subscript𝐹𝜎𝑎superscript𝑔superscript𝐡superscriptsuperscript𝐡topsuperscript𝐮0\mathbb{E}_{\mathbf{h^{\star}}}\left[F_{\sigma,a}(g^{\star}(\mathbf{h^{\star}}% ))\mathbf{h^{\star}}^{\top}\right]\mathbf{u}^{\star}=0blackboard_E start_POSTSUBSCRIPT bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ italic_F start_POSTSUBSCRIPT italic_σ , italic_a end_POSTSUBSCRIPT ( italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ) bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = 0 to the condition 𝔼𝐳[F(f(𝐳))𝐯,𝐳]=0subscript𝔼𝐳delimited-[]𝐹superscript𝑓𝐳superscript𝐯𝐳0\mathbb{E}_{\mathbf{z}}\left[F(f^{\star}(\mathbf{z}))\langle\mathbf{v}^{\star}% ,\mathbf{z}\rangle\right]=0blackboard_E start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT [ italic_F ( italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) ) ⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z ⟩ ] = 0 for general F𝐹Fitalic_F in definition 3.1.

General multi-pass schemes —

While Theorem 3.2 considers finite number of updates with the same batch of data for each step, it can be naturally generalized to other setups involving multiple-passes over a finite-number of mini-batches of size 𝒪(d)𝒪𝑑\mathcal{O}(d)caligraphic_O ( italic_d ). For instance, one can cycle over distinct minibatches with each cycle constituting one epoch or pass through the dataset. Theorem 3.2 remains valid under such a setup with the onset of weak-recovery shifting to the start of the second epoch instead of the second gradient step. We provide a sketch of this extension in Appendix B. On the other hand, if the minibatches are sampled with replacement from the dataset, the weak recovery still starts at the second gradient step. We illustrate this in Fig.4 (in appendix). Furthermore, we empirically observe that the phenomenon holds even when considering the limit of mini-batch size 1111 (Figure 5 Appendix). Proving this, however, remains out of the reach of the present technique.

5 Conclusions

Our study analyzes the training dynamics of two-layer neural networks for learning multi-index target functions, distinctively focusing on multi-pass gradient descent which involves reusing batches multiple times. We find that this enables gradient descent to exceed the constraints imposed by information and leap exponents.

Gradient descent is found to achieve a positive correlation with the target function across a broader class than previously anticipated, with only two data batch repetitions. Our analysis further demonstrates that the limitations associated with information and leap exponents, staircase learning, and CSQ lower bounds are restricted to online/single pass SGD and do not describe the class of functions inherently easy or hard to learn by gradient-based methods for neural networks.

Our conclusions follow from rigorous mathematical proofs derived from Dynamical Mean Field Theory, through which we also offer an analytical description of the dynamic processes of low-dimensional weight projections—a noteworthy insight. Additionally, we provide a closed-form depiction of these dynamical processes and illustrate our theoretical findings with numerical experiments.

6 Acknowledgements

We thank Cedric Gerbelot, Bruno Loureiro and Ludovic Stephan for insightful discussions. We also acknowledge funding from the Swiss National Science Foundation grant SNFS OperaGOST (grant number 200390200390200390200390), and SMArtNet (grant number 212049212049212049212049).

References

  • Abbe et al. [2021] E. Abbe, E. Boix-Adsera, M. S. Brennan, G. Bresler, and D. Nagaraj. The staircase property: How hierarchical structure can guide deep learning. Advances in Neural Information Processing Systems, 34:26989–27002, 2021.
  • Abbe et al. [2022] E. Abbe, E. Boix-Adsera, and T. Misiakiewicz. The merged-staircase property: a necessary and nearly sufficient condition for sgd learning of sparse functions on two-layer neural networks. In Conference on Learning Theory, pages 4782–4887. PMLR, 2022.
  • Abbe et al. [2023] E. Abbe, E. Boix-Adsera, and T. Misiakiewicz. Sgd learning on neural networks: leap complexity and saddle-to-saddle dynamics, 2023.
  • Agoritsas et al. [2018] E. Agoritsas, G. Biroli, P. Urbani, and F. Zamponi. Out-of-equilibrium dynamical mean-field equations for the perceptron model. Journal of Physics A: Mathematical and Theoretical, 51(8):085002, 2018.
  • Andrews [2004] G. E. Andrews. Special functions. Cambridge University Press, 2004.
  • Aubin et al. [2019] B. Aubin, A. Maillard, J. Barbier, F. Krzakala, N. Macris, and L. Zdeborová. The committee machine: computational to statistical gaps in learning a two-layers neural network. Journal of Statistical Mechanics: Theory and Experiment, 2019(12):124023, Dec. 2019. ISSN 1742-5468. doi: 10.1088/1742-5468/ab43d2. URL http://dx.doi.org/10.1088/1742-5468/ab43d2.
  • Ba et al. [2022] J. Ba, M. A. Erdogdu, T. Suzuki, Z. Wang, D. Wu, and G. Yang. High-dimensional asymptotics of feature learning: How one gradient step improves the representation. In S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh, editors, Advances in Neural Information Processing Systems, volume 35, pages 37932–37946. Curran Associates, Inc., 2022.
  • Ba et al. [2023] J. Ba, M. A. Erdogdu, T. Suzuki, Z. Wang, and D. Wu. Learning in the presence of low-dimensional structure: a spiked random matrix perspective. In Neurips 2023, 2023.
  • Barbier et al. [2019] J. Barbier, F. Krzakala, N. Macris, L. Miolane, and L. Zdeborová. Optimal errors and phase transitions in high-dimensional generalized linear models. Proceedings of the National Academy of Sciences, 116(12):5451–5460, 2019.
  • Bayati and Montanari [2011] M. Bayati and A. Montanari. The dynamics of message passing on dense graphs, with applications to compressed sensing. IEEE Transactions on Information Theory, 57(2):764–785, 2011.
  • Ben Arous et al. [1997] G. Ben Arous, A. Guionnet, et al. Symmetric langevin spin glass dynamics. The Annals of Probability, 25(3):1367–1422, 1997.
  • Ben Arous et al. [2021] G. Ben Arous, R. Gheissari, and A. Jagannath. Online stochastic gradient descent on non-convex losses from high-dimensional inference. Journal of Machine Learning Research, 22(106):1–51, 2021.
  • Ben Arous et al. [2022] G. Ben Arous, R. Gheissari, and A. Jagannath. High-dimensional limit theorems for sgd: Effective dynamics and critical scaling. Advances in Neural Information Processing Systems, 35:25349–25362, 2022.
  • Bietti et al. [2023] A. Bietti, J. Bruna, and L. Pillaud-Vivien. On learning gaussian multi-index models with gradient flow. arXiv preprint arXiv:2310.19793, 2023.
  • Bolthausen [2014] E. Bolthausen. An iterative construction of solutions of the tap equations for the sherrington–kirkpatrick model. Communications in Mathematical Physics, 325(1):333–366, 2014.
  • Bordelon et al. [2020] B. Bordelon, A. Canatar, and C. Pehlevan. Spectrum dependent learning curves in kernel regression and wide neural networks. In H. D. III and A. Singh, editors, Proceedings of the 37th International Conference on Machine Learning, volume 119 of Proceedings of Machine Learning Research, pages 1024–1034. PMLR, 13–18 Jul 2020.
  • Bouchaud et al. [1998] J.-P. Bouchaud, L. F. Cugliandolo, J. Kurchan, and M. Mézard. Out of equilibrium dynamics in spin-glasses and other glassy systems. Spin glasses and random fields, 12:161, 1998.
  • Celentano et al. [2021] M. Celentano, C. Cheng, and A. Montanari. The high-dimensional asymptotics of first order methods with random data. arXiv:2112.07572, 2021.
  • Chen and Meka [2020] S. Chen and R. Meka. Learning polynomials in few relevant dimensions. In Conference on Learning Theory, pages 1161–1227. PMLR, 2020.
  • Chen et al. [2021] S. Chen, A. Klivans, and R. Meka. Efficiently learning one hidden layer relu networks from queries. Advances in Neural Information Processing Systems, 34:24087–24098, 2021.
  • Chen et al. [2022] S. Chen, A. Gollakota, A. Klivans, and R. Meka. Hardness of noise-free learning for two-hidden-layer neural networks. Advances in Neural Information Processing Systems, 35:10709–10724, 2022.
  • Chizat and Bach [2018] L. Chizat and F. Bach. On the global convergence of gradient descent for over-parameterized models using optimal transport. Advances in neural information processing systems, 31, 2018.
  • Cugliandolo [2003] L. F. Cugliandolo. Dynamics of glassy systems. In Slow Relaxations and nonequilibrium dynamics in condensed matter. Springer, 2003.
  • Cui et al. [2021] H. Cui, B. Loureiro, F. Krzakala, and L. Zdeborová. Generalization error rates in kernel regression: The crossover from the noiseless to noisy regime. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P. Liang, and J. W. Vaughan, editors, Advances in Neural Information Processing Systems, volume 34, pages 10131–10143. Curran Associates, Inc., 2021.
  • Damian et al. [2022] A. Damian, J. Lee, and M. Soltanolkotabi. Neural networks can learn representations with gradient descent. In P.-L. Loh and M. Raginsky, editors, Proceedings of Thirty Fifth Conference on Learning Theory, volume 178 of Proceedings of Machine Learning Research, pages 5413–5452. PMLR, 02–05 Jul 2022.
  • Damian et al. [2023] A. Damian, E. Nichani, R. Ge, and J. D. Lee. Smoothing the Landscape Boosts the Signal for SGD: Optimal Sample Complexity for Learning Single Index Models. Technical report, Princeton, May 2023. arXiv:2305.10633 [cs, math, stat] type: article.
  • Dandi et al. [2023] Y. Dandi, F. Krzakala, B. Loureiro, L. Pesce, and L. Stephan. How two-layer neural networks learn, one (giant) step at a time, 2023.
  • Diakonikolas et al. [2020] I. Diakonikolas, D. M. Kane, V. Kontonis, and N. Zarifis. Algorithms and sq lower bounds for pac learning one-hidden-layer relu networks. In Conference on Learning Theory, pages 1514–1539. PMLR, 2020.
  • Dietrich et al. [1999] R. Dietrich, M. Opper, and H. Sompolinsky. Statistical mechanics of support vector networks. Phys. Rev. Lett., 82:2975–2978, Apr 1999. doi: 10.1103/PhysRevLett.82.2975.
  • Eissfeller and Opper [1992] H. Eissfeller and M. Opper. New method for studying the dynamics of disordered spin systems without finite-size effects. Physical review letters, 68(13):2094, 1992.
  • Eissfeller and Opper [1994] H. Eissfeller and M. Opper. Mean-field Monte Carlo approach to the Sherrington-Kirkpatrick model with asymmetric couplings. Physical Review E, 50(2):709, 1994.
  • Georges et al. [1996] A. Georges, G. Kotliar, W. Krauth, and M. J. Rozenberg. Dynamical mean-field theory of strongly correlated fermion systems and the limit of infinite dimensions. Reviews of Modern Physics, 68(1):13, 1996.
  • Gerbelot et al. [2023] C. Gerbelot, E. Troiani, F. Mignacco, F. Krzakala, and L. Zdeborova. Rigorous dynamical mean field theory for stochastic gradient descent methods, 2023.
  • Ghorbani et al. [2019] B. Ghorbani, S. Mei, T. Misiakiewicz, and A. Montanari. Limitations of lazy training of two-layers neural network. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
  • Ghorbani et al. [2020] B. Ghorbani, S. Mei, T. Misiakiewicz, and A. Montanari. When do neural networks outperform kernel methods? In H. Larochelle, M. Ranzato, R. Hadsell, M. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 14820–14830. Curran Associates, Inc., 2020.
  • Goel et al. [2020] S. Goel, A. Gollakota, Z. **, S. Karmalkar, and A. Klivans. Superpolynomial lower bounds for learning one-layer neural networks using gradient descent. In International Conference on Machine Learning, pages 3587–3596. PMLR, 2020.
  • Kearns [1998] M. Kearns. Efficient noise-tolerant learning from statistical queries. Journal of the ACM (JACM), 45(6):983–1006, 1998.
  • Loureiro et al. [2021] B. Loureiro, C. Gerbelot, H. Cui, S. Goldt, F. Krzakala, M. Mezard, and L. Zdeborová. Learning curves of generic features maps for realistic datasets with a teacher-student model. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P. Liang, and J. W. Vaughan, editors, Advances in Neural Information Processing Systems, volume 34, pages 18137–18151. Curran Associates, Inc., 2021.
  • Maillard et al. [2020] A. Maillard, B. Loureiro, F. Krzakala, and L. Zdeborová. Phase retrieval in high dimensions: Statistical and computational phase transitions, 2020.
  • Mannelli and Urbani [2021] S. S. Mannelli and P. Urbani. Just a momentum: Analytical study of momentum-based acceleration methods in paradigmatic high-dimensional non-convex problems. NeurIPS, 2021.
  • Mannelli et al. [2019a] S. S. Mannelli, G. Biroli, C. Cammarota, F. Krzakala, and L. Zdeborová. Who is afraid of big bad minima? analysis of gradient-flow in spiked matrix-tensor models. In Advances in Neural Information Processing Systems, pages 8676–8686, 2019a.
  • Mannelli et al. [2019b] S. S. Mannelli, F. Krzakala, P. Urbani, and L. Zdeborova. Passed & spurious: Descent algorithms and local minima in spiked matrix-tensor models. In international conference on machine learning, pages 4333–4342, 2019b.
  • Mannelli et al. [2020] S. S. Mannelli, G. Biroli, C. Cammarota, F. Krzakala, P. Urbani, and L. Zdeborová. Marvels and pitfalls of the langevin algorithm in noisy high-dimensional inference. Physical Review X, 10(1):011057, 2020.
  • Mei et al. [2018] S. Mei, A. Montanari, and P.-M. Nguyen. A mean field view of the landscape of two-layer neural networks. Proceedings of the National Academy of Sciences, 115(33):E7665–E7671, 2018.
  • Mignacco and Urbani [2022] F. Mignacco and P. Urbani. The effective noise of stochastic gradient descent. Journal of Statistical Mechanics: Theory and Experiment, 2022(8):083405, aug 2022. doi: 10.1088/1742-5468/ac841d. URL https://doi.org/10.1088/1742-5468/ac841d.
  • Mignacco et al. [2020] F. Mignacco, F. Krzakala, P. Urbani, and L. Zdeborová. Dynamical mean-field theory for stochastic gradient descent in gaussian mixture classification. Advances in Neural Information Processing Systems, 33:9540–9550, 2020.
  • Mignacco et al. [2021] F. Mignacco, P. Urbani, and L. Zdeborová. Stochasticity helps to navigate rough landscapes: comparing gradient-descent-based algorithms in the phase retrieval problem. Machine Learning: Science and Technology, 2(3):035029, 2021.
  • Moniri et al. [2023] B. Moniri, D. Lee, H. Hassani, and E. Dobriban. A theory of non-linear feature learning with one gradient step in two-layer neural networks, 2023.
  • Montanari and Saeed [2022] A. Montanari and B. N. Saeed. Universality of empirical risk minimization. In P.-L. Loh and M. Raginsky, editors, Proceedings of Thirty Fifth Conference on Learning Theory, volume 178 of Proceedings of Machine Learning Research, pages 4310–4312. PMLR, 02–05 Jul 2022.
  • Mousavi-Hosseini et al. [2023] A. Mousavi-Hosseini, D. Wu, T. Suzuki, and M. A. Erdogdu. Gradient-based feature learning under structured data, 2023.
  • Rotskoff and Vanden-Eijnden [2022] G. Rotskoff and E. Vanden-Eijnden. Trainability and accuracy of artificial neural networks: An interacting particle system approach. Communications on Pure and Applied Mathematics, 75(9):1889–1935, 2022. doi: https://doi.org/10.1002/cpa.22074.
  • Roy et al. [2019] F. Roy, G. Biroli, G. Bunin, and C. Cammarota. Numerical implementation of dynamical mean field theory for disordered systems: application to the lotka–volterra model of ecosystems. Journal of Physics A: Mathematical and Theoretical, 52(48):484001, Nov. 2019. ISSN 1751-8121. doi: 10.1088/1751-8121/ab1f32. URL http://dx.doi.org/10.1088/1751-8121/ab1f32.
  • Saad and Solla [1995] D. Saad and S. A. Solla. On-line learning in soft committee machines. Physical Review E, 52(4):4225–4243, Oct. 1995. doi: 10.1103/PhysRevE.52.4225.
  • Sirignano and Spiliopoulos [2020] J. Sirignano and K. Spiliopoulos. Mean field analysis of neural networks: A central limit theorem. Stochastic Processes and their Applications, 130(3):1820–1852, 2020.
  • Sompolinsky and Zippelius [1981] H. Sompolinsky and A. Zippelius. Dynamic theory of the spin-glass phase. Phys. Rev. Lett., 47:359–362, Aug 1981.
  • Sompolinsky et al. [1988] H. Sompolinsky, A. Crisanti, and H. J. Sommers. Chaos in random neural networks. Phys. Rev. Lett., 61:259–262, Jul 1988.
  • Zweig and Bruna [2023] A. Zweig and J. Bruna. Symmetric single index learning, 2023.

Appendix A Mathematical Proofs

A.1 Notations

We use the asymptotic notation f1(d)=Θd(f1(d))subscript𝑓1𝑑subscriptΘ𝑑subscript𝑓1𝑑f_{1}(d)=\Theta_{d}(f_{1}(d))italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_d ) = roman_Θ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_d ) ) to denote c|f1(d)||f2(d)|C|f1(d)|𝑐subscript𝑓1𝑑subscript𝑓2𝑑𝐶subscript𝑓1𝑑c\left\lvert f_{1}(d)\right\rvert\leq\left\lvert f_{2}(d)\right\rvert\leq C% \left\lvert f_{1}(d)\right\rvertitalic_c | italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_d ) | ≤ | italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_d ) | ≤ italic_C | italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_d ) | for some constants c,C>0𝑐𝐶0c,C>0italic_c , italic_C > 0 and large enough d𝑑ditalic_d. Similarly, f1(d)=od(f1(d))subscript𝑓1𝑑subscript𝑜𝑑subscript𝑓1𝑑f_{1}(d)=o_{d}(f_{1}(d))italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_d ) = italic_o start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_d ) ) denotes f1(d)c(f1(d))subscript𝑓1𝑑𝑐subscript𝑓1𝑑f_{1}(d)\leq c(f_{1}(d))italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_d ) ≤ italic_c ( italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_d ) ) for any constant c>0𝑐0c>0italic_c > 0 and large enough d𝑑ditalic_d. We use 𝑃d,n,𝐷d,n𝑃𝑑𝑛𝐷𝑑𝑛\xrightarrow[P]{d,n\rightarrow\infty},\xrightarrow[D]{d,n\rightarrow\infty}start_ARROW underitalic_P start_ARROW start_OVERACCENT italic_d , italic_n → ∞ end_OVERACCENT → end_ARROW end_ARROW , start_ARROW underitalic_D start_ARROW start_OVERACCENT italic_d , italic_n → ∞ end_OVERACCENT → end_ARROW end_ARROW to denote convergence in probability and convergence in distribution respectively as d,n𝑑𝑛d,n\rightarrow\inftyitalic_d , italic_n → ∞ with n/d=α>0𝑛𝑑𝛼0n/d=\alpha>0italic_n / italic_d = italic_α > 0. We denote subspaces and linear operators, matrices on dsuperscript𝑑\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT through uppercase letters A,B,C,𝐴𝐵𝐶A,B,C,\cdotsitalic_A , italic_B , italic_C , ⋯. For any subspace Ad𝐴superscript𝑑A\in\mathbb{R}^{d}italic_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, we denote by Asubscript𝐴perpendicular-toA_{\perp}italic_A start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT, its orthogonal complement, i.e the subspace of vectors orthogonal to all 𝐯A𝐯𝐴\mathbf{v}\in Abold_v ∈ italic_A.

A.2 DMFT and iterative conditioning

Unlike online SGD, the preactivations after multiple steps no longer remain Gaussian since the weights become dependent on the data. This prevents marginalizing over the orthogonal components over the preactivations and relating the learning of new directions to the Hermite decomposition of the target function. Our proof circumvents these issues by utilizing a simpler effective process that decouples the pre-activations for different samples. The effective process is obtained using a rigorous version of the Dynamical Mean Field Theory derived in [Montanari and Saeed, 2022] and [Gerbelot et al., 2023].

The derivation of Dynamical Mean Field Theory in the above works has the following essential elements:

  • 1.

    Iterative conditioning: The proof in [Gerbelot et al., 2023, Montanari and Saeed, 2022] for obtaining the DMFT equations relies on the observation that the gradient descent algorithm in Equation (28) for a finite-number of iterations can be described completely through projections of the inputs design matrix 𝐙n×d𝐙superscript𝑛𝑑\mathbf{Z}\in\mathbb{R}^{n\times d}bold_Z ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_d end_POSTSUPERSCRIPT along a finite number of vectors in n,dsuperscript𝑛superscript𝑑\mathbb{R}^{n},\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. The iterative conditioning technique [Bolthausen, 2014, Bayati and Montanari, 2011] then involves replacing the components of 𝐙n×d𝐙superscript𝑛𝑑\mathbf{Z}\in\mathbb{R}^{n\times d}bold_Z ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_d end_POSTSUPERSCRIPT along directions orthogonal to these projections by independent Gaussian random variables. This leads to a non-Markovian structure in the effective processes for the activations, parameters.

  • 2.

    The concentration of finite-dimensional order parameters such as overlaps of the neuron parameters with the teacher neurons/subspace as well as expectations w.r.t the empirical measure of the pre-activations and parameters.

Using the above elements, DMFT provides a low-dimensional effective dynamics characterizing the limiting joint empirical measure of the student parameters, as well as the pre-activations. We illustrate the proof for the activations after the first gradient step, illustrating the relationship with the “hidden progress” described in section 4.1

Let 𝐇(t)=1d𝐙(𝐖(t))superscript𝐇𝑡1𝑑𝐙superscriptsuperscript𝐖𝑡top\mathbf{H}^{(t)}=\frac{1}{\sqrt{d}}\mathbf{Z}(\mathbf{W}^{(t)})^{\top}bold_H start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG bold_Z ( bold_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT denote the n×p𝑛𝑝n\times pitalic_n × italic_p matrix of pre-activations at time t𝑡titalic_t. Similarly, let 𝐇n×ksuperscript𝐇superscript𝑛𝑘\mathbf{H}^{*}\in\mathbb{R}^{n\times k}bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_k end_POSTSUPERSCRIPT denote the matrix of input activations in the target function. We denote by 𝐇(𝐇,𝐇(t))n×psubscript𝐇superscript𝐇superscript𝐇𝑡superscript𝑛𝑝\nabla_{\mathbf{H}}\mathcal{L}(\mathbf{H}^{*},\mathbf{H}^{(t)})\in\mathbb{R}^{% n\times p}∇ start_POSTSUBSCRIPT bold_H end_POSTSUBSCRIPT caligraphic_L ( bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_p end_POSTSUPERSCRIPT, the matrix derivative of \mathcal{L}caligraphic_L w.r.t the corresponding entries of the pre-activations matrix 𝐇(t)superscript𝐇𝑡\mathbf{H}^{(t)}bold_H start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT. After each gradient update, 𝐖(t)superscript𝐖𝑡\mathbf{W}^{(t)}bold_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and the preactivations 𝐇(t)superscript𝐇𝑡\mathbf{H}^{(t)}bold_H start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT gain a dependence on 𝐙𝐙\mathbf{Z}bold_Z. The Iterative conditioning technique works around this dependence by conditioning on the sigma algebra generated by 𝐇(t),𝐇superscript𝐇𝑡superscript𝐇\mathbf{H}^{(t)},\mathbf{H}^{*}bold_H start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT and 𝐖(t)superscript𝐖𝑡\mathbf{W}^{(t)}bold_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT instead of on 𝐙𝐙\mathbf{Z}bold_Z. Since 𝐙𝐙\mathbf{Z}bold_Z interacts with 𝐖(t)superscript𝐖𝑡\mathbf{W}^{(t)}bold_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and 𝐇(t)superscript𝐇𝑡\mathbf{H}^{(t)}bold_H start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT only through projections (along right with 𝐖(t)superscript𝐖𝑡\mathbf{W}^{(t)}bold_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and left with 𝐇(𝐇,𝐇(t))subscript𝐇superscript𝐇superscript𝐇𝑡\nabla_{\mathbf{H}}\mathcal{L}(\mathbf{H}^{*},\mathbf{H}^{(t)})∇ start_POSTSUBSCRIPT bold_H end_POSTSUBSCRIPT caligraphic_L ( bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) respectively), the conditioning allows the components of 𝐙𝐙\mathbf{Z}bold_Z orthogonal to H(𝐇(),𝐇(t))subscript𝐻superscript𝐇superscript𝐇𝑡\nabla_{H}\mathcal{L}(\mathbf{H}^{(*)},\mathbf{H}^{(t)})∇ start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT caligraphic_L ( bold_H start_POSTSUPERSCRIPT ( ∗ ) end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) and 𝐖(t)superscript𝐖𝑡\mathbf{W}^{(t)}bold_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT to be replaced by independent Gaussian entries.

For the first-gradient step, we only require conditioning on 𝐇(0),𝐇,𝐖0superscript𝐇0superscript𝐇superscript𝐖0\mathbf{H}^{(0)},\mathbf{H}^{*},\mathbf{W}^{0}bold_H start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_W start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT.

From ((5)), we obtain the following update for 𝐇(1)superscript𝐇1\mathbf{H}^{(1)}bold_H start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT:

𝐇(1)=𝐇(0)+ηd𝐙(𝐙)𝐇(𝐇,𝐇(0))𝐚,superscript𝐇1superscript𝐇0𝜂𝑑𝐙superscript𝐙topsubscript𝐇superscript𝐇superscript𝐇0superscript𝐚top\mathbf{H}^{(1)}=\mathbf{H}^{(0)}+\frac{\eta}{d}\mathbf{Z}(\mathbf{Z})^{\top}% \nabla_{\mathbf{H}}\mathcal{L}(\mathbf{H}^{*},\mathbf{H}^{(0)})\mathbf{a}^{% \top},bold_H start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = bold_H start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT + divide start_ARG italic_η end_ARG start_ARG italic_d end_ARG bold_Z ( bold_Z ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_H end_POSTSUBSCRIPT caligraphic_L ( bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) bold_a start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , (28)

Let 𝐖¯(0)=(𝐖(0)𝐖)p+k×dsuperscript¯𝐖0matrixsuperscript𝐖0superscript𝐖superscript𝑝𝑘𝑑\mathbf{\bar{W}}^{(0)}=\begin{pmatrix}\mathbf{W}^{(0)}\\ \mathbf{W}^{\star}\end{pmatrix}\in\mathbb{R}^{p+k\times d}over¯ start_ARG bold_W end_ARG start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT = ( start_ARG start_ROW start_CELL bold_W start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_p + italic_k × italic_d end_POSTSUPERSCRIPT By the equivalence of projection and conditioning for Gaussian random variables, we have that the following inequality holds in distribution:

𝐙|𝐇0,𝐇,𝐖¯(0)=𝑑𝐙Pw+𝐙~(Pw),evaluated-at𝐙superscript𝐇0superscript𝐇superscript¯𝐖0𝑑𝐙superscriptsubscript𝑃𝑤top~𝐙superscriptsubscriptsuperscript𝑃perpendicular-to𝑤top\mathbf{Z}|_{\mathbf{H}^{0},\mathbf{H}^{*},\mathbf{\bar{W}}^{(0)}}\overset{d}{% =}\mathbf{Z}P_{w}^{\top}+\tilde{\mathbf{Z}}(P^{\perp}_{w})^{\top},bold_Z | start_POSTSUBSCRIPT bold_H start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , over¯ start_ARG bold_W end_ARG start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT overitalic_d start_ARG = end_ARG bold_Z italic_P start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + over~ start_ARG bold_Z end_ARG ( italic_P start_POSTSUPERSCRIPT ⟂ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , (29)

where 𝐙~~𝐙\tilde{\mathbf{Z}}over~ start_ARG bold_Z end_ARG is independent of 𝐙𝐙\mathbf{Z}bold_Z and Pwsubscript𝑃𝑤P_{w}italic_P start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT denote the projection operator along 𝐖¯(0),𝐖¯superscript¯𝐖0superscript¯𝐖\mathbf{\bar{W}}^{(0)},\mathbf{\bar{W}}^{\star}over¯ start_ARG bold_W end_ARG start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , over¯ start_ARG bold_W end_ARG start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, defined as:

Pw=𝐖¯(0)(𝐖¯(0)(𝐖¯(0)))1(𝐖¯(0)).subscript𝑃𝑤superscript¯𝐖0superscriptsuperscript¯𝐖0superscriptsuperscript¯𝐖0top1superscriptsuperscript¯𝐖0top\displaystyle P_{w}=\mathbf{\bar{W}}^{(0)}(\mathbf{\bar{W}}^{(0)}(\mathbf{\bar% {W}}^{(0)})^{\top})^{-1}(\mathbf{\bar{W}}^{(0)})^{\top}.italic_P start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = over¯ start_ARG bold_W end_ARG start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ( over¯ start_ARG bold_W end_ARG start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ( over¯ start_ARG bold_W end_ARG start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( over¯ start_ARG bold_W end_ARG start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT .

Substituting in Equation ((28)), we obtain:

𝐇(1)=𝑑𝐇(0)+η1d𝐙~(Pw)Pw(𝐙~)𝐇(𝐇,𝐇(0))𝐚+η1d𝐇𝟎(𝐖¯(0)(𝐖¯(0)))1(𝐇𝟎)𝐇(𝐇,𝐇(0))𝐚superscript𝐇1𝑑superscript𝐇0𝜂1𝑑~𝐙superscriptsubscriptsuperscript𝑃perpendicular-to𝑤topsubscriptsuperscript𝑃perpendicular-to𝑤superscript~𝐙topsubscript𝐇superscript𝐇superscript𝐇0superscript𝐚top𝜂1𝑑subscript𝐇0superscriptsuperscript¯𝐖0superscriptsuperscript¯𝐖0top1superscriptsubscript𝐇0topsubscript𝐇superscript𝐇superscript𝐇0superscript𝐚top\mathbf{H}^{(1)}\overset{d}{=}\mathbf{H}^{(0)}+\eta\frac{1}{d}\tilde{\mathbf{Z% }}(P^{\perp}_{w})^{\top}P^{\perp}_{w}(\tilde{\mathbf{Z}})^{\top}\nabla_{% \mathbf{H}}\mathcal{L}(\mathbf{H}^{*},\mathbf{H}^{(0)})\mathbf{a}^{\top}+\eta% \frac{1}{d}\mathbf{H_{0}}(\mathbf{\bar{W}}^{(0)}(\mathbf{\bar{W}}^{(0)})^{\top% })^{-1}(\mathbf{H_{0}})^{\top}\nabla_{\mathbf{H}}\mathcal{L}(\mathbf{H}^{*},% \mathbf{H}^{(0)})\mathbf{a}^{\top}bold_H start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT overitalic_d start_ARG = end_ARG bold_H start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT + italic_η divide start_ARG 1 end_ARG start_ARG italic_d end_ARG over~ start_ARG bold_Z end_ARG ( italic_P start_POSTSUPERSCRIPT ⟂ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_P start_POSTSUPERSCRIPT ⟂ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( over~ start_ARG bold_Z end_ARG ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_H end_POSTSUBSCRIPT caligraphic_L ( bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) bold_a start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_η divide start_ARG 1 end_ARG start_ARG italic_d end_ARG bold_H start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT ( over¯ start_ARG bold_W end_ARG start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ( over¯ start_ARG bold_W end_ARG start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_H start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_H end_POSTSUBSCRIPT caligraphic_L ( bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) bold_a start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (30)

Since the projection, Pwsubscript𝑃𝑤P_{w}italic_P start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT is along a low-dimensional subspace of dimension at most p𝑝pitalic_p, we have Pw𝐈dsubscriptsuperscript𝑃perpendicular-to𝑤subscript𝐈𝑑P^{\perp}_{w}\approx\mathbf{I}_{d}italic_P start_POSTSUPERSCRIPT ⟂ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ≈ bold_I start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT. One can therefore show that 1d𝐙~(Pw)Pw(𝐙~)𝐮1𝑑~𝐙superscriptsubscriptsuperscript𝑃perpendicular-to𝑤topsubscriptsuperscript𝑃perpendicular-to𝑤superscript~𝐙top𝐮\frac{1}{\sqrt{d}}\tilde{\mathbf{Z}}(P^{\perp}_{w})^{\top}P^{\perp}_{w}(\tilde% {\mathbf{Z}})^{\top}\mathbf{u}divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG over~ start_ARG bold_Z end_ARG ( italic_P start_POSTSUPERSCRIPT ⟂ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_P start_POSTSUPERSCRIPT ⟂ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( over~ start_ARG bold_Z end_ARG ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_u converges in probability to 𝐙~𝐙~𝐮~𝐙superscript~𝐙top𝐮\tilde{\mathbf{Z}}\tilde{\mathbf{Z}}^{\top}\mathbf{u}over~ start_ARG bold_Z end_ARG over~ start_ARG bold_Z end_ARG start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_u. for any deterministic 𝐮d𝐮superscript𝑑\mathbf{u}\in\mathbb{R}^{d}bold_u ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT with u=𝒪(d)delimited-∥∥𝑢𝒪𝑑\left\lVert u\right\rVert=\mathcal{O}(\sqrt{d})∥ italic_u ∥ = caligraphic_O ( square-root start_ARG italic_d end_ARG ). Applying it to the vector 𝐮=(𝐇,𝐇(0))𝐮superscript𝐇superscript𝐇0\mathbf{u}=\mathcal{L}(\mathbf{H}^{*},\mathbf{H}^{(0)})bold_u = caligraphic_L ( bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ), conditioned on 𝐇,𝐇(0)superscript𝐇superscript𝐇0\mathbf{H}^{*},\mathbf{H}^{(0)}bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT, we obtain that:

1d1d𝐙~(Pw)Pw(𝐙~)𝐇(𝐇,𝐇(0))1d𝐙~(𝐙~)𝐇(𝐇,𝐇(0))F𝑃n,d0.𝑃𝑛𝑑1𝑑subscriptdelimited-∥∥1𝑑~𝐙superscriptsubscriptsuperscript𝑃perpendicular-to𝑤topsubscriptsuperscript𝑃perpendicular-to𝑤superscript~𝐙topsubscript𝐇superscript𝐇superscript𝐇01𝑑~𝐙superscript~𝐙topsubscript𝐇superscript𝐇superscript𝐇0𝐹0\frac{1}{\sqrt{d}}\left\lVert\frac{1}{d}\tilde{\mathbf{Z}}(P^{\perp}_{w})^{% \top}P^{\perp}_{w}(\tilde{\mathbf{Z}})^{\top}\nabla_{\mathbf{H}}\mathcal{L}(% \mathbf{H}^{*},\mathbf{H}^{(0)})-\frac{1}{d}\tilde{\mathbf{Z}}(\tilde{\mathbf{% Z}})^{\top}\nabla_{\mathbf{H}}\mathcal{L}(\mathbf{H}^{*},\mathbf{H}^{(0)})% \right\rVert_{F}\xrightarrow[P]{n,d\rightarrow\infty}0.divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ∥ divide start_ARG 1 end_ARG start_ARG italic_d end_ARG over~ start_ARG bold_Z end_ARG ( italic_P start_POSTSUPERSCRIPT ⟂ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_P start_POSTSUPERSCRIPT ⟂ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( over~ start_ARG bold_Z end_ARG ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_H end_POSTSUBSCRIPT caligraphic_L ( bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) - divide start_ARG 1 end_ARG start_ARG italic_d end_ARG over~ start_ARG bold_Z end_ARG ( over~ start_ARG bold_Z end_ARG ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_H end_POSTSUBSCRIPT caligraphic_L ( bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_ARROW underitalic_P start_ARROW start_OVERACCENT italic_n , italic_d → ∞ end_OVERACCENT → end_ARROW end_ARROW 0 . (31)

Now, the diagonal entries of 1d𝐙~(𝐙~)1𝑑~𝐙superscript~𝐙top\frac{1}{d}\tilde{\mathbf{Z}}(\tilde{\mathbf{Z}})^{\top}divide start_ARG 1 end_ARG start_ARG italic_d end_ARG over~ start_ARG bold_Z end_ARG ( over~ start_ARG bold_Z end_ARG ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT convergence in probability to 1111 due to the concentration of norms of Gaussian random vectors. This results in the term (𝐇,𝐇(0))superscript𝐇superscript𝐇0\mathcal{L}(\mathbf{H}^{*},\mathbf{H}^{(0)})caligraphic_L ( bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ). This term is precisely the one responsible for the “hidden progress” explained in section 4.1, corresponding to the term in Equation 18. Since 𝐙~~𝐙\tilde{\mathbf{Z}}over~ start_ARG bold_Z end_ARG is independent of 𝐖(0)superscript𝐖0\mathbf{W}^{(0)}bold_W start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT and 𝐇(0),𝐇superscript𝐇0superscript𝐇\mathbf{H}^{(0)},\mathbf{H}^{*}bold_H start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, by central limit-theorem for sub-Gaussian random variables, the remaining off-diagonal terms can be shown to converge to Gaussian noise independent of 𝐇(0),𝐇superscript𝐇0superscript𝐇\mathbf{H}^{(0)},\mathbf{H}^{*}bold_H start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT with variance 𝐡(𝐡,𝐡(0))2superscriptdelimited-∥∥subscript𝐡superscript𝐡superscript𝐡02\left\lVert\nabla_{\mathbf{h}}\mathcal{L}(\mathbf{h}^{*},\mathbf{h}^{(0)})% \right\rVert^{2}∥ ∇ start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT caligraphic_L ( bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

Lastly, the third term in Equation (30) can be shown to converge to Gaussian noise correlated with corresponding entries of 𝐇𝟎,𝐇subscript𝐇0superscript𝐇\mathbf{H_{0}},\mathbf{H^{*}}bold_H start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT , bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. Specifically, by removing the conditioning on 𝐇𝟎subscript𝐇0\mathbf{H_{0}}bold_H start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT, we have through law of large numbers and Stein’s Lemma, we have that the term 1d(𝐇𝟎)𝐇(𝐇,𝐇(0))1𝑑superscriptsubscript𝐇0topsubscript𝐇superscript𝐇superscript𝐇0\frac{1}{d}(\mathbf{H_{0}})^{\top}\nabla_{\mathbf{H}}\mathcal{L}(\mathbf{H}^{*% },\mathbf{H}^{(0)})divide start_ARG 1 end_ARG start_ARG italic_d end_ARG ( bold_H start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_H end_POSTSUBSCRIPT caligraphic_L ( bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) converges in probability to (𝐖(0)(𝐖(0)))𝔼[𝐡2(𝐡,𝐡(0))]superscript𝐖0superscriptsuperscript𝐖0top𝔼delimited-[]subscriptsuperscript2𝐡superscript𝐡superscript𝐡0(\mathbf{W}^{(0)}(\mathbf{W}^{(0)})^{\top})\mathbb{E}\left[\nabla^{2}_{\mathbf% {h}}\mathcal{L}(\mathbf{h}^{*},\mathbf{h}^{(0)})\right]( bold_W start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ( bold_W start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) blackboard_E [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT caligraphic_L ( bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) ]. Therefore, we obtain:

1d1d𝐇𝟎(𝐖(0)(𝐖(0)))1(𝐇𝟎)𝐇(𝐇,𝐡(0))𝐚𝐇𝟎𝔼[𝐡2(𝐡,𝐡(0))])𝐚𝑃n,d0\frac{1}{\sqrt{d}}\left\lVert\frac{1}{d}\mathbf{H_{0}}(\mathbf{W}^{(0)}(% \mathbf{W}^{(0)})^{\top})^{-1}(\mathbf{H_{0}})^{\top}\nabla_{\mathbf{H}}% \mathcal{L}(\mathbf{H}^{*},\mathbf{h}^{(0)})\mathbf{a}^{\top}-\mathbf{H_{0}}% \mathbb{E}\left[\nabla^{2}_{\mathbf{h}}\mathcal{L}(\mathbf{h}^{*},\mathbf{h}^{% (0)})\right])\mathbf{a}^{\top}\right\rVert\xrightarrow[P]{n,d\rightarrow\infty}0divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ∥ divide start_ARG 1 end_ARG start_ARG italic_d end_ARG bold_H start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT ( bold_W start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ( bold_W start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_H start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_H end_POSTSUBSCRIPT caligraphic_L ( bold_H start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) bold_a start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - bold_H start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT blackboard_E [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT caligraphic_L ( bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) ] ) bold_a start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∥ start_ARROW underitalic_P start_ARROW start_OVERACCENT italic_n , italic_d → ∞ end_OVERACCENT → end_ARROW end_ARROW 0 (32)

Proceeding similarly, one obtains low-dimensional effective processes for 𝐖(t),𝐇(t)superscript𝐖𝑡superscript𝐇𝑡\mathbf{W}^{(t)},\mathbf{H}^{(t)}bold_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_H start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT for any time t𝑡t\in\mathbb{N}italic_t ∈ blackboard_N. In the following section, we derive the resulting DMFT dynamics for the setup considered in Section 1 through a reduction to the result in [Gerbelot et al., 2023]. We refer to [Gerbelot et al., 2023, Celentano et al., 2021] for detailed proofs based on the above technique.

A.3 Derivation of the exact asymptotics

We start by stating a general consequence of the main result in [Gerbelot et al., 2023].

Theorem A.1 (Corollary of Theorem 3.2 in [Gerbelot et al., 2023]).

Let W0q×dsuperscript𝑊0superscript𝑞𝑑W^{0}\in\mathbb{R}^{q\times d}italic_W start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_q × italic_d end_POSTSUPERSCRIPT be a sequence of matrices such that the overlap matrix satifies:

1dW0(W0)a.sdQ(0),formulae-sequence𝑎𝑠𝑑1𝑑superscript𝑊0superscriptsuperscript𝑊0topsuperscript𝑄0\frac{1}{d}W^{0}(W^{0})^{\top}\xrightarrow[a.s]{d\rightarrow\infty}Q^{(0)},divide start_ARG 1 end_ARG start_ARG italic_d end_ARG italic_W start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ( italic_W start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT start_ARROW start_UNDERACCENT italic_a . italic_s end_UNDERACCENT start_ARROW start_OVERACCENT italic_d → ∞ end_OVERACCENT → end_ARROW end_ARROW italic_Q start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , (33)

where Q(0)𝒮p+superscript𝑄0subscriptsuperscript𝒮𝑝Q^{(0)}\in\mathcal{S}^{+}_{p}italic_Q start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ∈ caligraphic_S start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT denotes a fixed matrix. Consider a dynamics of the form:

W(t+1)=W(t)ηλW(t)η1dν=1nF(W(t)𝐳νd)𝐳νsuperscript𝑊𝑡1superscript𝑊𝑡𝜂𝜆superscript𝑊𝑡𝜂1𝑑superscriptsubscript𝜈1𝑛𝐹superscript𝑊𝑡subscript𝐳𝜈𝑑superscriptsubscript𝐳𝜈top\displaystyle W^{(t+1)}=W^{(t)}-\eta\lambda W^{(t)}-\eta\frac{1}{\sqrt{d}}\sum% _{\nu=1}^{n}F\left(\frac{W^{(t)}\mathbf{z}_{\nu}}{\sqrt{d}}\right)\mathbf{z}_{% \nu}^{\top}italic_W start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_η italic_λ italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_η divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ∑ start_POSTSUBSCRIPT italic_ν = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_F ( divide start_ARG italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (34)

where F:qq:𝐹superscript𝑞superscript𝑞F:\mathbb{R}^{q}\rightarrow\mathbb{R}^{q}italic_F : blackboard_R start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT is pseudo-Lipshitz of finite-order and {zν}ν=1nsuperscriptsubscriptsubscript𝑧𝜈𝜈1𝑛\{z_{\nu}\}_{\nu=1}^{n}{ italic_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_ν = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT are i.i.d vectors distributed as zν𝒩(0,𝕀d)similar-tosubscript𝑧𝜈𝒩0subscript𝕀𝑑z_{\nu}\sim\mathcal{N}(0,\mathbb{I}_{d})italic_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , blackboard_I start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ), such that n,d𝑛𝑑n,d\rightarrow\inftyitalic_n , italic_d → ∞ with n/d=α>0𝑛𝑑𝛼0n/d=\alpha>0italic_n / italic_d = italic_α > 0. Then the empirical measure of the weights 𝐰i(t)subscriptsuperscript𝐰𝑡𝑖\mathbf{w}^{(t)}_{i}bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT converges in distribution to the weight process θi(t)superscriptsubscript𝜃𝑖𝑡\mathbf{\theta}_{i}^{(t)}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and the empirical measure of the preactivations 𝐡ν(t)subscriptsuperscript𝐡𝑡𝜈\mathbf{h}^{(t)}_{\nu}bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT converges in distribution to that of the preactivation process 𝐡(t)superscript𝐡𝑡\mathbf{h}^{(t)}bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT, defined as

𝜽(t+1)𝜽(t)=η(λ+Λ(t))𝜽(t)+ητ=0t1R(t,τ)𝜽(τ)+η𝒖(t)superscript𝜽𝑡1superscript𝜽𝑡𝜂𝜆superscriptΛ𝑡superscript𝜽𝑡𝜂superscriptsubscript𝜏0𝑡1superscriptsubscript𝑅𝑡𝜏superscript𝜽𝜏𝜂superscript𝒖𝑡\bm{\theta}^{(t+1)}-\bm{\theta}^{(t)}=-\,\eta\,\left(\lambda+\Lambda^{(t)}% \right)\bm{\theta}^{(t)}+\,\eta\sum_{\tau=0}^{t-1}R_{\ell}^{(t,\tau)}\bm{% \theta}^{(\tau)}+\,\eta\bm{u}^{(t)}bold_italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = - italic_η ( italic_λ + roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_η ∑ start_POSTSUBSCRIPT italic_τ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT + italic_η bold_italic_u start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT (35)
𝒉(t)=ητ=0t1Rθ(t,τ)F(𝒉(τ))+𝝎(t)superscript𝒉𝑡𝜂superscriptsubscript𝜏0𝑡1superscriptsubscript𝑅𝜃𝑡𝜏𝐹superscript𝒉𝜏superscript𝝎𝑡\bm{h}^{(t)}=-\eta\sum_{\tau=0}^{t-1}R_{\theta}^{(t,\tau)}F(\bm{h}^{(\tau)})+% \bm{\omega}^{(t)}bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = - italic_η ∑ start_POSTSUBSCRIPT italic_τ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT italic_F ( bold_italic_h start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ) + bold_italic_ω start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT (36)

were we have

Λ(t)=α𝔼[𝒉F(𝒉(t))]superscriptΛ𝑡𝛼𝔼delimited-[]subscript𝒉𝐹superscript𝒉𝑡\Lambda^{(t)}=\alpha\mathbb{E}\left[\nabla_{\bm{h}}F(\bm{h}^{(t)})\right]roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = italic_α blackboard_E [ ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT italic_F ( bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ] (37)
Rθ(t,τ)=𝔼[θ(t)τ(τ)]superscriptsubscript𝑅𝜃𝑡𝜏𝔼delimited-[]superscript𝜃𝑡superscript𝜏𝜏\displaystyle R_{\theta}^{(t,\tau)}=\mathbb{E}\left[\frac{\partial\,\mathbf{% \theta}^{(t)}}{\partial\,\mathbf{\tau}^{(\tau)}}\right]italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = blackboard_E [ divide start_ARG ∂ italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_ARG start_ARG ∂ italic_τ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT end_ARG ] (38)
R(t,τ)=α𝔼[F(𝐡(t))ω(τ)].superscriptsubscript𝑅𝑡𝜏𝛼𝔼delimited-[]𝐹superscript𝐡𝑡superscript𝜔𝜏\displaystyle R_{\ell}^{(t,\tau)}=\alpha\mathbb{E}\left[\frac{\partial\,F(% \mathbf{h}^{(t)})}{\partial\,\mathbf{\omega}^{(\tau)}}\right]\,.italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = italic_α blackboard_E [ divide start_ARG ∂ italic_F ( bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG start_ARG ∂ italic_ω start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT end_ARG ] . (39)

Finally, 𝐮(t)superscript𝐮𝑡\bm{u}^{(t)}bold_italic_u start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and 𝛚(t)superscript𝛚𝑡\bm{\omega}^{(t)}bold_italic_ω start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT are zero-mean Gaussian processes respectively with covariances given by C(t,τ)superscriptsubscript𝐶𝑡𝜏C_{\ell}^{(t,\tau)}italic_C start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT and Cθ(t,τ)superscriptsubscript𝐶𝜃𝑡𝜏C_{\theta}^{(t,\tau)}italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT:

C(t,τ)=α𝔼[𝒖(t)(𝒖(τ))]=α𝔼[F(𝒓(t))F(𝒓(τ))]superscriptsubscript𝐶𝑡𝜏𝛼𝔼delimited-[]superscript𝒖𝑡superscriptsuperscript𝒖𝜏top𝛼𝔼delimited-[]𝐹superscript𝒓𝑡𝐹superscriptsuperscript𝒓𝜏topC_{\ell}^{(t,\tau)}=\alpha\mathbb{E}\left[\bm{u}^{(t)}\left(\bm{u}^{(\tau)}% \right)^{\top}\right]=\alpha\mathbb{E}\left[F(\bm{r}^{(t)})F(\bm{r}^{(\tau)})^% {\top}\right]italic_C start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = italic_α blackboard_E [ bold_italic_u start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( bold_italic_u start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] = italic_α blackboard_E [ italic_F ( bold_italic_r start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_F ( bold_italic_r start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] (40)
Cθ(t,τ)=𝔼[𝝎(t)(𝝎(τ))]=𝔼[𝜽(t)(𝜽(τ))]superscriptsubscript𝐶𝜃𝑡𝜏𝔼delimited-[]superscript𝝎𝑡superscriptsuperscript𝝎𝜏top𝔼delimited-[]superscript𝜽𝑡superscriptsuperscript𝜽𝜏topC_{\theta}^{(t,\tau)}=\mathbb{E}\left[\bm{\omega}^{(t)}\left(\bm{\omega}^{(% \tau)}\right)^{\top}\right]=\mathbb{E}\left[\bm{\theta}^{(t)}\left(\bm{\theta}% ^{(\tau)}\right)^{\top}\right]italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = blackboard_E [ bold_italic_ω start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( bold_italic_ω start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] = blackboard_E [ bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] (41)

The convergence in distribution of the empirical measures holds in the following sense: For any t𝑡t\in\mathbb{N}italic_t ∈ blackboard_N, and any pseudo-Lipschitz functions ψ:p(t+1):𝜓superscript𝑝𝑡1\psi:\mathbb{R}^{p(t+1)}\to\mathbb{R}italic_ψ : blackboard_R start_POSTSUPERSCRIPT italic_p ( italic_t + 1 ) end_POSTSUPERSCRIPT → blackboard_R and ϕ:pt:italic-ϕsuperscript𝑝𝑡\phi:\mathbb{R}^{pt}\to\mathbb{R}italic_ϕ : blackboard_R start_POSTSUPERSCRIPT italic_p italic_t end_POSTSUPERSCRIPT → blackboard_R:

1di=1dψ((Wi(0),,Wi(t)))n,dw.h.p.𝔼[ψ(𝜽(0),,𝜽(t))],\displaystyle\frac{1}{d}\sum_{i=1}^{d}\psi((W_{i}^{(0)},...,W_{i}^{(t)}))% \xrightarrow[n,d\to\infty]{{\rm w.h.p.}}\mathbb{E}\left[\psi(\bm{\theta}^{(0)}% ,...,\bm{\theta}^{(t)})\right],divide start_ARG 1 end_ARG start_ARG italic_d end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_ψ ( ( italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , … , italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) start_ARROW start_UNDERACCENT italic_n , italic_d → ∞ end_UNDERACCENT start_ARROW start_OVERACCENT roman_w . roman_h . roman_p . end_OVERACCENT → end_ARROW end_ARROW blackboard_E [ italic_ψ ( bold_italic_θ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , … , bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ] , (42)
1nν=1nϕ((𝐡ν(0),,𝐡ν(t1)))n,dw.h.p.𝔼[ϕ(𝐡(0),,𝐡(t1))],\displaystyle\frac{1}{n}\sum_{\nu=1}^{n}\phi((\mathbf{h}_{\nu}^{(0)},...,% \mathbf{h}_{\nu}^{(t-1)}))\xrightarrow[n,d\to\infty]{{\rm w.h.p.}}\mathbb{E}% \left[\phi(\mathbf{h}^{(0)},...,\mathbf{h}^{(t-1)})\right],divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_ν = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_ϕ ( ( bold_h start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , … , bold_h start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT ) ) start_ARROW start_UNDERACCENT italic_n , italic_d → ∞ end_UNDERACCENT start_ARROW start_OVERACCENT roman_w . roman_h . roman_p . end_OVERACCENT → end_ARROW end_ARROW blackboard_E [ italic_ϕ ( bold_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , … , bold_h start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT ) ] , (43)

where Wi(t)superscriptsubscript𝑊𝑖𝑡W_{i}^{(t)}italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT denotes the ithsubscript𝑖𝑡i_{th}italic_i start_POSTSUBSCRIPT italic_t italic_h end_POSTSUBSCRIPT column of W(t)superscript𝑊𝑡W^{(t)}italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT

The above result follows directly by substituting 1d𝐙1𝑑𝐙\frac{1}{\sqrt{d}}\mathbf{Z}divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG bold_Z as 𝐗𝐗\mathbf{X}bold_X in Theorem 3.2 of Gerbelot et al. [2023].

The definitions of Rθsubscript𝑅𝜃R_{\theta}italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT and Rsubscript𝑅R_{\ell}italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT in Theorem A.1 require differentiating through the non-markovian processes defined by Equations 36, 35. Fortunately, Rθ,Rsubscript𝑅𝜃subscript𝑅R_{\theta},R_{\ell}italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT , italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT can be equivalently described through an explicit set of recursive updates, which we state below for convenience:

Rθ(t+1,τ)Rθ(t,τ)=η(λ+Λ(t))Rθ(t,τ)+ηs=τt1R(t,s)Rθ(τ,s)superscriptsubscript𝑅𝜃𝑡1𝜏superscriptsubscript𝑅𝜃𝑡𝜏𝜂𝜆superscriptΛ𝑡superscriptsubscript𝑅𝜃𝑡𝜏𝜂superscriptsubscript𝑠𝜏𝑡1superscriptsubscript𝑅𝑡𝑠superscriptsubscript𝑅𝜃𝜏𝑠R_{\theta}^{(t+1,\tau)}-R_{\theta}^{(t,\tau)}=-\,\eta\,\left(\lambda+\Lambda^{% (t)}\right)R_{\theta}^{(t,\tau)}+\,\eta\sum_{s=\tau}^{t-1}R_{\ell}^{(t,s)}R_{% \theta}^{(\tau,s)}italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 , italic_τ ) end_POSTSUPERSCRIPT - italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = - italic_η ( italic_λ + roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT + italic_η ∑ start_POSTSUBSCRIPT italic_s = italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_s ) end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_τ , italic_s ) end_POSTSUPERSCRIPT (44)

with boundary conditions

Rθ(t,t)=1,superscriptsubscript𝑅𝜃𝑡𝑡1\displaystyle R_{\theta}^{(t,t)}=1\,,italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_t ) end_POSTSUPERSCRIPT = 1 , (45)
Rθ(t+1,t)=1ηΛ(t),superscriptsubscript𝑅𝜃𝑡1𝑡1𝜂superscriptΛ𝑡\displaystyle R_{\theta}^{(t+1,t)}=1-\,\eta\,\Lambda^{(t)}\,,italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 , italic_t ) end_POSTSUPERSCRIPT = 1 - italic_η roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , (46)

while

R(t,τ)=α𝔼[𝒉F(𝒉(t))T(t,τ)],superscriptsubscript𝑅𝑡𝜏𝛼𝔼delimited-[]subscript𝒉𝐹superscript𝒉𝑡superscriptsubscript𝑇𝑡𝜏R_{\ell}^{(t,\tau)}=\alpha\mathbb{E}\left[\nabla_{\bm{h}}F(\bm{h}^{(t)})\,T_{% \ell}^{(t,\tau)}\right],italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = italic_α blackboard_E [ ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT italic_F ( bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_T start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT ] , (47)

where T(t,τ)superscriptsubscript𝑇𝑡𝜏T_{\ell}^{(t,\tau)}italic_T start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT is a collection of stochastic processes with distribution

T(t,τ)=Rθ(t,τ)𝒉F(𝒉(τ))+s=τ+1t1Rθ(t,s)T(s,τ)superscriptsubscript𝑇𝑡𝜏superscriptsubscript𝑅𝜃𝑡𝜏subscript𝒉𝐹superscript𝒉𝜏superscriptsubscript𝑠𝜏1𝑡1superscriptsubscript𝑅𝜃𝑡𝑠superscriptsubscript𝑇𝑠𝜏T_{\ell}^{(t,\tau)}=R_{\theta}^{(t,\tau)}\nabla_{\bm{h}}F(\bm{h}^{(\tau)})+% \sum_{s=\tau+1}^{t-1}R_{\theta}^{(t,s)}T_{\ell}^{(s,\tau)}italic_T start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT italic_F ( bold_italic_h start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_s = italic_τ + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_s ) end_POSTSUPERSCRIPT italic_T start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_s , italic_τ ) end_POSTSUPERSCRIPT (48)

and boundary conditions

T(t,t)=0,superscriptsubscript𝑇𝑡𝑡0\displaystyle T_{\ell}^{(t,t)}=0\,,italic_T start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_t ) end_POSTSUPERSCRIPT = 0 , (49)
T(t+1,t)=1ηΛ(t),superscriptsubscript𝑇𝑡1𝑡1𝜂superscriptΛ𝑡\displaystyle T_{\ell}^{(t+1,t)}=1-\,\eta\,\Lambda^{(t)}\,,italic_T start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 , italic_t ) end_POSTSUPERSCRIPT = 1 - italic_η roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , (50)

To obtain the limiting equations under the setting of gradient descent with teacher weights Wsuperscript𝑊W^{*}italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT in section 1, we utilize the generality of the update F𝐹Fitalic_F in theorem A.1, which allows for a portion of the parameters (Wsuperscript𝑊W^{*}italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT) to remain unaffected. We obtain the following result, which generalize the former theorem to the setting of our paper:

Theorem A.2.

Consider the distribution over data defined in section 2 and an update rule on the weights of the form (5), i.e:

𝐰i(t+1)=𝐰i(t)ηλ𝐰i(t)ην=1n𝐰i(t)(W(t)𝐳νd,W𝐳νd),superscriptsubscript𝐰𝑖𝑡1subscriptsuperscript𝐰𝑡𝑖𝜂𝜆subscriptsuperscript𝐰𝑡𝑖𝜂superscriptsubscript𝜈1𝑛subscriptsubscriptsuperscript𝐰𝑡𝑖superscript𝑊𝑡subscript𝐳𝜈𝑑superscript𝑊subscript𝐳𝜈𝑑\displaystyle\mathbf{w}_{i}^{(t+1)}=\mathbf{w}^{(t)}_{i}-\eta\lambda\mathbf{w}% ^{(t)}_{i}-\eta\sum_{\nu=1}^{n}\nabla_{\mathbf{w}^{(t)}_{i}}\,\mathcal{L}\left% (\frac{W^{(t)}\mathbf{z}_{\nu}}{\sqrt{d}},\frac{W^{\star}\mathbf{z}_{\nu}}{% \sqrt{d}}\right)\,,bold_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_η italic_λ bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_η ∑ start_POSTSUBSCRIPT italic_ν = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L ( divide start_ARG italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG , divide start_ARG italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) , (51)

Then under the assumptions of Theorem 3.2, as d𝑑d\rightarrow\inftyitalic_d → ∞ with n/d=α>0𝑛𝑑𝛼0n/d=\alpha>0italic_n / italic_d = italic_α > 0, the joint empirical measure of the coordinates of the student weights 𝐰i(t)superscriptsubscript𝐰𝑖𝑡\mathbf{w}_{i}^{(t)}bold_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and the teacher weights Wsuperscript𝑊W^{\star}italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT converges in distribution to the stochastic process θ(t)superscript𝜃𝑡\mathbf{\theta}^{(t)}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and the standard normal variable θsuperscript𝜃\mathbf{\theta^{\star}}italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, in the sense of Theorem A.1. Similarly, the joint empirical measure of the student and teacher preactivations 𝐰i(t)𝐳νdsuperscriptsubscript𝐰𝑖𝑡subscript𝐳𝜈𝑑\frac{\mathbf{w}_{i}^{(t)}\mathbf{z}_{\nu}}{\sqrt{d}}divide start_ARG bold_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG, 𝐰i𝐳νdsuperscriptsubscript𝐰𝑖subscript𝐳𝜈𝑑\frac{\mathbf{w}_{i}^{\star}\mathbf{z}_{\nu}}{\sqrt{d}}divide start_ARG bold_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG converge in distribution to the stochastic process 𝐡(t)superscript𝐡𝑡\mathbf{h}^{(t)}bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and the standard normal variable 𝐡superscript𝐡\mathbf{h}^{\star}bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. 𝛉(t)superscript𝛉𝑡\bm{\theta}^{(t)}bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and 𝐡(t)superscript𝐡𝑡\mathbf{h}^{(t)}bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT are defined recursively through the following equations:

𝜽(t+1)𝜽(t)=η(λ+Λ(t))𝜽(t)+ητ=0t1R(t,τ)𝜽(τ)ηg(t)𝜽+ητ=0tR~(t,τ)𝜽+η𝒖(t)superscript𝜽𝑡1superscript𝜽𝑡𝜂𝜆superscriptΛ𝑡superscript𝜽𝑡𝜂superscriptsubscript𝜏0𝑡1superscriptsubscript𝑅𝑡𝜏superscript𝜽𝜏𝜂superscript𝑔𝑡superscript𝜽𝜂superscriptsubscript𝜏0𝑡superscriptsubscript~𝑅𝑡𝜏superscript𝜽𝜂superscript𝒖𝑡\bm{\theta}^{(t+1)}-\bm{\theta}^{(t)}=-\,\eta\,\left(\lambda+\Lambda^{(t)}% \right)\bm{\theta}^{(t)}+\,\eta\sum_{\tau=0}^{t-1}R_{\ell}^{(t,\tau)}\bm{% \theta}^{(\tau)}-\,\eta g^{(t)}\bm{\theta}^{\star}+\,\eta\sum_{\tau=0}^{t}% \tilde{R}_{\ell}^{(t,\tau)}\bm{\theta}^{\star}+\,\eta\bm{u}^{(t)}bold_italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = - italic_η ( italic_λ + roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_η ∑ start_POSTSUBSCRIPT italic_τ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT - italic_η italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT + italic_η ∑ start_POSTSUBSCRIPT italic_τ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT over~ start_ARG italic_R end_ARG start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT + italic_η bold_italic_u start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT (52)
𝒉(t)=ητ=0t1Rθ(t,τ)𝒉(𝒉(τ),𝒉)+𝝎(t)superscript𝒉𝑡𝜂superscriptsubscript𝜏0𝑡1superscriptsubscript𝑅𝜃𝑡𝜏subscript𝒉superscript𝒉𝜏superscript𝒉superscript𝝎𝑡\bm{h}^{(t)}=-\eta\sum_{\tau=0}^{t-1}R_{\theta}^{(t,\tau)}\nabla_{\bm{h}}\ell(% \bm{h}^{(\tau)},\bm{h}^{\star})+\bm{\omega}^{(t)}bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = - italic_η ∑ start_POSTSUBSCRIPT italic_τ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT roman_ℓ ( bold_italic_h start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) + bold_italic_ω start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT (53)

Here 𝐮(t),𝛉superscript𝐮𝑡superscript𝛉\mathbf{u}^{(t)},\bm{\theta}^{\star}bold_u start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and (𝛚(t),𝐡)superscript𝛚𝑡superscript𝐡(\bm{\omega}^{(t)},\bm{h}^{\star})( bold_italic_ω start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) are zero mean Gaussian Process with covariances C(t,τ)superscriptsubscript𝐶𝑡𝜏C_{\ell}^{(t,\tau)}italic_C start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT and Ω(t,τ)superscriptΩ𝑡𝜏\Omega^{(t,\tau)}roman_Ω start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT respectively, given by:

C(t,τ)=α𝔼[𝒖(t)(𝒖(τ))]=α𝔼𝒉(t),𝒉[𝒉(𝒉(t),𝒉)𝒉(𝒉(τ),𝒉)]superscriptsubscript𝐶𝑡𝜏𝛼𝔼delimited-[]superscript𝒖𝑡superscriptsuperscript𝒖𝜏top𝛼subscript𝔼superscript𝒉𝑡superscript𝒉delimited-[]subscript𝒉superscript𝒉𝑡superscript𝒉subscript𝒉superscriptsuperscript𝒉𝜏superscript𝒉top\displaystyle C_{\ell}^{(t,\tau)}\!\!=\alpha\mathbb{E}\left[\bm{u}^{(t)}\left(% \bm{u}^{(\tau)}\right)^{\top}\right]=\alpha\mathbb{E}_{\bm{h}^{(t)},\bm{h}^{% \star}}\!\left[\nabla_{\bm{h}}\mathcal{L}(\bm{h}^{(t)},\bm{h}^{\star})\nabla_{% \bm{h}}\mathcal{L}(\bm{h}^{(\tau)},\bm{h}^{\star})^{\top}\!\right]italic_C start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = italic_α blackboard_E [ bold_italic_u start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( bold_italic_u start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] = italic_α blackboard_E start_POSTSUBSCRIPT bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] (54)
Ω(t,τ)=𝔼[(𝝎(t)𝒉)(𝝎(t)𝒉)]=[Cθ(t,τ)M(t)M(t)1],superscriptΩ𝑡𝜏𝔼delimited-[]matrixsuperscript𝝎𝑡superscript𝒉superscriptmatrixsuperscript𝝎𝑡superscript𝒉topmatrixsuperscriptsubscript𝐶𝜃𝑡𝜏superscript𝑀𝑡superscript𝑀𝑡1\displaystyle\Omega^{(t,\tau)}\!\!=\!\!\mathbb{E}\!\left[\begin{pmatrix}\bm{% \omega}^{(t)}\\ \!\!\bm{h}^{\star}\end{pmatrix}\begin{pmatrix}\bm{\omega}^{(t)}\!\\ \bm{h}^{\star}\!\end{pmatrix}^{\top}\!\right]=\begin{bmatrix}C_{\theta}^{(t,% \tau)}&M^{(t)}\\ M^{(t)}&1\end{bmatrix},roman_Ω start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = blackboard_E [ ( start_ARG start_ROW start_CELL bold_italic_ω start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ) ( start_ARG start_ROW start_CELL bold_italic_ω start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] = [ start_ARG start_ROW start_CELL italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT end_CELL start_CELL italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL start_CELL 1 end_CELL end_ROW end_ARG ] , (55)

where Cθ(t,τ),M(t)superscriptsubscript𝐶𝜃𝑡𝜏superscript𝑀𝑡C_{\theta}^{(t,\tau)},M^{(t)}italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT , italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT are defined as:

[Cθ(t,τ)M(t)M(t)1]=𝔼𝜽(t),𝜽[(𝜽(t)𝜽)(𝜽(τ)𝜽)]matrixsuperscriptsubscript𝐶𝜃𝑡𝜏superscript𝑀𝑡superscript𝑀𝑡1subscript𝔼superscript𝜽𝑡superscript𝜽delimited-[]matrixsuperscript𝜽𝑡superscript𝜽superscriptmatrixsuperscript𝜽𝜏superscript𝜽top\displaystyle\begin{bmatrix}C_{\theta}^{(t,\tau)}&M^{(t)}\\ M^{(t)}&1\end{bmatrix}\!\!=\!\mathbb{E}_{\bm{\theta}^{(t)},\bm{\theta}^{\star}% }\!\left[\begin{pmatrix}\bm{\theta}^{(t)}\\ \!\!\bm{\theta}^{\star}\end{pmatrix}\begin{pmatrix}\bm{\theta}^{(\tau)}\!\\ \bm{\theta}^{\star}\!\end{pmatrix}^{\top}\!\right][ start_ARG start_ROW start_CELL italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT end_CELL start_CELL italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL start_CELL 1 end_CELL end_ROW end_ARG ] = blackboard_E start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ ( start_ARG start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ) ( start_ARG start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] (56)

The effective regularisation Λ(t)superscriptΛ𝑡\Lambda^{(t)}roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and the projected gradient g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT concentrate to

Λ(t)=α𝔼𝒉(t),𝒉[𝐡2(𝐡(t),𝒉)],superscriptΛ𝑡𝛼subscript𝔼superscript𝒉𝑡superscript𝒉delimited-[]subscriptsuperscript2𝐡superscript𝐡𝑡superscript𝒉\displaystyle\Lambda^{(t)}\!=\!\alpha\mathbb{E}_{\bm{h}^{(t)},\bm{h}^{\star}}% \left[\nabla^{2}_{\mathbf{h}}\mathcal{L}\left(\mathbf{h}^{(t)},\bm{h}^{\star}% \right)\right]\,,\,roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = italic_α blackboard_E start_POSTSUBSCRIPT bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT caligraphic_L ( bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ] , (57)
g(t)=α𝔼𝒉(t),𝒉[𝐡(𝐡(t),𝒉)𝐡]superscript𝑔𝑡𝛼subscript𝔼superscript𝒉𝑡superscript𝒉delimited-[]subscript𝐡superscript𝐡𝑡superscript𝒉superscript𝐡absenttop\displaystyle g^{(t)}\!=\!\alpha\mathbb{E}_{\bm{h}^{(t)},\bm{h}^{\star}}\left[% \nabla_{\mathbf{h}}\mathcal{L}\left(\mathbf{h}^{(t)},\bm{h}^{\star}\right)% \mathbf{h}^{\star\top}\right]italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = italic_α blackboard_E start_POSTSUBSCRIPT bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT caligraphic_L ( bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) bold_h start_POSTSUPERSCRIPT ⋆ ⊤ end_POSTSUPERSCRIPT ] (58)

The memory kernels R(t,τ)superscriptsubscript𝑅𝑡𝜏R_{\ell}^{(t,\tau)}italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT, R~(t,τ)superscriptsubscript~𝑅𝑡𝜏\tilde{R}_{\ell}^{(t,\tau)}over~ start_ARG italic_R end_ARG start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT, Rθ(t,τ)superscriptsubscript𝑅𝜃𝑡𝜏R_{\theta}^{(t,\tau)}italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT for t>τ𝑡𝜏t>\tauitalic_t > italic_τ are defined through the partial derivatives with respect to the noise:

Rθ(t,τ)=𝔼𝜽(t),𝜽[θ(t)τ(τ)],superscriptsubscript𝑅𝜃𝑡𝜏subscript𝔼superscript𝜽𝑡superscript𝜽delimited-[]superscript𝜃𝑡superscript𝜏𝜏\displaystyle R_{\theta}^{(t,\tau)}=\mathbb{E}_{\bm{\theta}^{(t)},\bm{\theta}^% {\star}}\left[\frac{\partial\,\mathbf{\theta}^{(t)}}{\partial\,\mathbf{\tau}^{% (\tau)}}\right]\,,italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = blackboard_E start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ divide start_ARG ∂ italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_ARG start_ARG ∂ italic_τ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT end_ARG ] ,
R(t,τ)=α𝔼𝒉(t),𝒉[𝐡(𝐡(t),𝒉)𝝎(τ)],superscriptsubscript𝑅𝑡𝜏𝛼subscript𝔼superscript𝒉𝑡superscript𝒉delimited-[]subscript𝐡superscript𝐡𝑡superscript𝒉superscript𝝎𝜏\displaystyle R_{\ell}^{(t,\tau)}=\alpha\mathbb{E}_{\bm{h}^{(t)},\bm{h}^{\star% }}\left[\frac{\partial\,\nabla_{\mathbf{h}}\mathcal{L}\left(\mathbf{h}^{(t)},% \bm{h}^{\star}\right)}{\partial\,\bm{\omega}^{(\tau)}}\right]\,,italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = italic_α blackboard_E start_POSTSUBSCRIPT bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ divide start_ARG ∂ ∇ start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT caligraphic_L ( bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_ARG start_ARG ∂ bold_italic_ω start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT end_ARG ] , (59)
R~(t,τ)=α𝔼𝒉(t),𝒉[𝐡(𝐡(t),𝒉)(𝝎)(τ)],superscriptsubscript~𝑅𝑡𝜏𝛼subscript𝔼superscript𝒉𝑡superscript𝒉delimited-[]subscript𝐡superscript𝐡𝑡superscript𝒉superscriptsuperscript𝝎𝜏\displaystyle\tilde{R}_{\ell}^{(t,\tau)}=\alpha\mathbb{E}_{\bm{h}^{(t)},\bm{h}% ^{\star}}\left[\frac{\partial\,\nabla_{\mathbf{h}}\mathcal{L}\left(\mathbf{h}^% {(t)},\bm{h}^{\star}\right)}{\partial\,\left(\bm{\omega}^{\star}\right)^{(\tau% )}}\right]\,,\quadover~ start_ARG italic_R end_ARG start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = italic_α blackboard_E start_POSTSUBSCRIPT bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ divide start_ARG ∂ ∇ start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT caligraphic_L ( bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_ARG start_ARG ∂ ( bold_italic_ω start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT end_ARG ] , (60)

and R(t,t)=R~(t,t)=0superscriptsubscript𝑅𝑡𝑡superscriptsubscript~𝑅𝑡𝑡0R_{\ell}^{(t,t)}=\tilde{R}_{\ell}^{(t,t)}=0italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_t ) end_POSTSUPERSCRIPT = over~ start_ARG italic_R end_ARG start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_t ) end_POSTSUPERSCRIPT = 0, Rθ(t,t)=1superscriptsubscript𝑅𝜃𝑡𝑡1R_{\theta}^{(t,t)}=1italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_t ) end_POSTSUPERSCRIPT = 1. Finally, M(t)superscript𝑀𝑡M^{(t)}italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT satisfies the update equation

M(t+1)=(1ηλ)M(t)ηg(t),superscript𝑀𝑡11𝜂𝜆superscript𝑀𝑡𝜂superscript𝑔𝑡M^{(t+1)}=(1-\eta\lambda)M^{(t)}-\eta g^{(t)}\,,italic_M start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = ( 1 - italic_η italic_λ ) italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_η italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , (61)

where g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT is defined as:

α𝔼[𝒉(𝒉(t))(𝒉)]𝛼𝔼delimited-[]subscript𝒉superscript𝒉𝑡superscriptsuperscript𝒉top\alpha\mathbb{E}\left[\nabla_{\bm{h}}\ell(\bm{h}^{(t)})\left(\bm{h}^{*}\right)% ^{\top}\right]italic_α blackboard_E [ ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT roman_ℓ ( bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] (62)
Proof.

Analogous to the embedding of planted vectors in [Celentano et al., 2021], we start by considering an a lifted dynamics defined by concating W(t)superscript𝑊𝑡W^{(t)}italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and Wsuperscript𝑊W^{*}italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. First, define the extended parameters W~(t)(p+k)×dsuperscript~𝑊𝑡superscript𝑝𝑘𝑑{\tilde{W}}^{(t)}\in\mathbb{R}^{(p+k)\times d}over~ start_ARG italic_W end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT ( italic_p + italic_k ) × italic_d end_POSTSUPERSCRIPT with update rule:

[W(t+1)W]=[W(t)W]η[λW(t)0]ην=1n[W(t)(W(t)𝐳νd)0],matrixsuperscript𝑊𝑡1superscript𝑊matrixsuperscript𝑊𝑡superscript𝑊𝜂matrix𝜆superscript𝑊𝑡0𝜂superscriptsubscript𝜈1𝑛matrixsubscriptsuperscript𝑊𝑡superscript𝑊𝑡superscript𝐳𝜈𝑑0\displaystyle\begin{bmatrix}W^{(t+1)}\\ W^{*}\end{bmatrix}=\begin{bmatrix}W^{(t)}\\ W^{*}\end{bmatrix}-\eta\begin{bmatrix}\lambda W^{(t)}\\ 0\end{bmatrix}-\eta\sum_{\nu=1}^{n}\begin{bmatrix}\nabla_{W^{(t)}}\,\mathcal{L% }\left(\frac{W^{(t)}\mathbf{z}^{\nu}}{\sqrt{d}}\right)\\ 0\end{bmatrix}\,,[ start_ARG start_ROW start_CELL italic_W start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] = [ start_ARG start_ROW start_CELL italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] - italic_η [ start_ARG start_ROW start_CELL italic_λ italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ] - italic_η ∑ start_POSTSUBSCRIPT italic_ν = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT [ start_ARG start_ROW start_CELL ∇ start_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L ( divide start_ARG italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_z start_POSTSUPERSCRIPT italic_ν end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ] , (63)

The above form of updates can be seen to be a special case of Theorem A.1 with q=p+k𝑞𝑝𝑘q=p+kitalic_q = italic_p + italic_k and F:Rp+kp+k:𝐹superscript𝑅𝑝𝑘superscript𝑝𝑘F:R^{p+k}\rightarrow\mathbb{R}^{p+k}italic_F : italic_R start_POSTSUPERSCRIPT italic_p + italic_k end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_p + italic_k end_POSTSUPERSCRIPT given by:

F:(𝐡𝐡)(𝐡(𝐡,𝐡)𝟎):𝐹matrix𝐡superscript𝐡matrixsubscript𝐡𝐡superscript𝐡0F:\begin{pmatrix}\mathbf{h}\\ \mathbf{h}^{\star}\end{pmatrix}\rightarrow\begin{pmatrix}\nabla_{\mathbf{h}}% \mathcal{L}(\mathbf{h},\mathbf{h}^{\star})\\ \mathbf{0}\end{pmatrix}italic_F : ( start_ARG start_ROW start_CELL bold_h end_CELL end_ROW start_ROW start_CELL bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ) → ( start_ARG start_ROW start_CELL ∇ start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT caligraphic_L ( bold_h , bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL end_ROW end_ARG ) (64)

The assumptions on g,σsuperscript𝑔𝜎g^{*},\sigmaitalic_g start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_σ imply that F𝐹Fitalic_F is pseudo-Lipschitz of finite-order while standard concentration results for sub-exponential random variables when applied to W(0),Wsuperscript𝑊0superscript𝑊W^{(0)},W^{*}italic_W start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT imply that the overlap matrices at initialization converge almost surely. Therefore, Theorem A.1 applies, with the effective process for the weights and the pre-activations being described by:

[𝜽(t+1)𝜽][𝜽(t)𝜽]=η[λ+Λ(t)Λ~(t)00][𝜽(t)𝜽]+ητ=0t[R(t,τ)R~(t,τ)00][𝜽(τ)𝜽]+η[𝒖(t)0]matrixsuperscript𝜽𝑡1superscript𝜽matrixsuperscript𝜽𝑡superscript𝜽𝜂matrix𝜆superscriptΛ𝑡superscript~Λ𝑡00matrixsuperscript𝜽𝑡superscript𝜽𝜂superscriptsubscript𝜏0𝑡matrixsuperscriptsubscript𝑅𝑡𝜏superscriptsubscript~𝑅𝑡𝜏00matrixsuperscript𝜽𝜏superscript𝜽𝜂matrixsuperscript𝒖𝑡0\begin{bmatrix}\bm{\theta}^{(t+1)}\\ \bm{\theta}^{*}\end{bmatrix}-\begin{bmatrix}\bm{\theta}^{(t)}\\ \bm{\theta}^{*}\end{bmatrix}=-\,\eta\,\begin{bmatrix}\lambda+\Lambda^{(t)}&% \tilde{\Lambda}^{(t)}\\ 0&0\end{bmatrix}\begin{bmatrix}\bm{\theta}^{(t)}\\ \bm{\theta}^{*}\end{bmatrix}+\,\eta\sum_{\tau=0}^{t}\begin{bmatrix}R_{\ell}^{(% t,\tau)}&\tilde{R}_{\ell}^{(t,\tau)}\\ 0&0\end{bmatrix}\begin{bmatrix}\bm{\theta}^{(\tau)}\\ \bm{\theta}^{*}\end{bmatrix}+\,\eta\begin{bmatrix}\bm{u}^{(t)}\\ 0\end{bmatrix}[ start_ARG start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] - [ start_ARG start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] = - italic_η [ start_ARG start_ROW start_CELL italic_λ + roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL start_CELL over~ start_ARG roman_Λ end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] + italic_η ∑ start_POSTSUBSCRIPT italic_τ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT [ start_ARG start_ROW start_CELL italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT end_CELL start_CELL over~ start_ARG italic_R end_ARG start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] + italic_η [ start_ARG start_ROW start_CELL bold_italic_u start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ] (65)
[𝒉(t)𝒉]=ητ=0t1[Rθ(t,τ)R~θ(t,τ)01][𝒉(𝒉(τ),𝒉)0]+[𝝎(t)𝝎]matrixsuperscript𝒉𝑡superscript𝒉𝜂superscriptsubscript𝜏0𝑡1matrixsuperscriptsubscript𝑅𝜃𝑡𝜏superscriptsubscript~𝑅𝜃𝑡𝜏01matrixsubscript𝒉superscript𝒉𝜏superscript𝒉0matrixsuperscript𝝎𝑡superscript𝝎\begin{bmatrix}\bm{h}^{(t)}\\ \bm{h}^{*}\end{bmatrix}=-\eta\sum_{\tau=0}^{t-1}\begin{bmatrix}R_{\theta}^{(t,% \tau)}&\tilde{R}_{\theta}^{(t,\tau)}\\ 0&1\end{bmatrix}\begin{bmatrix}\nabla_{\bm{h}}\mathcal{L}(\bm{h}^{(\tau)},\bm{% h}^{\star})\\ 0\end{bmatrix}+\begin{bmatrix}\bm{\omega}^{(t)}\\ \bm{\omega}^{*}\end{bmatrix}[ start_ARG start_ROW start_CELL bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] = - italic_η ∑ start_POSTSUBSCRIPT italic_τ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT [ start_ARG start_ROW start_CELL italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT end_CELL start_CELL over~ start_ARG italic_R end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 1 end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ] + [ start_ARG start_ROW start_CELL bold_italic_ω start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_ω start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] (66)

Notice the redundancy in the above equations due to Wsuperscript𝑊W^{*}italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT not being updated in (63). This allows us to further simplify (66) and (65), obtaining:

𝜽(t+1)𝜽(t)=η(λ+Λ(t))𝜽(t)+ητ=0t1R(t,τ)𝜽(τ)ηΛ~(t)𝜽+ητ=0tR~(t,τ)𝜽+η𝒖(t)superscript𝜽𝑡1superscript𝜽𝑡𝜂𝜆superscriptΛ𝑡superscript𝜽𝑡𝜂superscriptsubscript𝜏0𝑡1superscriptsubscript𝑅𝑡𝜏superscript𝜽𝜏𝜂superscript~Λ𝑡superscript𝜽𝜂superscriptsubscript𝜏0𝑡superscriptsubscript~𝑅𝑡𝜏superscript𝜽𝜂superscript𝒖𝑡\bm{\theta}^{(t+1)}-\bm{\theta}^{(t)}=-\,\eta\,\left(\lambda+\Lambda^{(t)}% \right)\bm{\theta}^{(t)}+\,\eta\sum_{\tau=0}^{t-1}R_{\ell}^{(t,\tau)}\bm{% \theta}^{(\tau)}-\,\eta\,\tilde{\Lambda}^{(t)}\bm{\theta}^{*}+\,\eta\sum_{\tau% =0}^{t}\tilde{R}_{\ell}^{(t,\tau)}\bm{\theta}^{*}+\,\eta\bm{u}^{(t)}bold_italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = - italic_η ( italic_λ + roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_η ∑ start_POSTSUBSCRIPT italic_τ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT - italic_η over~ start_ARG roman_Λ end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT + italic_η ∑ start_POSTSUBSCRIPT italic_τ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT over~ start_ARG italic_R end_ARG start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT + italic_η bold_italic_u start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT (67)
𝒉(t)=ητ=0t1Rθ(t,τ)𝒉(𝒉(τ),𝒉)+𝝎(t)superscript𝒉𝑡𝜂superscriptsubscript𝜏0𝑡1superscriptsubscript𝑅𝜃𝑡𝜏subscript𝒉superscript𝒉𝜏superscript𝒉superscript𝝎𝑡\bm{h}^{(t)}=-\eta\sum_{\tau=0}^{t-1}R_{\theta}^{(t,\tau)}\nabla_{\bm{h}}% \mathcal{L}(\bm{h}^{(\tau)},\bm{h}^{\star})+\bm{\omega}^{(t)}bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = - italic_η ∑ start_POSTSUBSCRIPT italic_τ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) + bold_italic_ω start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT (68)

where we noticed that 𝒉𝝎similar-tosuperscript𝒉superscript𝝎\bm{h}^{*}\sim\bm{\omega}^{*}bold_italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∼ bold_italic_ω start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. These equations are the same as in A.1, with just two extra terms in (67), Λ~(t)superscript~Λ𝑡\tilde{\Lambda}^{(t)}over~ start_ARG roman_Λ end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and R~(t,τ)superscriptsubscript~𝑅𝑡𝜏\tilde{R}_{\ell}^{(t,\tau)}over~ start_ARG italic_R end_ARG start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT. An application of the Stein’s Lemma further simplifies the term Λ~(t)superscript~Λ𝑡\tilde{\Lambda}^{(t)}over~ start_ARG roman_Λ end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT to g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT in the Theorem as follows:

Λ~(t)=α𝔼[𝒉𝒉(𝒉(t),𝒉)]=α𝔼[𝒉(𝒉(t),𝒉)(𝒉)]=g(t),superscript~Λ𝑡𝛼𝔼delimited-[]subscriptsuperscript𝒉subscript𝒉superscript𝒉𝑡superscript𝒉𝛼𝔼delimited-[]subscript𝒉superscript𝒉𝑡superscript𝒉superscriptsuperscript𝒉topsuperscript𝑔𝑡\tilde{\Lambda}^{(t)}=\alpha\mathbb{E}\left[\nabla_{\bm{h}^{*}}\nabla_{\bm{h}}% \ell(\bm{h}^{(t)},\bm{h}^{\star})\right]=\alpha\mathbb{E}\left[\nabla_{\bm{h}}% \ell(\bm{h}^{(t)},\bm{h}^{\star})\left(\bm{h}^{*}\right)^{\top}\right]=g^{(t)},over~ start_ARG roman_Λ end_ARG start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = italic_α blackboard_E [ ∇ start_POSTSUBSCRIPT bold_italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT roman_ℓ ( bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ] = italic_α blackboard_E [ ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT roman_ℓ ( bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] = italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , (69)

The above effective process characterizes the limits of several quantities determined by the weights and pre-activations. In particular, it provides the limits of the student-teacher overlaps:

Corollary A.3.

Under the assumptions of Theorem 3.2,

W(t)(W)/d𝑃n,dM(t),𝑃𝑛𝑑superscript𝑊𝑡superscriptsuperscript𝑊top𝑑superscript𝑀𝑡W^{(t)}(W^{\star})^{\top}/d\xrightarrow[P]{n,d\rightarrow\infty}M^{(t)},italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT / italic_d start_ARROW underitalic_P start_ARROW start_OVERACCENT italic_n , italic_d → ∞ end_OVERACCENT → end_ARROW end_ARROW italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , (70)

where M(t)superscript𝑀𝑡M^{(t)}italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT is defined as in Theorem A.2.

Proof.

Observe that 𝐰i(t),𝐰/dsubscriptsuperscript𝐰𝑡𝑖superscript𝐰𝑑\langle\mathbf{w}^{(t)}_{i},\mathbf{w}^{\star}\rangle/d⟨ bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_w start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩ / italic_d can be expressed as an expectation of a pseudo-lipschitz function w.r.t the joint empirical measure over the coordinates of 𝐰i(t),𝐰subscriptsuperscript𝐰𝑡𝑖superscript𝐰\mathbf{w}^{(t)}_{i},\mathbf{w}^{\star}bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_w start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT with the value at the jthsubscript𝑗thj_{\rm th}italic_j start_POSTSUBSCRIPT roman_th end_POSTSUBSCRIPT coordinate given by {𝐰i(t)}j{𝐰}jsubscriptsubscriptsuperscript𝐰𝑡𝑖𝑗subscriptsuperscript𝐰𝑗\{\mathbf{w}^{(t)}_{i}\}_{j}\{\mathbf{w}^{\star}\}_{j}{ bold_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT { bold_w start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT. Therefore 3.2 implies that W(t)(W)/dsuperscript𝑊𝑡superscriptsuperscript𝑊top𝑑W^{(t)}(W^{\star})^{\top}/ditalic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT / italic_d converges in probability to the expected overlaps of the effective process θ(t),θsuperscript𝜃𝑡superscript𝜃\mathbf{\theta}^{(t)},\mathbf{\theta}^{\star}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT which equal M(t)superscript𝑀𝑡M^{(t)}italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT by definition. ∎

We also include a useful corollary, describing the evolution of the overlaps of the weights.

Lemma A.4.

Under the assumptions of A.2 the covariance Cθ(t,τ)superscriptsubscript𝐶𝜃𝑡𝜏C_{\theta}^{(t,\tau)}italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT

Cθ(t+1,τ)Cθ(t,τ)=η(λ+Λ(t))Cθ(t,τ)+ηs=0t1R(t,s)Cθ(s,τ)+ηs=0τ1Rθ(t,s)C(s,τ)superscriptsubscript𝐶𝜃𝑡1𝜏superscriptsubscript𝐶𝜃𝑡𝜏𝜂𝜆superscriptΛ𝑡superscriptsubscript𝐶𝜃𝑡𝜏𝜂superscriptsubscript𝑠0𝑡1superscriptsubscript𝑅𝑡𝑠superscriptsubscript𝐶𝜃𝑠𝜏limit-from𝜂superscriptsubscript𝑠0𝜏1superscriptsubscript𝑅𝜃𝑡𝑠superscriptsubscript𝐶𝑠𝜏\displaystyle C_{\theta}^{(t+1,\tau)}-C_{\theta}^{(t,\tau)}=-\eta\left(\lambda% +\Lambda^{(t)}\right)C_{\theta}^{(t,\tau)}+\eta\sum_{s=0}^{t-1}R_{\ell}^{(t,s)% }C_{\theta}^{(s,\tau)}+\eta\sum_{s=0}^{\tau-1}R_{\theta}^{(t,s)}C_{\ell}^{(s,% \tau)}-italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 , italic_τ ) end_POSTSUPERSCRIPT - italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = - italic_η ( italic_λ + roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT + italic_η ∑ start_POSTSUBSCRIPT italic_s = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_s ) end_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_s , italic_τ ) end_POSTSUPERSCRIPT + italic_η ∑ start_POSTSUBSCRIPT italic_s = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_s ) end_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_s , italic_τ ) end_POSTSUPERSCRIPT - (71)
η(gts=0tR~(t,s))(M(τ))𝜂superscript𝑔𝑡superscriptsubscript𝑠0𝑡superscriptsubscript~𝑅𝑡𝑠superscriptsuperscript𝑀𝜏top\displaystyle-\eta\left(g^{t}-\sum_{s=0}^{t}\tilde{R}_{\ell}^{(t,s)}\right)% \left(M^{(\tau)}\right)^{\top}- italic_η ( italic_g start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT - ∑ start_POSTSUBSCRIPT italic_s = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT over~ start_ARG italic_R end_ARG start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_s ) end_POSTSUPERSCRIPT ) ( italic_M start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (72)

This is a consequence of linearity of expectation on (52). Concretely, viewing θ(t)superscript𝜃𝑡\theta^{(t)}italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT as a function of the Gaussian random variables {u(τ)}τ=1tsuperscriptsubscriptsuperscript𝑢𝜏𝜏1𝑡\{u^{(\tau)}\}_{\tau=1}^{t}{ italic_u start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_τ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT, we apply the multi-variate Stein’s Lemma to obtain:

𝔼[θ(t)𝐮(τ)]=s=τt1Rθ(t,s)C(s,τ).𝔼delimited-[]superscript𝜃𝑡superscript𝐮𝜏superscriptsubscript𝑠𝜏𝑡1superscriptsubscript𝑅𝜃𝑡𝑠superscriptsubscript𝐶𝑠𝜏\mathbb{E}\left[\mathbf{\theta}^{(t)}\mathbf{u}^{(\tau)}\right]=\sum_{s=\tau}^% {t-1}R_{\theta}^{(t,s)}C_{\ell}^{(s,\tau)}\,.blackboard_E [ italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_u start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ] = ∑ start_POSTSUBSCRIPT italic_s = italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_s ) end_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_s , italic_τ ) end_POSTSUPERSCRIPT . (73)

In particular, we obtain the following expression for the covariances upto the first time-steps:

Lemma A.5.

The covariances Cθ(0,1),Cθ(0,0)superscriptsubscript𝐶𝜃01superscriptsubscript𝐶𝜃00C_{\theta}^{(0,1)},C_{\theta}^{(0,0)}italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 , 1 ) end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 , 0 ) end_POSTSUPERSCRIPT satisfy:

Cθ(0,1)Cθ(0,0)=η(λ+Λ(0))Cθ(0,0)η(g(0)Λ(0)M(0))(M(0))superscriptsubscript𝐶𝜃01superscriptsubscript𝐶𝜃00𝜂𝜆superscriptΛ0superscriptsubscript𝐶𝜃00𝜂superscript𝑔0superscriptΛ0superscript𝑀0superscriptsuperscript𝑀0top\displaystyle C_{\theta}^{(0,1)}-C_{\theta}^{(0,0)}=-\eta\left(\lambda+\Lambda% ^{(0)}\right)C_{\theta}^{(0,0)}-\eta\left(g^{(0)}-\Lambda^{(0)}M^{(0)}\right)% \left(M^{(0)}\right)^{\top}italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 , 1 ) end_POSTSUPERSCRIPT - italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 , 0 ) end_POSTSUPERSCRIPT = - italic_η ( italic_λ + roman_Λ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 , 0 ) end_POSTSUPERSCRIPT - italic_η ( italic_g start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT - roman_Λ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT italic_M start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) ( italic_M start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (74)
Cθ(1,1)Cθ(0,1)=η(λ+Λ(0))Cθ(0,1)+ηC(0,1)η(g(0)Λ(0)M(0))(M(1))superscriptsubscript𝐶𝜃11superscriptsubscript𝐶𝜃01𝜂𝜆superscriptΛ0superscriptsubscript𝐶𝜃01𝜂superscriptsubscript𝐶01𝜂superscript𝑔0superscriptΛ0superscript𝑀0superscriptsuperscript𝑀1top\displaystyle C_{\theta}^{(1,1)}-C_{\theta}^{(0,1)}=-\eta\left(\lambda+\Lambda% ^{(0)}\right)C_{\theta}^{(0,1)}+\eta C_{\ell}^{(0,1)}-\eta\left(g^{(0)}-% \Lambda^{(0)}M^{(0)}\right)\left(M^{(1)}\right)^{\top}italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 , 1 ) end_POSTSUPERSCRIPT - italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 , 1 ) end_POSTSUPERSCRIPT = - italic_η ( italic_λ + roman_Λ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 , 1 ) end_POSTSUPERSCRIPT + italic_η italic_C start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 , 1 ) end_POSTSUPERSCRIPT - italic_η ( italic_g start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT - roman_Λ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT italic_M start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) ( italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (75)

A.4 Pre-activations at the end of the first gradient update

For T=1𝑇1T=1italic_T = 1, Equation (53) simplifies to:

𝒉(1)=η𝒉(𝒉(0),𝐡)+𝝎(1)superscript𝒉1𝜂subscript𝒉superscript𝒉0superscript𝐡superscript𝝎1\bm{h}^{(1)}=-\eta\nabla_{\bm{h}}\ell(\bm{h}^{(0)},\mathbf{h}^{\star})+\bm{% \omega}^{(1)}bold_italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = - italic_η ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT roman_ℓ ( bold_italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) + bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT (76)

We now show that the first term exactly correspond to the contributions considered in section 4.1.

Lemma A.6.

Under the notation in section 4.1 and assumptions of Theorem 3.2:

ad(𝐡ν(0),𝐡ν)σ(𝐡ν(0))𝐳ν,𝐳ν𝐷n,d𝒉(𝒉0,𝐡)𝐷𝑛𝑑𝑎𝑑superscriptsubscriptsuperscript𝐡0𝜈subscriptsuperscript𝐡𝜈superscript𝜎subscriptsuperscript𝐡0𝜈subscript𝐳𝜈subscript𝐳𝜈subscript𝒉superscript𝒉0superscript𝐡\frac{a}{d}\ell^{\prime}\left(\mathbf{h}^{(0)}_{\nu},\mathbf{h}^{\star}_{\nu}% \right)\sigma^{\prime}(\mathbf{h}^{(0)}_{\nu})\langle\mathbf{z}_{\nu},\mathbf{% z}_{\nu}\rangle\xrightarrow[D]{n,d\rightarrow\infty}\nabla_{\bm{h}}\ell(\bm{h}% ^{0},\mathbf{h}^{\star})divide start_ARG italic_a end_ARG start_ARG italic_d end_ARG roman_ℓ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT , bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) ⟨ bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT , bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ⟩ start_ARROW underitalic_D start_ARROW start_OVERACCENT italic_n , italic_d → ∞ end_OVERACCENT → end_ARROW end_ARROW ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT roman_ℓ ( bold_italic_h start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) (77)

where 𝐡0p,𝐡kformulae-sequencesuperscript𝐡0superscript𝑝superscript𝐡superscript𝑘\bm{h}^{0}\in\mathbb{R}^{p},\bm{h}^{*}\in\mathbb{R}^{k}bold_italic_h start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT are independent Gaussian random variables distributed as in Theorem A.2.

Proof.

We simply apply the conditioning by projection technique described in Section A.2 to 𝐳νsubscript𝐳𝜈\mathbf{z}_{\nu}bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT by expressing it as: 𝐳ν=𝐡ν(0)+d1d𝐳νsubscript𝐳𝜈subscriptsuperscript𝐡0𝜈𝑑1𝑑subscriptsuperscript𝐳𝜈\mathbf{z}_{\nu}=\mathbf{h}^{(0)}_{\nu}+\frac{d-1}{d}\mathbf{z}^{\prime}_{\nu}bold_z start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT = bold_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT + divide start_ARG italic_d - 1 end_ARG start_ARG italic_d end_ARG bold_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT, where 𝐳νsubscriptsuperscript𝐳𝜈\mathbf{z}^{\prime}_{\nu}bold_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT is independent of 𝐡ν(0)subscriptsuperscript𝐡0𝜈\mathbf{h}^{(0)}_{\nu}bold_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT. The result then follows from convergence in probability of 1d𝐳ν,𝐳ν1𝑑subscriptsuperscript𝐳𝜈subscriptsuperscript𝐳𝜈\frac{1}{d}\langle\mathbf{z}^{\prime}_{\nu},\mathbf{z}^{\prime}_{\nu}\rangledivide start_ARG 1 end_ARG start_ARG italic_d end_ARG ⟨ bold_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT , bold_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ⟩ to 1111. ∎

Next, we characterize 𝝎(1)superscript𝝎1\bm{\omega}^{(1)}bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT. We consider two cases:

  • M(1)=0superscript𝑀10M^{(1)}=0italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = 0: In this case, Corollary A.3 implies that the first-layer does not develop any overlap with directions in Usuperscript𝑈U^{*}italic_U start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. Equation (55) then implies that 𝝎(1)superscript𝝎1\bm{\omega}^{(1)}bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT is uncorrelated with ωsuperscript𝜔\omega^{*}italic_ω start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT.

  • M(1)0superscript𝑀10M^{(1)}\neq 0italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ≠ 0: In this case the first-layer develops an overlap along Usuperscript𝑈U^{*}italic_U start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. By initialization, we have that M(0)=0superscript𝑀00M^{(0)}=0italic_M start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT = 0. Equation (62) implies that M(1)superscript𝑀1M^{(1)}italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT is given by:

    M(1)=ηα𝔼[𝒉(𝒉(0))(𝒉)]superscript𝑀1𝜂𝛼𝔼delimited-[]subscript𝒉superscript𝒉0superscriptsuperscript𝒉topM^{(1)}=-\eta\alpha\mathbb{E}\left[\nabla_{\bm{h}}\ell(\bm{h}^{(0)})\left(\bm{% h}^{*}\right)^{\top}\right]italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = - italic_η italic_α blackboard_E [ ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT roman_ℓ ( bold_italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] (78)

    Due to the choice of symmetric initialization (Equation 4), we have f(𝒉(0))=0𝑓superscript𝒉00f(\bm{h}^{(0)})=0italic_f ( bold_italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) = 0. Therefore, 𝒉(𝒉(0))=ag(𝐡)σ(𝒉(0))subscript𝒉superscript𝒉0𝑎superscript𝑔superscript𝐡superscript𝜎superscript𝒉0\nabla_{\bm{h}}\mathcal{L}(\bm{h}^{(0)})=-ag^{*}(\mathbf{h}^{*})\sigma^{\prime% }(\bm{h}^{(0)})∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) = - italic_a italic_g start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ). We thus obtain

    M(1)=ηα𝐚𝔼[g(𝐡)𝒉(0)(𝒉)]=ηα𝐚𝔼[𝒉(0)]𝔼[g(𝐡)(𝒉)],superscript𝑀1𝜂𝛼𝐚𝔼delimited-[]superscript𝑔superscript𝐡superscript𝒉0superscriptsuperscript𝒉topdirect-product𝜂𝛼𝐚𝔼delimited-[]superscript𝒉0𝔼delimited-[]superscript𝑔superscript𝐡superscriptsuperscript𝒉topM^{(1)}=\eta\alpha\mathbf{a}\mathbb{E}\left[g^{*}(\mathbf{h}^{*})\bm{h}^{(0)}(% \bm{h}^{*})^{\top}\right]=\eta\alpha\mathbf{a}\odot\mathbb{E}\left[\bm{h}^{(0)% }\right]\mathbb{E}\left[g^{*}(\mathbf{h}^{*})(\bm{h}^{*})^{\top}\right],italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = italic_η italic_α bold_a blackboard_E [ italic_g start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) bold_italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ( bold_italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] = italic_η italic_α bold_a ⊙ blackboard_E [ bold_italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ] blackboard_E [ italic_g start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] , (79)

    where direct-product\odot denotes element-wise multiplication and we used the independence of 𝐡,𝒉(0)superscript𝐡superscript𝒉0\mathbf{h}^{*},\bm{h}^{(0)}bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT. Therefore, the rows of M(1)superscript𝑀1M^{(1)}italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT gains a rank-one spike along 𝔼[g(𝐡)(𝒉)]𝔼delimited-[]superscript𝑔superscript𝐡superscript𝒉\mathbb{E}\left[g^{*}(\mathbf{h}^{*})(\bm{h}^{*})\right]blackboard_E [ italic_g start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ]. This matches the corresponding results for single 𝒪(d)𝒪𝑑\mathcal{O}(d)caligraphic_O ( italic_d ) batch gradient steps under the online-setting [Ba et al., 2022, Dandi et al., 2023].

    By Equation (55), 𝝎(1)superscript𝝎1\bm{\omega}^{(1)}bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT can be expressed as:

    𝝎(1)=𝝎(1)+M(1)𝒉,superscript𝝎1superscriptsubscript𝝎perpendicular-to1superscript𝑀1superscript𝒉\bm{\omega}^{(1)}=\bm{\omega_{\perp}}^{(1)}+M^{(1)}\bm{h}^{*}\,,bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = bold_italic_ω start_POSTSUBSCRIPT bold_⟂ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT + italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT bold_italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , (80)

    where ωsubscript𝜔perpendicular-to\omega_{\perp}italic_ω start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT is independent of 𝒉superscript𝒉\bm{h}^{*}bold_italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT.

A.5 Proof of Theorem 3.2

To illustrate the learning of directions solely due to the hidden progress explained in section 4.1, we first focus on the case where M(1)=0superscript𝑀10M^{(1)}=0italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = 0 i.e when the parameters develop no overlap along the target subspace in the first step.

From Corollary A.3, and Slutsky’s theorem, we have that:

1dW2(W)1dW1(W)𝑃n,dηα𝔼[𝒉(𝒉(1),𝐡)(𝒉)],𝑃𝑛𝑑1𝑑superscript𝑊2superscriptsuperscript𝑊top1𝑑superscript𝑊1superscriptsuperscript𝑊top𝜂𝛼𝔼delimited-[]subscript𝒉superscript𝒉1superscript𝐡superscriptsuperscript𝒉top\frac{1}{d}W^{2}(W^{\star})^{\top}-\frac{1}{d}W^{1}(W^{\star})^{\top}% \xrightarrow[P]{n,d\rightarrow\infty}-\eta\alpha\mathbb{E}\left[\nabla_{\bm{h}% }\mathcal{L}(\bm{h}^{(1)},\mathbf{h}^{\star})\left(\bm{h}^{\star}\right)^{\top% }\right],divide start_ARG 1 end_ARG start_ARG italic_d end_ARG italic_W start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_d end_ARG italic_W start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT start_ARROW underitalic_P start_ARROW start_OVERACCENT italic_n , italic_d → ∞ end_OVERACCENT → end_ARROW end_ARROW - italic_η italic_α blackboard_E [ ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] , (81)

where from Equation (76), 𝐡(1)superscript𝐡1\mathbf{h}^{(1)}bold_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT can be expressed as a combination of 𝒉(𝒉(1),𝐡)subscript𝒉superscript𝒉1superscript𝐡\nabla_{\bm{h}}\mathcal{L}(\bm{h}^{(1)},\mathbf{h}^{\star})∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) and a Gaussian random variable 𝝎(1)superscript𝝎1\bm{\omega}^{(1)}bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT independent of 𝐡superscript𝐡\mathbf{h}^{\star}bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Furthermore, Lemma A.4 implies that the regularization strength λ𝜆\lambdaitalic_λ and step-size η𝜂\etaitalic_η can be set such that the entries of 𝝎(1)superscript𝝎1\bm{\omega}^{(1)}bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT have unit-variance. Now, suppose 𝐯=(W)𝐮superscript𝐯superscriptsuperscript𝑊topsuperscript𝐮\mathbf{v}^{\star}=(W^{\star})^{\top}\mathbf{u}^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = ( italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT for some fixed vector 𝐮psuperscript𝐮superscript𝑝\mathbf{u}^{\star}\in\mathbb{R}^{p}bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT. First, consider the case when 𝐯superscript𝐯\mathbf{v}^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT lies in the subspace Psubscriptsuperscript𝑃perpendicular-toP^{*}_{\perp}italic_P start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT as defined in definition 3.1.

By projecting Equation (81) along 𝐮superscript𝐮\mathbf{u}^{\star}bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, we obtain that:

1dW2𝐯1dW1𝐯𝑃n,dηα𝔼[𝒉(𝒉(1),𝐡)(𝒉)𝐮],𝑃𝑛𝑑1𝑑superscript𝑊2superscript𝐯1𝑑superscript𝑊1superscript𝐯𝜂𝛼𝔼delimited-[]subscript𝒉superscript𝒉1superscript𝐡superscriptsuperscript𝒉topsuperscript𝐮\frac{1}{d}W^{2}\mathbf{v}^{\star}-\frac{1}{d}W^{1}\mathbf{v}^{\star}% \xrightarrow[P]{n,d\rightarrow\infty}-\eta\alpha\mathbb{E}\left[\nabla_{\bm{h}% }\mathcal{L}(\bm{h}^{(1)},\mathbf{h}^{*})\left(\bm{h}^{\star}\right)^{\top}% \mathbf{u}^{\star}\right],divide start_ARG 1 end_ARG start_ARG italic_d end_ARG italic_W start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_d end_ARG italic_W start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_ARROW underitalic_P start_ARROW start_OVERACCENT italic_n , italic_d → ∞ end_OVERACCENT → end_ARROW end_ARROW - italic_η italic_α blackboard_E [ ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ] , (82)

For squared loss, we have 𝒉j(𝒉(t),𝐡)=aj(g(𝐡)f(𝒉(t)))σ(𝒉j(t))subscriptsubscript𝒉𝑗superscript𝒉𝑡superscript𝐡subscript𝑎𝑗superscript𝑔superscript𝐡𝑓superscript𝒉𝑡superscript𝜎subscriptsuperscript𝒉𝑡𝑗-\nabla_{\bm{h}_{j}}\mathcal{L}(\bm{h}^{(t)},\mathbf{h}^{\star})=a_{j}(g^{% \star}(\mathbf{h}^{\star})-f(\bm{h}^{(t)}))\sigma^{\prime}(\bm{h}^{(t)}_{j})- ∇ start_POSTSUBSCRIPT bold_italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) = italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) - italic_f ( bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ).

Therefore, the overlap for the jthsubscript𝑗𝑡j_{th}italic_j start_POSTSUBSCRIPT italic_t italic_h end_POSTSUBSCRIPT neuron can be expressed as :

1d𝐰i(2),𝐯1d𝐰i(1),𝐯𝑃n,dηα𝔼[ajg(𝐡)σ(𝒉j(1))(𝒉)𝐮]ηα𝔼[ajf(𝒉(1))σ(𝒉j(1))(𝒉)𝐮].𝑃𝑛𝑑1𝑑subscriptsuperscript𝐰2𝑖superscript𝐯1𝑑subscriptsuperscript𝐰1𝑖superscript𝐯𝜂𝛼𝔼delimited-[]subscript𝑎𝑗superscript𝑔superscript𝐡superscript𝜎subscriptsuperscript𝒉1𝑗superscriptsuperscript𝒉topsuperscript𝐮𝜂𝛼𝔼delimited-[]subscript𝑎𝑗𝑓superscript𝒉1superscript𝜎subscriptsuperscript𝒉1𝑗superscriptsuperscript𝒉topsuperscript𝐮\frac{1}{d}\langle\mathbf{w}^{(2)}_{i},\mathbf{v}^{\star}\rangle-\frac{1}{d}% \langle\mathbf{w}^{(1)}_{i},\mathbf{v}^{\star}\rangle\xrightarrow[P]{n,d% \rightarrow\infty}\eta\alpha\mathbb{E}\left[a_{j}g^{\star}(\mathbf{h}^{\star})% \sigma^{\prime}(\bm{h}^{(1)}_{j})\left(\bm{h}^{\star}\right)^{\top}\mathbf{u}^% {\star}\right]-\eta\alpha\mathbb{E}\left[a_{j}f(\bm{h}^{(1)})\sigma^{\prime}(% \bm{h}^{(1)}_{j})\left(\bm{h}^{\star}\right)^{\top}\mathbf{u}^{\star}\right].divide start_ARG 1 end_ARG start_ARG italic_d end_ARG ⟨ bold_w start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩ - divide start_ARG 1 end_ARG start_ARG italic_d end_ARG ⟨ bold_w start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩ start_ARROW underitalic_P start_ARROW start_OVERACCENT italic_n , italic_d → ∞ end_OVERACCENT → end_ARROW end_ARROW italic_η italic_α blackboard_E [ italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ] - italic_η italic_α blackboard_E [ italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_f ( bold_italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ] . (83)

We focus on the first term in the RHS. By assumption, 1d𝐰i(0),𝐯,1d𝐰i(1),𝐯1𝑑subscriptsuperscript𝐰0𝑖superscript𝐯1𝑑subscriptsuperscript𝐰1𝑖superscript𝐯\frac{1}{d}\langle\mathbf{w}^{(0)}_{i},\mathbf{v}^{\star}\rangle,\frac{1}{d}% \langle\mathbf{w}^{(1)}_{i},\mathbf{v}^{\star}\rangledivide start_ARG 1 end_ARG start_ARG italic_d end_ARG ⟨ bold_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩ , divide start_ARG 1 end_ARG start_ARG italic_d end_ARG ⟨ bold_w start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩ converge in probability to 00. Therefore, using Equation (76), we obtain:

1d𝐰i(2),𝐯1𝑑subscriptsuperscript𝐰2𝑖superscript𝐯\displaystyle\frac{1}{d}\langle\mathbf{w}^{(2)}_{i},\mathbf{v}^{*}\rangledivide start_ARG 1 end_ARG start_ARG italic_d end_ARG ⟨ bold_w start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ⟩ 𝑃n,dηα𝔼[ajg(𝐡)σ(𝒉j(1))(𝒉)𝐮]𝑃𝑛𝑑absent𝜂𝛼𝔼delimited-[]subscript𝑎𝑗superscript𝑔superscript𝐡superscript𝜎subscriptsuperscript𝒉1𝑗superscriptsuperscript𝒉topsuperscript𝐮\displaystyle\xrightarrow[P]{n,d\rightarrow\infty}\eta\alpha\mathbb{E}\left[a_% {j}g^{\star}(\mathbf{h}^{\star})\sigma^{\prime}(\bm{h}^{(1)}_{j})\left(\bm{h}^% {\star}\right)^{\top}\mathbf{u}^{\star}\right]start_ARROW underitalic_P start_ARROW start_OVERACCENT italic_n , italic_d → ∞ end_OVERACCENT → end_ARROW end_ARROW italic_η italic_α blackboard_E [ italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ] (84)
=ηα𝔼[ajg(𝐡)σ(𝒉(𝒉(0))+𝝎(1))(𝒉)𝐮].absent𝜂𝛼𝔼delimited-[]subscript𝑎𝑗superscript𝑔superscript𝐡superscript𝜎subscript𝒉superscript𝒉0superscript𝝎1superscriptsuperscript𝒉topsuperscript𝐮\displaystyle=\eta\alpha\mathbb{E}\left[a_{j}g^{\star}(\mathbf{h}^{\star})% \sigma^{\prime}(-\nabla_{\bm{h}}\mathcal{L}(\bm{h}^{(0)})+\bm{\omega}^{(1)})% \left(\bm{h}^{\star}\right)^{\top}\mathbf{u}^{\star}\right].= italic_η italic_α blackboard_E [ italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( - ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) + bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ] . (85)

Recall that by the choice of initialization, Cθ(0,0)superscriptsubscript𝐶𝜃00C_{\theta}^{(0,0)}italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 , 0 ) end_POSTSUPERSCRIPT are diagonal with entries 1111 except for the off-diagonal entries corresponding to pairing of neurons through the symmetric initialization. Furthermore, by initialization, M(0)=𝟎superscript𝑀00M^{(0)}=\mathbf{0}italic_M start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT = bold_0 and by assumption M(1)=𝟎superscript𝑀10M^{(1)}=\mathbf{0}italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = bold_0.

From Lemma A.4, and the definitions of we have that by setting ηγ=1𝜂𝛾1\eta\gamma=1italic_η italic_γ = 1, the covariance Cθ(0,1),Cθ(1,1)superscriptsubscript𝐶𝜃01superscriptsubscript𝐶𝜃11C_{\theta}^{(0,1)},C_{\theta}^{(1,1)}italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 , 1 ) end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 , 1 ) end_POSTSUPERSCRIPT simplify to:

Cθ(0,1)=ηΛ(0)Cθ(0,0)superscriptsubscript𝐶𝜃01𝜂superscriptΛ0superscriptsubscript𝐶𝜃00\displaystyle C_{\theta}^{(0,1)}=-\eta\Lambda^{(0)}C_{\theta}^{(0,0)}italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 , 1 ) end_POSTSUPERSCRIPT = - italic_η roman_Λ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 , 0 ) end_POSTSUPERSCRIPT (86)
Cθ(1,1)=Λ(0)Cθ(0,1)+ηC(0,1)superscriptsubscript𝐶𝜃11superscriptΛ0superscriptsubscript𝐶𝜃01𝜂superscriptsubscript𝐶01\displaystyle C_{\theta}^{(1,1)}=\Lambda^{(0)}C_{\theta}^{(0,1)}+\eta C_{\ell}% ^{(0,1)}italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 , 1 ) end_POSTSUPERSCRIPT = roman_Λ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 , 1 ) end_POSTSUPERSCRIPT + italic_η italic_C start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 , 1 ) end_POSTSUPERSCRIPT (87)

By definition, Λ(0),C(0,1)superscriptΛ0superscriptsubscript𝐶01\Lambda^{(0)},C_{\ell}^{(0,1)}roman_Λ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , italic_C start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 , 1 ) end_POSTSUPERSCRIPT have diagonal entries proportional to aj,aj2subscript𝑎𝑗superscriptsubscript𝑎𝑗2a_{j},a_{j}^{2}italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT respectively. Therefore, we can further set η>0𝜂0\eta>0italic_η > 0 such that the jthsubscript𝑗𝑡j_{th}italic_j start_POSTSUBSCRIPT italic_t italic_h end_POSTSUBSCRIPT diagonal entry of Cθ1,1subscriptsuperscript𝐶11𝜃C^{1,1}_{\theta}italic_C start_POSTSUPERSCRIPT 1 , 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT equals 1+aj21subscriptsuperscript𝑎2𝑗1+a^{2}_{j}1 + italic_a start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT. By case 1111 in section A.4, we further have that ω(1)superscript𝜔1\mathbf{\omega}^{(1)}italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT is independent of 𝐡superscript𝐡\mathbf{h}^{*}bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT.

Substituting 𝒉(𝒉(0))=ag(𝐡)σ(𝒉(0))subscript𝒉superscript𝒉0𝑎superscript𝑔superscript𝐡superscript𝜎superscript𝒉0\nabla_{\bm{h}}\mathcal{L}(\bm{h}^{(0)})=-ag^{*}(\mathbf{h}^{*})\sigma^{\prime% }(\bm{h}^{(0)})∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) = - italic_a italic_g start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_h start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ), we obtain the precise condition on σ𝜎\sigmaitalic_σ for the jthsubscript𝑗𝑡j_{th}italic_j start_POSTSUBSCRIPT italic_t italic_h end_POSTSUBSCRIPT neuron to learn direction 𝐯superscript𝐯\mathbf{v}^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT the second timestep. The condition is given by:

ϕ(aj)=𝔼[g(𝐡)σ(ηajg(𝐡)σ(hj0)+ajξ)𝐡,𝐮]0,italic-ϕsubscript𝑎𝑗𝔼delimited-[]superscript𝑔superscript𝐡superscript𝜎𝜂subscript𝑎𝑗superscript𝑔superscript𝐡superscript𝜎subscriptsuperscript0𝑗subscript𝑎𝑗𝜉superscript𝐡superscript𝐮0\phi(a_{j})=\mathbb{E}\left[g^{\star}(\mathbf{h^{\star}})\sigma^{\prime}(\eta a% _{j}g^{\star}(\mathbf{h^{\star}})\sigma^{\prime}(h^{0}_{j})+a_{j}\xi)\langle% \mathbf{h^{\star}},\mathbf{u}^{\star}\rangle\right]\neq 0,italic_ϕ ( italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = blackboard_E [ italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_η italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_ξ ) ⟨ bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩ ] ≠ 0 , (88)

where 𝐡superscript𝐡\mathbf{h^{\star}}bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and ξ𝜉\xiitalic_ξ are independent Gaussian random variables. Since 𝐡superscript𝐡\mathbf{h}^{\star}bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT matches in distribution 1d𝐖𝐳1𝑑superscript𝐖𝐳\frac{1}{\sqrt{d}}\mathbf{W}^{\star}\mathbf{z}divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG bold_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT bold_z, the above condition can equivalently be expressed as the following condition on fsuperscript𝑓f^{\star}italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT:

ϕ(aj)=𝔼𝐳[Fσ,𝐚(f(𝐳))𝐯,𝐳]0,italic-ϕsubscript𝑎𝑗subscript𝔼𝐳delimited-[]subscript𝐹𝜎𝐚superscript𝑓𝐳superscript𝐯𝐳0\phi(a_{j})=\mathbb{E}_{\mathbf{z}}\left[F_{\sigma,\mathbf{a}}(f^{\star}(% \mathbf{z}))\langle\mathbf{v}^{\star},\mathbf{z}\rangle\right]\neq 0,italic_ϕ ( italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = blackboard_E start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT [ italic_F start_POSTSUBSCRIPT italic_σ , bold_a end_POSTSUBSCRIPT ( italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) ) ⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z ⟩ ] ≠ 0 , (89)

where:

Fσ,𝐚(x)=𝔼ξ1,ξ2[xσ(ηajσ(u)x+ajξ)],subscript𝐹𝜎𝐚𝑥subscript𝔼subscript𝜉1subscript𝜉2delimited-[]𝑥superscript𝜎𝜂subscript𝑎𝑗superscript𝜎𝑢𝑥subscript𝑎𝑗𝜉F_{\sigma,\mathbf{a}}(x)=\mathbb{E}_{\xi_{1},\xi_{2}}\left[x\sigma^{\prime}(% \eta a_{j}\sigma^{\prime}(u)x+a_{j}\xi)\right],italic_F start_POSTSUBSCRIPT italic_σ , bold_a end_POSTSUBSCRIPT ( italic_x ) = blackboard_E start_POSTSUBSCRIPT italic_ξ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_ξ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_x italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_η italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_u ) italic_x + italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_ξ ) ] , (90)

where u is a standard normal variable, corresponding to hj(0)superscriptsubscript𝑗0h_{j}^{(0)}italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT. The above expectation ϕitalic-ϕ\phiitalic_ϕ is an analytic function of ajsubscript𝑎𝑗a_{j}italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT. To show that it is identically non-zero, we consider the derivative w.r.t ajsubscript𝑎𝑗a_{j}italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT at aj=0subscript𝑎𝑗0a_{j}=0italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = 0. We have, using the dominated-convergence theorem:

ϕ(0)superscriptitalic-ϕ0\displaystyle\phi^{\prime}(0)italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( 0 ) =η𝔼[σ(u)g(𝐡)2σ′′(0)𝐡,𝐮]absent𝜂𝔼delimited-[]superscript𝜎𝑢superscript𝑔superscriptsuperscript𝐡2superscript𝜎′′0superscript𝐡superscript𝐮\displaystyle=\eta\mathbb{E}\left[\sigma^{\prime}(u)g^{\star}(\mathbf{h^{\star% }})^{2}\sigma^{\prime\prime}(0)\langle\mathbf{h^{\star}},\mathbf{u}^{\star}% \rangle\right]= italic_η blackboard_E [ italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_u ) italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( 0 ) ⟨ bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩ ]
=η𝔼[g(𝐡)2𝐡,𝐮]𝔼[σ(u)]σ′′(0)absent𝜂𝔼delimited-[]superscript𝑔superscriptsuperscript𝐡2superscript𝐡superscript𝐮𝔼delimited-[]superscript𝜎𝑢superscript𝜎′′0\displaystyle=\eta\mathbb{E}\left[g^{\star}(\mathbf{h^{\star}})^{2}\langle% \mathbf{h^{\star}},\mathbf{u}^{\star}\rangle\right]\mathbb{E}\left[\sigma^{% \prime}(u)\right]\sigma^{\prime\prime}(0)= italic_η blackboard_E [ italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⟨ bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩ ] blackboard_E [ italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_u ) ] italic_σ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( 0 )
=η𝔼[g(𝐡)2𝐡,𝐮]ν1(σ)σ′′(0),absent𝜂𝔼delimited-[]superscript𝑔superscriptsuperscript𝐡2superscript𝐡superscript𝐮subscript𝜈1𝜎superscript𝜎′′0\displaystyle=\eta\mathbb{E}\left[g^{\star}(\mathbf{h^{\star}})^{2}\langle% \mathbf{h^{\star}},\mathbf{u}^{\star}\rangle\right]\nu_{1}(\sigma)\sigma^{% \prime\prime}(0),= italic_η blackboard_E [ italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⟨ bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩ ] italic_ν start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_σ ) italic_σ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( 0 ) ,

where ν1(σ)subscript𝜈1𝜎\nu_{1}(\sigma)italic_ν start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_σ ) denotes the 1stsubscript1𝑠𝑡1_{st}1 start_POSTSUBSCRIPT italic_s italic_t end_POSTSUBSCRIPT Hermite-coefficients of σ𝜎\sigmaitalic_σ. Similarly, iterating k𝑘kitalic_k-times, we have:

Dajkϕ(𝐚)=ηk+1𝔼[(g(𝐡))k+1𝐡,𝐮]ν1k(σ)σk+1(0),superscriptsubscript𝐷subscript𝑎𝑗𝑘italic-ϕ𝐚superscript𝜂𝑘1𝔼delimited-[]superscriptsuperscript𝑔superscript𝐡𝑘1superscript𝐡superscript𝐮subscriptsuperscript𝜈𝑘1𝜎superscript𝜎𝑘10D_{a_{j}}^{k}\phi(\mathbf{a})=\eta^{k+1}\mathbb{E}\left[(g^{\star}(\mathbf{h^{% *}}))^{k+1}\langle\mathbf{h^{\star}},\mathbf{u}^{\star}\rangle\right]\nu^{k}_{% 1}(\sigma)\sigma^{k+1}(0),italic_D start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_ϕ ( bold_a ) = italic_η start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT blackboard_E [ ( italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT ⟨ bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩ ] italic_ν start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_σ ) italic_σ start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT ( 0 ) , (91)

where σk+1(σ)superscript𝜎𝑘1𝜎\sigma^{k+1}(\sigma)italic_σ start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT ( italic_σ ) denotes the kthsubscript𝑘𝑡k_{th}italic_k start_POSTSUBSCRIPT italic_t italic_h end_POSTSUBSCRIPT derivative of σ𝜎\sigmaitalic_σ. Note that for ϕ(𝐚)italic-ϕ𝐚\phi(\mathbf{a})italic_ϕ ( bold_a ) to not be identically zero, it is sufficient that Dajk1superscriptsubscript𝐷subscript𝑎𝑗𝑘1D_{a_{j}}^{k-1}italic_D start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k - 1 end_POSTSUPERSCRIPT is non-zero for some k𝑘k\in\mathbb{N}italic_k ∈ blackboard_N. Since by assumption, 𝐯Psuperscript𝐯subscriptsuperscript𝑃perpendicular-to\mathbf{v}^{*}\in P^{*}_{\perp}bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∈ italic_P start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT, and since the monomials 1,x,x2,1𝑥superscript𝑥21,x,x^{2},\cdots1 , italic_x , italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , ⋯ span the space of polynomials, we have that there exists a k𝑘k\in\mathbb{N}italic_k ∈ blackboard_N such that: 𝔼[(g(𝐡))k𝐡,𝐮]0𝔼delimited-[]superscriptsuperscript𝑔superscript𝐡𝑘superscript𝐡superscript𝐮0\mathbb{E}\left[(g^{\star}(\mathbf{h^{*}}))^{k}\langle\mathbf{h^{\star}},% \mathbf{u}^{\star}\rangle\right]\neq 0blackboard_E [ ( italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ⟨ bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩ ] ≠ 0. The conditions on σ𝜎\sigmaitalic_σ further imply that ν1k(σ)σk+1(0)0,kformulae-sequencesubscriptsuperscript𝜈𝑘1𝜎superscript𝜎𝑘100for-all𝑘\nu^{k}_{1}(\sigma)\sigma^{k+1}(0)\neq 0,\ \forall k\in\mathbb{N}italic_ν start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_σ ) italic_σ start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT ( 0 ) ≠ 0 , ∀ italic_k ∈ blackboard_N. Therefore ϕ(aj)italic-ϕsubscript𝑎𝑗\phi(a_{j})italic_ϕ ( italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) is a not identically 00.

Since ϕ(aj)italic-ϕsubscript𝑎𝑗\phi(a_{j})italic_ϕ ( italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) is an analytic function non-identically zero, and the law of 𝐚(0)𝒩(0,1p𝟙p)similar-tosuperscript𝐚0𝒩01𝑝subscript1𝑝\mathbf{a}^{(0)}\sim\mathcal{N}(0,\frac{1}{p}{\mathbbm{1}}_{p})bold_a start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ∼ caligraphic_N ( 0 , divide start_ARG 1 end_ARG start_ARG italic_p end_ARG blackboard_1 start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ) is absolutely continuous w.r.t the Lebesgue measure, we have that ϕ(aj)italic-ϕsubscript𝑎𝑗\phi(a_{j})italic_ϕ ( italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) is non-zero almost surely over the initialization. Now, the second term in Equation 82 is again an analytic function in 𝐚𝐚\mathbf{a}bold_a, distinct from ϕ(aj)italic-ϕsubscript𝑎𝑗\phi(a_{j})italic_ϕ ( italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ), and can therefore be almost surely absorbed into the non-zero overlap. This proves the first part of Theorem 3.2 for develo** an overlap along a fixed direction in Psubscriptsuperscript𝑃perpendicular-toP^{*}_{\perp}italic_P start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT when M(1)=0superscript𝑀10M^{(1)}=0italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = 0. We now proceed to show that the weights W2superscript𝑊2W^{2}italic_W start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT span Psubscriptsuperscript𝑃perpendicular-toP^{*}_{\perp}italic_P start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT.

Let r𝑟ritalic_r denote the dimension of the subspace Psubscriptsuperscript𝑃perpendicular-toP^{*}_{\perp}italic_P start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT. Suppose that v1=(W)𝐮1,v2,=(W)𝐮2,,vr=(W)𝐮rformulae-sequencesubscriptsuperscript𝑣1superscriptsuperscript𝑊topsuperscriptsubscript𝐮1subscriptsuperscript𝑣2formulae-sequencesuperscriptsuperscript𝑊topsuperscriptsubscript𝐮2subscriptsuperscript𝑣𝑟superscriptsuperscript𝑊topsuperscriptsubscript𝐮𝑟v^{*}_{1}=(W^{\star})^{\top}\mathbf{u}_{1}^{\star},v^{*}_{2},\cdots=(W^{\star}% )^{\top}\mathbf{u}_{2}^{\star},\cdots,v^{*}_{r}=(W^{\star})^{\top}\mathbf{u}_{% r}^{\star}italic_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ( italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ⋯ = ( italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_u start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , ⋯ , italic_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT = ( italic_W start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_u start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT form an orthonormal basis of Psubscriptsuperscript𝑃perpendicular-toP^{*}_{\perp}italic_P start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT. Let Vd×rsuperscript𝑉superscript𝑑𝑟V^{*}\in\mathbb{R}^{d\times r}italic_V start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_r end_POSTSUPERSCRIPT matrix Muk×rsuperscriptsubscript𝑀𝑢superscript𝑘𝑟M_{u}^{*}\in\mathbb{R}^{k\times r}italic_M start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_k × italic_r end_POSTSUPERSCRIPT denote matrices with columns𝐯1,,𝐯rsuperscriptsubscript𝐯1superscriptsubscript𝐯𝑟\mathbf{v}_{1}^{\star},\cdots,\mathbf{v}_{r}^{\star}bold_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , ⋯ , bold_v start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT and 𝐮1,,𝐮rsuperscriptsubscript𝐮1superscriptsubscript𝐮𝑟\mathbf{u}_{1}^{\star},\cdots,\mathbf{u}_{r}^{\star}bold_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , ⋯ , bold_u start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT respectively.

Analogous to Equation (83), we obtain:

1dW2V1dW1V𝑃n,dηα𝔼[𝒉(𝒉(1),𝐡)(𝒉)Mu],𝑃𝑛𝑑1𝑑superscript𝑊2superscript𝑉1𝑑superscript𝑊1superscript𝑉𝜂𝛼𝔼delimited-[]subscript𝒉superscript𝒉1superscript𝐡superscriptsuperscript𝒉topsuperscriptsubscript𝑀𝑢\frac{1}{d}W^{2}V^{\star}-\frac{1}{d}W^{1}V^{\star}\xrightarrow[P]{n,d% \rightarrow\infty}-\eta\alpha\mathbb{E}\left[\nabla_{\bm{h}}\mathcal{L}(\bm{h}% ^{(1)},\mathbf{h}^{*})\left(\bm{h}^{\star}\right)^{\top}M_{u}^{*}\right],divide start_ARG 1 end_ARG start_ARG italic_d end_ARG italic_W start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_V start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_d end_ARG italic_W start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT italic_V start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_ARROW underitalic_P start_ARROW start_OVERACCENT italic_n , italic_d → ∞ end_OVERACCENT → end_ARROW end_ARROW - italic_η italic_α blackboard_E [ ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ] , (92)

Following the derivation of Equation (89), we obtain that the rows of the matrix 𝔼[𝒉(𝒉(1),𝐡)(𝒉)U]𝔼delimited-[]subscript𝒉superscript𝒉1superscript𝐡superscriptsuperscript𝒉topsuperscript𝑈\mathbb{E}\left[\nabla_{\bm{h}}\mathcal{L}(\bm{h}^{(1)},\mathbf{h}^{*})\left(% \bm{h}^{\star}\right)^{\top}U^{*}\right]blackboard_E [ ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_U start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ] are independent for neurons i,j𝑖𝑗i,jitalic_i , italic_j for jpi+1𝑗𝑝𝑖1j\neq p-i+1italic_j ≠ italic_p - italic_i + 1 (due to the symmetric initialization in Equation (4). Furthermore each row of the matrix is absolutely continuous w.r.t the Lebesgue measure on rsuperscript𝑟\mathbb{R}^{r}blackboard_R start_POSTSUPERSCRIPT italic_r end_POSTSUPERSCRIPT. This implies that 1dW2V1dW1V1𝑑superscript𝑊2superscript𝑉1𝑑superscript𝑊1superscript𝑉\frac{1}{d}W^{2}V^{\star}-\frac{1}{d}W^{1}V^{\star}divide start_ARG 1 end_ARG start_ARG italic_d end_ARG italic_W start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_V start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_d end_ARG italic_W start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT italic_V start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT has full row-rank almost surely for large enough p𝑝pitalic_p.

Now, suppose that 𝐯superscript𝐯\mathbf{v}^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT instead lies in the even-symmetric subspace Asuperscript𝐴A^{\star}italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. By induction and closure properties of analytic functions, we have that 𝒉(t)superscript𝒉𝑡\bm{h}^{(t)}bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT can be expressed as:

𝒉(t)=t(𝒉(),ω1,𝝎(1),,𝝎(t)),superscript𝒉𝑡subscript𝑡superscript𝒉subscript𝜔1superscript𝝎1superscript𝝎𝑡\bm{h}^{(t)}=\mathcal{F}_{t}(\bm{h}^{(*)},\omega_{1},\bm{\omega}^{(1)},\cdots,% \bm{\omega}^{(t)}),bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = caligraphic_F start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_h start_POSTSUPERSCRIPT ( ∗ ) end_POSTSUPERSCRIPT , italic_ω start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , ⋯ , bold_italic_ω start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , (93)

for an analytic map** tsubscript𝑡\mathcal{F}_{t}caligraphic_F start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. Now, similar to Equation (83), we have that:

1dWt𝐯1dWt1𝐯𝑃n,dηα𝔼[𝒉(𝒉(t),𝐡)(𝒉)𝐮],𝑃𝑛𝑑1𝑑superscript𝑊𝑡superscript𝐯1𝑑superscript𝑊𝑡1superscript𝐯𝜂𝛼𝔼delimited-[]subscript𝒉superscript𝒉𝑡superscript𝐡superscriptsuperscript𝒉topsuperscript𝐮\frac{1}{d}W^{t}\mathbf{v}^{\star}-\frac{1}{d}W^{t-1}\mathbf{v}^{\star}% \xrightarrow[P]{n,d\rightarrow\infty}-\eta\alpha\mathbb{E}\left[\nabla_{\bm{h}% }\mathcal{L}(\bm{h}^{(t)},\mathbf{h}^{*})\left(\bm{h}^{\star}\right)^{\top}% \mathbf{u}^{\star}\right],divide start_ARG 1 end_ARG start_ARG italic_d end_ARG italic_W start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_d end_ARG italic_W start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT start_ARROW underitalic_P start_ARROW start_OVERACCENT italic_n , italic_d → ∞ end_OVERACCENT → end_ARROW end_ARROW - italic_η italic_α blackboard_E [ ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ] , (94)

Using Fubini’s theorem, we may take expectation w.r.t to express each entry of 𝔼[𝒉(𝒉(t),𝐡)(𝒉)𝐮]𝔼delimited-[]subscript𝒉superscript𝒉𝑡superscript𝐡superscriptsuperscript𝒉topsuperscript𝐮\mathbb{E}\left[\nabla_{\bm{h}}\mathcal{L}(\bm{h}^{(t)},\mathbf{h}^{*})\left(% \bm{h}^{\star}\right)^{\top}\mathbf{u}^{\star}\right]blackboard_E [ ∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L ( bold_italic_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ] as:

𝔼𝐳[Ft,𝐚(f(𝐳))𝐯,𝐳],subscript𝔼𝐳delimited-[]subscript𝐹𝑡𝐚superscript𝑓𝐳superscript𝐯𝐳\mathbb{E}_{\mathbf{z}}\left[F_{t,\mathbf{a}}(f^{\star}(\mathbf{z}))\langle% \mathbf{v}^{\star},\mathbf{z}\rangle\right]\,,blackboard_E start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT [ italic_F start_POSTSUBSCRIPT italic_t , bold_a end_POSTSUBSCRIPT ( italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) ) ⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z ⟩ ] , (95)

for some analytic Ft,𝐚subscript𝐹𝑡𝐚F_{t,\mathbf{a}}italic_F start_POSTSUBSCRIPT italic_t , bold_a end_POSTSUBSCRIPT This ensures that the expectation in (82) remains 00 for all time t𝑡titalic_t. This proves the second part of Theorem 3.2.

A.6 Effect of previously learned directions

We now consider the case when M(1)0superscript𝑀10M^{(1)}\neq 0italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ≠ 0, i.e when the first-layer develops an overlap along Usuperscript𝑈U^{*}italic_U start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. As shown in 4.1, the rows of M(1)superscript𝑀1M^{(1)}italic_M start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT lie along the same direction given by 𝔼[g(𝐡)(𝒉)]𝔼delimited-[]superscript𝑔superscript𝐡superscript𝒉\mathbb{E}\left[g^{*}(\mathbf{h}^{*})(\bm{h}^{*})\right]blackboard_E [ italic_g start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ]. Without loss of generality, we assume that the direction 𝔼[g(𝐡)(𝒉)]𝔼delimited-[]superscript𝑔superscript𝐡superscript𝒉\mathbb{E}\left[g^{*}(\mathbf{h}^{*})(\bm{h}^{*})\right]blackboard_E [ italic_g start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ( bold_italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ] corresponds to 𝐞1subscript𝐞1\mathbf{e}_{1}bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT in the input space dsuperscript𝑑\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT and that Wsuperscript𝑊W^{*}italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT has rows along the standard basis 𝐞1,,𝐞ksubscript𝐞1subscript𝐞𝑘\mathbf{e}_{1},\cdots,\mathbf{e}_{k}bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ , bold_e start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. Note that 𝐞1subscript𝐞1\mathbf{e}_{1}bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT itself lies in Psubscriptsuperscript𝑃perpendicular-toP^{*}_{\perp}italic_P start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT by setting F(x)=x𝐹𝑥𝑥F(x)=xitalic_F ( italic_x ) = italic_x in definition 3.1.

From Equation (80), we obtain:

𝝎j(1)=𝝎𝒋(1)+ηCajh1superscriptsubscript𝝎𝑗1superscriptsubscriptsubscript𝝎perpendicular-to𝒋1𝜂𝐶subscript𝑎𝑗subscriptsuperscript1\bm{\omega}_{j}^{(1)}=\bm{{\omega_{\perp}}_{j}}^{(1)}+\eta Ca_{j}h^{*}_{1}\,bold_italic_ω start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = bold_italic_ω start_POSTSUBSCRIPT bold_⟂ end_POSTSUBSCRIPT start_POSTSUBSCRIPT bold_italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT + italic_η italic_C italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT (96)

where C𝐶Citalic_C denotes a constant dependent on gsuperscript𝑔g^{*}italic_g start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. Since 𝝎(1)superscript𝝎1\bm{\omega}^{(1)}bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT is now correlated with h1subscriptsuperscript1h^{*}_{1}italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, the condition in Equation (88) is modified to:

ϕ(aj)=𝔼[g(𝐡)σ(ηajσ(u)g(𝐡)+ξ1+ajξ2+ηCajh1)𝐡,𝐮]0,italic-ϕsubscript𝑎𝑗𝔼delimited-[]superscript𝑔superscript𝐡superscript𝜎𝜂subscript𝑎𝑗superscript𝜎𝑢superscript𝑔superscript𝐡subscript𝜉1subscript𝑎𝑗subscript𝜉2𝜂𝐶subscript𝑎𝑗subscriptsuperscript1superscript𝐡superscript𝐮0\phi(a_{j})=\mathbb{E}\left[g^{\star}(\mathbf{h^{\star}})\sigma^{\prime}(\eta a% _{j}\sigma^{\prime}(u)g^{\star}(\mathbf{h^{\star}})+\xi_{1}+a_{j}\xi_{2}+\eta Ca% _{j}h^{*}_{1})\langle\mathbf{h^{\star}},\mathbf{u}^{\star}\rangle\right]\neq 0,italic_ϕ ( italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = blackboard_E [ italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_η italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_u ) italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) + italic_ξ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_ξ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_η italic_C italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ⟨ bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩ ] ≠ 0 , (97)

Again, differentiating w.r.t ajsubscript𝑎𝑗a_{j}italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, we obtain:

ϕ(0)=η𝔼[g(𝐡)2𝐡,𝐮]ν1(σ)ν2(σ)+η𝔼[g(𝐡)h1𝐡,𝐮]superscriptitalic-ϕ0𝜂𝔼delimited-[]superscript𝑔superscriptsuperscript𝐡2superscript𝐡superscript𝐮subscript𝜈1𝜎subscript𝜈2𝜎𝜂𝔼delimited-[]superscript𝑔superscript𝐡subscriptsuperscript1superscript𝐡superscript𝐮\phi^{\prime}(0)=\eta\mathbb{E}\left[g^{\star}(\mathbf{h^{\star}})^{2}\langle% \mathbf{h^{\star}},\mathbf{u}^{\star}\rangle\right]\nu_{1}(\sigma)\nu_{2}(% \sigma)+\eta\mathbb{E}\left[g^{\star}(\mathbf{h^{\star}})h^{*}_{1}\langle% \mathbf{h^{\star}},\mathbf{u}^{\star}\rangle\right]italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( 0 ) = italic_η blackboard_E [ italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⟨ bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩ ] italic_ν start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_σ ) italic_ν start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_σ ) + italic_η blackboard_E [ italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⟨ bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩ ] (98)

Similar to section A.5, we have that 𝐯Psuperscript𝐯subscriptsuperscript𝑃perpendicular-to\mathbf{v}^{*}\in P^{*}_{\perp}bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∈ italic_P start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT is sufficient for the first term to be non-zero almost surely over ajsubscript𝑎𝑗a_{j}italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT. If the second-term is non-zero, we have that 𝐮superscript𝐮\mathbf{u}^{\star}bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is learned through the staircase mechanism, since it implies that g(𝐡)superscript𝑔superscript𝐡g^{\star}(\mathbf{h^{\star}})italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) contains terms dependent on h1subscriptsuperscript1h^{*}_{1}italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and linearly coupled with 𝐡,𝐮superscript𝐡superscript𝐮\langle\mathbf{h^{\star}},\mathbf{u}^{\star}\rangle⟨ bold_h start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_u start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⟩. In either case, we obtain that W(2)superscript𝑊2W^{(2)}italic_W start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT almost surely obtains an overlap along 𝐯superscript𝐯\mathbf{v}^{*}bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. This concludes the proof of the first part of Theorem 3.2.

More, generally, suppose that 𝐞1,𝐞msubscript𝐞1subscript𝐞𝑚\mathbf{e}_{1},\cdots\mathbf{e}_{m}bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ bold_e start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT denote a basis of the directions in Usuperscript𝑈U^{*}italic_U start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT learned up to time t𝑡titalic_t. Then, the modified condition for learned a new direction 𝐯superscript𝐯\mathbf{v}^{*}bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT at time t+1𝑡1t+1italic_t + 1 is:

𝔼𝐳[F(f(𝐳),z1,zm)𝐯,𝐳]0,subscript𝔼𝐳delimited-[]𝐹superscript𝑓𝐳subscript𝑧1subscript𝑧𝑚superscript𝐯𝐳0\mathbb{E}_{\mathbf{z}}\left[F(f^{\star}(\mathbf{z}),z_{1},\cdots z_{m})% \langle\mathbf{v}^{\star},\mathbf{z}\rangle\right]\neq 0,blackboard_E start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT [ italic_F ( italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ italic_z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) ⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z ⟩ ] ≠ 0 , (99)

for a polynomial F:m+1:𝐹superscript𝑚1F:\mathbb{R}^{m+1}\rightarrow\mathbb{R}italic_F : blackboard_R start_POSTSUPERSCRIPT italic_m + 1 end_POSTSUPERSCRIPT → blackboard_R. Therefore, new directions can be learned through a combination of the staircase and hidden-progress mechanism.

A.7 Typical examples where E=P=Usuperscriptsubscript𝐸perpendicular-tosubscriptsuperscript𝑃perpendicular-tosuperscript𝑈E_{\perp}^{*}=P^{*}_{\perp}=U^{*}italic_E start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = italic_P start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT = italic_U start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT

For several target functions of interest, the class Psubscriptsuperscript𝑃perpendicular-toP^{*}_{\perp}italic_P start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT can be shown to cover the entire target space Usuperscript𝑈U^{*}italic_U start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. We list some of them below:

  • Single-index odd polynomials with all non-negative/non-positive coefficients. This follows since 𝔼𝐳[(f(𝐳))k𝐯,𝐳]subscript𝔼𝐳delimited-[]superscriptsuperscript𝑓𝐳𝑘superscript𝐯𝐳\mathbb{E}_{\mathbf{z}}\left[(f^{\star}(\mathbf{z}))^{k}\langle\mathbf{v}^{% \star},\mathbf{z}\rangle\right]blackboard_E start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT [ ( italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) ) start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z ⟩ ] decomposes into sums of non-negative/non-positive terms.

  • Single-index odd Hermite polynomials. We prove this below in Lemma A.7

  • Staircase function f(𝐳)=z1+z1z2+z1z2z3superscript𝑓𝐳subscript𝑧1subscript𝑧1subscript𝑧2subscript𝑧1subscript𝑧2subscript𝑧3f^{*}(\mathbf{z})=z_{1}+z_{1}z_{2}+z_{1}z_{2}z_{3}italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_z ) = italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT. This follows directly by evaluating 𝔼𝐳[(f(𝐳))2zi]subscript𝔼𝐳delimited-[]superscriptsuperscript𝑓𝐳2subscript𝑧𝑖\mathbb{E}_{\mathbf{z}}\left[(f^{\star}(\mathbf{z}))^{2}z_{i}\right]blackboard_E start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT [ ( italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] for i=2,3𝑖23i=2,3italic_i = 2 , 3.

In general, for polynomial fsuperscript𝑓f^{\star}italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, the condition:

𝔼𝐳[(f(𝐳))k𝐯,𝐳]=0,k,formulae-sequencesubscript𝔼𝐳delimited-[]superscriptsuperscript𝑓𝐳𝑘superscript𝐯𝐳0for-all𝑘\mathbb{E}_{\mathbf{z}}\left[(f^{\star}(\mathbf{z}))^{k}\langle\mathbf{v}^{% \star},\mathbf{z}\rangle\right]=0,\forall k\in\mathbb{N},blackboard_E start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT [ ( italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) ) start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z ⟩ ] = 0 , ∀ italic_k ∈ blackboard_N , (100)

specifies an overdetermined system of infinite homogenous polynomial equations on the coefficients of fsuperscript𝑓f^{\star}italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. Therefore we expect the condition to fail almost surely for typical choices of fsuperscript𝑓f^{\star}italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. We leave an investigation of this using algebraic tools to future investigation.

Lemma A.7.

For any odd Hermite-polynomial H2k+1subscript𝐻2superscript𝑘1H_{2k^{\prime}+1}italic_H start_POSTSUBSCRIPT 2 italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + 1 end_POSTSUBSCRIPT for ksuperscript𝑘k^{\prime}\in\mathbb{N}italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ blackboard_N,:

𝔼z[(H2k+1(z))3z],0\mathbb{E}_{z}\left[(H_{2k^{\prime}+1}(z))^{3}z\right],\neq 0blackboard_E start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT [ ( italic_H start_POSTSUBSCRIPT 2 italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + 1 end_POSTSUBSCRIPT ( italic_z ) ) start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT italic_z ] , ≠ 0 (101)

where z𝒩(0,1)similar-to𝑧𝒩01z\sim\mathcal{N}(0,1)italic_z ∼ caligraphic_N ( 0 , 1 )

Proof.

Using Stein’s Lemma, we have:

𝔼z[(H2k+1(z))3z]=3𝔼z[(H2k+1(z))2ddzH2k+1(z)]subscript𝔼𝑧delimited-[]superscriptsubscript𝐻2superscript𝑘1𝑧3𝑧3subscript𝔼𝑧delimited-[]superscriptsubscript𝐻2superscript𝑘1𝑧2𝑑𝑑𝑧subscript𝐻2superscript𝑘1𝑧\mathbb{E}_{z}\left[(H_{2k^{\prime}+1}(z))^{3}z\right]=3\mathbb{E}_{z}\left[(H% _{2k^{\prime}+1}(z))^{2}\frac{d}{dz}H_{2k^{\prime}+1}(z)\right]blackboard_E start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT [ ( italic_H start_POSTSUBSCRIPT 2 italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + 1 end_POSTSUBSCRIPT ( italic_z ) ) start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT italic_z ] = 3 blackboard_E start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT [ ( italic_H start_POSTSUBSCRIPT 2 italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + 1 end_POSTSUBSCRIPT ( italic_z ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT divide start_ARG italic_d end_ARG start_ARG italic_d italic_z end_ARG italic_H start_POSTSUBSCRIPT 2 italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + 1 end_POSTSUBSCRIPT ( italic_z ) ] (102)

Next, we recall the following relation between Hermite polynomials and their derivatives:

ddzHn(z)=nHn1(z),n.formulae-sequence𝑑𝑑𝑧subscript𝐻𝑛𝑧𝑛subscript𝐻𝑛1𝑧for-all𝑛\frac{d}{dz}H_{n}(z)=nH_{n-1}(z),\forall n\in\mathbb{N}.divide start_ARG italic_d end_ARG start_ARG italic_d italic_z end_ARG italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_z ) = italic_n italic_H start_POSTSUBSCRIPT italic_n - 1 end_POSTSUBSCRIPT ( italic_z ) , ∀ italic_n ∈ blackboard_N . (103)

Substituting in Equation (102), we obtain:

𝔼z[(H2k+1(z))3z]=3(2k+1)𝔼z[(H2k+1(z))2H2k(z)(z)].subscript𝔼𝑧delimited-[]superscriptsubscript𝐻2superscript𝑘1𝑧3𝑧32superscript𝑘1subscript𝔼𝑧delimited-[]superscriptsubscript𝐻2superscript𝑘1𝑧2subscript𝐻2superscript𝑘𝑧𝑧\mathbb{E}_{z}\left[(H_{2k^{\prime}+1}(z))^{3}z\right]=3(2k^{\prime}+1)\mathbb% {E}_{z}\left[(H_{2k^{\prime}+1}(z))^{2}H_{2k^{\prime}}(z)(z)\right].blackboard_E start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT [ ( italic_H start_POSTSUBSCRIPT 2 italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + 1 end_POSTSUBSCRIPT ( italic_z ) ) start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT italic_z ] = 3 ( 2 italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + 1 ) blackboard_E start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT [ ( italic_H start_POSTSUBSCRIPT 2 italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + 1 end_POSTSUBSCRIPT ( italic_z ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_H start_POSTSUBSCRIPT 2 italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_z ) ( italic_z ) ] . (104)

The above expectation can be obtained analytically using the linearization formulas for Hermite polynomials [Andrews, 2004] to show that 𝔼z[(H2k+1(z))2H2k(z)(z)]0subscript𝔼𝑧delimited-[]superscriptsubscript𝐻2superscript𝑘1𝑧2subscript𝐻2superscript𝑘𝑧𝑧0\mathbb{E}_{z}\left[(H_{2k^{\prime}+1}(z))^{2}H_{2k^{\prime}}(z)(z)\right]\neq 0blackboard_E start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT [ ( italic_H start_POSTSUBSCRIPT 2 italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + 1 end_POSTSUBSCRIPT ( italic_z ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_H start_POSTSUBSCRIPT 2 italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_z ) ( italic_z ) ] ≠ 0 for all ksuperscript𝑘k^{\prime}\in\mathbb{N}italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ blackboard_N. ∎

A.8 Proof of Proposition 3.5

Suppose that 𝐯OEsuperscript𝐯𝑂superscript𝐸\mathbf{v}^{\star}\in OE^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ italic_O italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT i.e 𝐯superscript𝐯\mathbf{v}^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is orthogonally even-symmetric w.r.t fsuperscript𝑓f^{\star}italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT for some transformation OO({v})subscript𝑂perpendicular-to𝑂subscriptsuperscript𝑣perpendicular-toO_{\perp}\in O(\{v^{\star}\}_{\perp})italic_O start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT ∈ italic_O ( { italic_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT ). Let z=OR𝐯𝐳superscript𝑧subscript𝑂perpendicular-tosubscript𝑅superscript𝐯𝐳z^{\prime}=O_{\perp}R_{\mathbf{v^{\star}}}\mathbf{z}italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_O start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_z. Then, by the invariance of the Gaussian measure under orthogonal transformations, we have:

𝔼𝐳[f(𝐳)𝐯,𝐳]=𝔼𝐳[f(𝐳)𝐯,𝐳].subscript𝔼𝐳delimited-[]superscript𝑓𝐳superscript𝐯𝐳subscript𝔼superscript𝐳delimited-[]superscript𝑓superscript𝐳superscript𝐯superscript𝐳\mathbb{E}_{\mathbf{z}}\left[f^{\star}(\mathbf{z})\langle\mathbf{v}^{\star},% \mathbf{z}\rangle\right]=\mathbb{E}_{\mathbf{z}^{\prime}}\left[f^{\star}(% \mathbf{z}^{\prime})\langle\mathbf{v}^{\star},\mathbf{z}^{\prime}\rangle\right].blackboard_E start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) ⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z ⟩ ] = blackboard_E start_POSTSUBSCRIPT bold_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⟩ ] . (105)

However, the expectation on the right can equivalently be expressed as:

𝔼𝐳[f(𝐳)𝐯,𝐳]subscript𝔼superscript𝐳delimited-[]superscript𝑓superscript𝐳superscript𝐯superscript𝐳\displaystyle\mathbb{E}_{\mathbf{z}^{\prime}}\left[f^{\star}(\mathbf{z}^{% \prime})\langle\mathbf{v}^{\star},\mathbf{z}^{\prime}\rangle\right]blackboard_E start_POSTSUBSCRIPT bold_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⟩ ] =𝔼𝐳[f(OR𝐯𝐳))𝐯,𝐳]\displaystyle=-\mathbb{E}_{\mathbf{z}}\left[f^{\star}(O_{\perp}R_{\mathbf{v}^{% \star}}\mathbf{z}))\langle\mathbf{v}^{\star},\mathbf{z}\rangle\right]= - blackboard_E start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_O start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_z ) ) ⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z ⟩ ]
=𝔼𝐳[f(𝐳)𝐯,𝐳]absentsubscript𝔼𝐳delimited-[]superscript𝑓𝐳superscript𝐯𝐳\displaystyle=-\mathbb{E}_{\mathbf{z}}\left[f^{\star}(\mathbf{z})\langle% \mathbf{v}^{\star},\mathbf{z}\rangle\right]= - blackboard_E start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) ⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z ⟩ ]

where in the second equality we used 𝐯,𝐳=𝐯,𝐳superscript𝐯superscript𝐳superscript𝐯𝐳\langle\mathbf{v}^{\star},\mathbf{z}^{\prime}\rangle=-\langle\mathbf{v}^{\star% },\mathbf{z}\rangle⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⟩ = - ⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z ⟩ and in the third the definition 3.4. Therefore, for any 𝐯OEsuperscript𝐯𝑂superscript𝐸\mathbf{v}^{\star}\in OE^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ italic_O italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, we have:

𝔼𝐳[f(𝐳)𝐯,𝐳]=0subscript𝔼𝐳delimited-[]superscript𝑓𝐳superscript𝐯𝐳0\mathbb{E}_{\mathbf{z}}\left[f^{\star}(\mathbf{z})\langle\mathbf{v}^{\star},% \mathbf{z}\rangle\right]=0blackboard_E start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT [ italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) ⟨ bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , bold_z ⟩ ] = 0 (106)

Furthermore, it is straightforward to see that 𝐯superscript𝐯\mathbf{v}^{\star}bold_v start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT remains orthogonally even-symmetric w.r.t the composition F(f())𝐹superscript𝑓F(f^{\star}(\cdot))italic_F ( italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( ⋅ ) ). Therefore, we have that EOEAPsuperscript𝐸𝑂superscript𝐸superscript𝐴superscript𝑃E^{\star}\subseteq OE^{\star}\subseteq A^{\star}\subseteq P^{\star}italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ italic_O italic_E start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ italic_A start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ⊆ italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT.

A.9 Illustration of non-even symmetric hard directions

Refer to caption
Figure 3: An illustration of a hard, non-even target f(𝐳)=z1z2z3+He3(z4)superscript𝑓𝐳subscript𝑧1subscript𝑧2subscript𝑧3subscriptHe3subscript𝑧4f^{\star}(\mathbf{z})\!=\!z_{1}z_{2}z_{3}\!+\!\mathrm{He}_{3}(z_{4})italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( bold_z ) = italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT + roman_He start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ) being learned by a student with p=4𝑝4p=4italic_p = 4 hidden units. We can see that, even when reusing the batch, the teacher can only learn the direction associated with z4subscript𝑧4z_{4}italic_z start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT, while kee** a zero overlap otherwise. The continuous lines are from the DMFT numerical integration, the dots are simulations with d=10000𝑑10000d\!=\!10000italic_d = 10000. In the legend the overlap with the n𝑛nitalic_n-th direction is the projection of the student weights in the subspace associated with znsubscript𝑧𝑛z_{n}italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. For this figure we have σ=relu𝜎relu\sigma\!=\!\rm reluitalic_σ = roman_relu, n=5d𝑛5𝑑n=5ditalic_n = 5 italic_d, η=0.2𝜂0.2\eta\!=\!0.2italic_η = 0.2.

We now show the existence of target functions where EOEsuperscript𝐸𝑂superscript𝐸E^{*}\neq OE^{*}italic_E start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ≠ italic_O italic_E start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. Without loss of generality, we assume that the rows of Wsuperscript𝑊W^{*}italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT lie along the standard Euclidean basis 𝐞1,𝐞2,,𝐞ksubscript𝐞1subscript𝐞2subscript𝐞𝑘\mathbf{e}_{1},\mathbf{e}_{2},\cdots,\mathbf{e}_{k}bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ⋯ , bold_e start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT

Lemma A.8.

Suppose that f(𝐳)=z1z2z3superscript𝑓𝐳subscript𝑧1subscript𝑧2subscript𝑧3f^{*}(\mathbf{z})=z_{1}z_{2}z_{3}italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_z ) = italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT. Let 𝐯=𝐞1+𝐞2+𝐞3superscript𝐯subscript𝐞1subscript𝐞2subscript𝐞3\mathbf{v}^{*}=\mathbf{e}_{1}+\mathbf{e}_{2}+\mathbf{e}_{3}bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + bold_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + bold_e start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT. Then 𝐯Esuperscript𝐯superscript𝐸\mathbf{v}^{*}\notin E^{*}bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∉ italic_E start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT but OE=U𝑂superscript𝐸superscript𝑈OE^{*}=U^{*}italic_O italic_E start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = italic_U start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT.

Proof.

𝐯Esuperscript𝐯superscript𝐸\mathbf{v}^{*}\notin E^{*}bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∉ italic_E start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT follows directly by noting that f(𝐳)superscript𝑓𝐳f^{*}(\mathbf{z})italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_z ) is even-symmetric along 𝐞1𝐞2subscript𝐞1subscript𝐞2\mathbf{e}_{1}-\mathbf{e}_{2}bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - bold_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and 𝐞1𝐞3subscript𝐞1subscript𝐞3\mathbf{e}_{1}-\mathbf{e}_{3}bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - bold_e start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT. We further have that a target fsuperscript𝑓f^{*}italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT satisfying E=Usuperscript𝐸superscript𝑈E^{*}=U^{*}italic_E start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = italic_U start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT must satisfy f(𝐳)=f(𝐳)𝐳dsuperscript𝑓𝐳superscript𝑓𝐳for-all𝐳superscript𝑑f^{*}(-\mathbf{z})=f^{*}(\mathbf{z})\ \forall\mathbf{z}\in\mathbb{R}^{d}italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( - bold_z ) = italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_z ) ∀ bold_z ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. Therefore, since f(𝐳)=f(𝐳)superscript𝑓𝐳superscript𝑓𝐳f^{*}(-\mathbf{z})=-f^{*}(\mathbf{z})italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( - bold_z ) = - italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_z ), fsuperscript𝑓f^{*}italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT cannot be even-symmetric along 𝐞1+𝐞2+𝐞3subscript𝐞1subscript𝐞2subscript𝐞3\mathbf{e}_{1}+\mathbf{e}_{2}+\mathbf{e}_{3}bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + bold_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + bold_e start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT. Next, we show that OE=U𝑂superscript𝐸superscript𝑈OE^{*}=U^{*}italic_O italic_E start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = italic_U start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. Since 𝐞1,𝐞2,𝐞3subscript𝐞1subscript𝐞2subscript𝐞3\mathbf{e}_{1},\mathbf{e}_{2},\mathbf{e}_{3}bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , bold_e start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT span Usuperscript𝑈U^{*}italic_U start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, and fsuperscript𝑓f^{*}italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT is symmetric w.r.t permutations of z1z2z3subscript𝑧1subscript𝑧2subscript𝑧3z_{1}z_{2}z_{3}italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT, it suffices to show that condition in definition 3.4 holds for 𝐯=𝐞1superscript𝐯subscript𝐞1\mathbf{v}^{*}=\mathbf{e}_{1}bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. The orthogonal complement {𝐯}subscriptsuperscript𝐯perpendicular-to\{\mathbf{v}^{*}\}_{\perp}{ bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT is given by span(𝐞2,𝐞3)spansubscript𝐞2subscript𝐞3\operatorname{span}(\mathbf{e}_{2},\mathbf{e}_{3})roman_span ( bold_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , bold_e start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ). Therefore the transformation O2subscript𝑂2O_{2}italic_O start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT defined by z2z2subscript𝑧2subscript𝑧2z_{2}\rightarrow-z_{2}italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT → - italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is a valid orthogonal transformation Osubscript𝑂perpendicular-toO_{\perp}italic_O start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT as per definition 3.4. We have:

f(O2R𝐯𝐳)superscript𝑓subscript𝑂2subscript𝑅superscript𝐯𝐳\displaystyle f^{*}(O_{2}R_{\mathbf{v}^{*}}\mathbf{z})italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_O start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_z ) =(z1)(z2)z3absentsubscript𝑧1subscript𝑧2subscript𝑧3\displaystyle=(-z_{1})(-z_{2})z_{3}= ( - italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ( - italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) italic_z start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT
=z1z2z3=f(𝐳).absentsubscript𝑧1subscript𝑧2subscript𝑧3superscript𝑓𝐳\displaystyle=z_{1}z_{2}z_{3}=f^{*}(\mathbf{z}).= italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT = italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_z ) .

This shows that 𝐞1subscript𝐞1\mathbf{e}_{1}bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT lies in OE𝑂superscript𝐸OE^{*}italic_O italic_E start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. Similarly, we have by symmetry 𝐞1OEsubscript𝐞1𝑂superscript𝐸\mathbf{e}_{1}\in OE^{*}bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ italic_O italic_E start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT

We present a numerical illustration of another such example in figure 3. ∎

One can in-fact construct a family of functions with a direction 𝐯superscript𝐯\mathbf{v}^{*}bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, for instance 𝐯=𝐞𝟏superscript𝐯subscript𝐞1\mathbf{v}^{*}=\mathbf{e_{1}}bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = bold_e start_POSTSUBSCRIPT bold_1 end_POSTSUBSCRIPT lying in OE𝑂superscript𝐸OE^{*}italic_O italic_E start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT but in general not in Esuperscript𝐸E^{*}italic_E start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. To see this, let f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT be a function dsuperscript𝑑\mathbb{R}^{d}\rightarrow\mathbb{R}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R, depending only on projections of 𝐳𝐳\mathbf{z}bold_z along {𝐞1}subscriptsubscript𝐞1perpendicular-to\{\mathbf{e}_{1}\}_{\perp}{ bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT } start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT and let Osubscript𝑂perpendicular-toO_{\perp}italic_O start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT by an involutory orthogonal transformation on {𝐞1}subscriptsubscript𝐞1perpendicular-to\{\mathbf{e}_{1}\}_{\perp}{ bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT } start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT i.e an orthogonal transformation satisfying O2=𝐈subscriptsuperscript𝑂2perpendicular-to𝐈O^{2}_{\perp}=\mathbf{I}italic_O start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT = bold_I or equivalently O=(O)subscript𝑂perpendicular-tosuperscriptsubscript𝑂perpendicular-totopO_{\perp}=(O_{\perp})^{\top}italic_O start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT = ( italic_O start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. Now, let f2::subscript𝑓2f_{2}:\mathbb{R}\rightarrow\mathbb{R}italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT : blackboard_R → blackboard_R be an odd function. Then, consider the function:

f(𝐳)=(f1(O𝐳)f(𝐳))f2(z1).f^{*}(\mathbf{z})=(f_{1}(O_{\perp}\mathbf{z})-f_{(}\mathbf{z}))f_{2}(z_{1}).italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_z ) = ( italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_O start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT bold_z ) - italic_f start_POSTSUBSCRIPT ( end_POSTSUBSCRIPT bold_z ) ) italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) . (107)

We observe that:

f(OR𝐞1𝐳)superscript𝑓𝑂subscript𝑅subscript𝐞1𝐳\displaystyle f^{*}(OR_{\mathbf{e}_{1}}\mathbf{z})italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_O italic_R start_POSTSUBSCRIPT bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT bold_z ) =(f1(O2𝐳)f(O𝐳))f2(z1)\displaystyle=(f_{1}(O^{2}_{\perp}\mathbf{z})-f_{(}O_{\perp}\mathbf{z}))f_{2}(% -z_{1})= ( italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_O start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT bold_z ) - italic_f start_POSTSUBSCRIPT ( end_POSTSUBSCRIPT italic_O start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT bold_z ) ) italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( - italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )
=(f1(O𝐳)f(𝐳))f2(z1)\displaystyle=(f_{1}(O_{\perp}\mathbf{z})-f_{(}\mathbf{z}))f_{2}(z_{1})= ( italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_O start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT bold_z ) - italic_f start_POSTSUBSCRIPT ( end_POSTSUBSCRIPT bold_z ) ) italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )
=f(𝐳),absentsuperscript𝑓𝐳\displaystyle=f^{*}(\mathbf{z}),= italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_z ) ,

where we used that O2=𝐈subscriptsuperscript𝑂2perpendicular-to𝐈O^{2}_{\perp}=\mathbf{I}italic_O start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⟂ end_POSTSUBSCRIPT = bold_I and f2(z1)=f2(z1)subscript𝑓2subscript𝑧1subscript𝑓2subscript𝑧1f_{2}(-z_{1})=-f_{2}(z_{1})italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( - italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = - italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ). Therefore, for any such function f(𝐳)superscript𝑓𝐳f^{*}(\mathbf{z})italic_f start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_z ), 𝐞1OEsubscript𝐞1𝑂superscript𝐸\mathbf{e}_{1}\in OE^{*}bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ italic_O italic_E start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT.

A.10 Implications for generalization

Since the specific guarantees of such results depend on the choice of activation and target functions, we illustrate this for the case of single-index target functions with matching activations:

Corollary A.9.

Consider the setting of a single-index target and student network with matching activations i.e. σ=g𝜎superscript𝑔\sigma=g^{\star}italic_σ = italic_g start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, such that σ𝜎\sigmaitalic_σ is a polynomial with finite degree, satisfying the following assumption, k𝑘\exists k\in\mathbb{N}∃ italic_k ∈ blackboard_N such that:

𝔼[σk(z)z]𝔼[Dkσ(z)]0,𝔼delimited-[]superscript𝜎𝑘𝑧𝑧𝔼delimited-[]superscript𝐷𝑘𝜎𝑧0\mathbb{E}\left[\sigma^{k}(z)z\right]\mathbb{E}\left[D^{k}\sigma(z)\right]\neq 0,blackboard_E [ italic_σ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( italic_z ) italic_z ] blackboard_E [ italic_D start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_σ ( italic_z ) ] ≠ 0 , (108)

, where Dksuperscript𝐷𝑘D^{k}italic_D start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT denotes the kthsubscript𝑘𝑡k_{th}italic_k start_POSTSUBSCRIPT italic_t italic_h end_POSTSUBSCRIPT derivative. Let 𝐰^^𝐰\hat{\mathbf{w}}over^ start_ARG bold_w end_ARG be the parameters obtained after two steps of gradient descent with batch size 𝒪(d)𝒪𝑑\mathcal{O}(d)caligraphic_O ( italic_d ) using η𝜂\etaitalic_η as in Theorem 3.2. Then, almost surely over the initialization a𝒩(0,1)similar-to𝑎𝒩01a\sim\mathcal{N}(0,1)italic_a ∼ caligraphic_N ( 0 , 1 ), for any ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0, there exists a step size ηsuperscript𝜂\eta^{\prime}italic_η start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT such that online SGD on squared loss reaches generalization error <epsilonabsent𝑒𝑝𝑠𝑖𝑙𝑜𝑛<epsilon< italic_e italic_p italic_s italic_i italic_l italic_o italic_n in time 𝒪(d)𝒪𝑑\mathcal{O}(d)caligraphic_O ( italic_d ).

We verify numerically that the above assumption holds in particular for all odd Hermite polynomials upto order 50505050. The corollary implies that such target functions can be learned with 𝒪(d)𝒪𝑑\mathcal{O}(d)caligraphic_O ( italic_d ) sample complexity using gradient descent alone, without resorting to specialized algorithms and techniques such as spectral initialization.

Proof.

Let 𝐰superscript𝐰\mathbf{w}^{*}bold_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT denote the single-direction in the teacher subspace with 𝐰=ddelimited-∥∥superscript𝐰𝑑\left\lVert\mathbf{w}^{*}\right\rVert=\sqrt{d}∥ bold_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∥ = square-root start_ARG italic_d end_ARG. We note that Equation (108) is proportional to the k1th𝑘subscript1𝑡{k-1}_{th}italic_k - 1 start_POSTSUBSCRIPT italic_t italic_h end_POSTSUBSCRIPT derivative of ϕ(aj)italic-ϕsubscript𝑎𝑗\phi(a_{j})italic_ϕ ( italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) defined in Equation (89). Therefore, the condition is sufficient to ensure that ϕ(aj)italic-ϕsubscript𝑎𝑗\phi(a_{j})italic_ϕ ( italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) is not identically zero and the student neuron almost surely develops an overlap along 𝐰superscript𝐰\mathbf{w}^{*}bold_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. The result then follows from Proposition 2.1 in [Ben Arous et al., 2021], which proves that upon weak recovery i.e a non-zero overlap the target direction 𝐯superscript𝐯\mathbf{v}^{*}bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, online SGD on a differentiable activation with polynomially bounded derivatives converges to strong recovery Concretely, for any starting non-zero overlap θ>0𝜃0\theta>0italic_θ > 0, for any ϵ>0superscriptitalic-ϵ0\epsilon^{\prime}>0italic_ϵ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT > 0, there exists Cϵ,θsubscript𝐶superscriptitalic-ϵ𝜃C_{\epsilon^{\prime},\theta}italic_C start_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_θ end_POSTSUBSCRIPT and small-enough step-size such that online SGD with time Cϵ,θdsubscript𝐶superscriptitalic-ϵ𝜃𝑑C_{\epsilon^{\prime},\theta}ditalic_C start_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_θ end_POSTSUBSCRIPT italic_d achieves overlap 1ϵ1superscriptitalic-ϵ1-\epsilon^{\prime}1 - italic_ϵ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT along 𝐰superscript𝐰\mathbf{w}^{*}bold_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT Due to the matching activations, this suffices to obtain arbitrary generalization error. ∎

Appendix B General Multi-Pass Schemes

Refer to caption
Figure 4: Comparison of theory and experiments for Gradient Descent on the target z1z2z3+He3(z4)subscript𝑧1subscript𝑧2subscript𝑧3subscriptHe3subscript𝑧4z_{1}z_{2}z_{3}+\mathrm{He}_{3}(z_{4})italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT + roman_He start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ). Each gradient step uses a mini-batche of n/5𝑛5n/5italic_n / 5 samples. On the left we use the data sequentially, on the right we sample the batch from the dataset with replacement. The continuous lines are from the DMFT numerical integration, the dots are simulations with d=10000𝑑10000d\!=\!10000italic_d = 10000 averaged over 32323232 realisations. In the legend the overlap with the n𝑛nitalic_n-th direction is the projection of the student weights in the subspace associated with znsubscript𝑧𝑛z_{n}italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. For this figure we have σ=relu𝜎relu\sigma\!=\!\rm reluitalic_σ = roman_relu, n=5d𝑛5𝑑n=5ditalic_n = 5 italic_d, η=0.2𝜂0.2\eta\!=\!0.2italic_η = 0.2.

B.1 Sketch of Proof for Extending Theorem 3.2 to Cycling over Epochs

Let 𝐙1,,𝐙ensuperscript𝐙1subscriptsuperscript𝐙𝑛𝑒\mathbf{Z}^{1},\cdots,\mathbf{Z}^{n}_{e}bold_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , ⋯ , bold_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT denote nesubscript𝑛𝑒n_{e}italic_n start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT independent minibatches of size nbsubscript𝑛𝑏n_{b}italic_n start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT such that nbd=𝒪(1)subscript𝑛𝑏𝑑𝒪1\frac{n_{b}}{d}=\mathcal{O}(1)divide start_ARG italic_n start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_ARG start_ARG italic_d end_ARG = caligraphic_O ( 1 ) with nesubscript𝑛𝑒n_{e}italic_n start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT being finite. The effective dynamics for a finite number of epochs can be obtained by noting that Theorem 3.2 in Gerbelot et al. [2023] allows generalizing Theorem A.1 to dynamics of the form:

W(t+1)=W(t)ηλW(t)i=1neη1dν=1nbFit(W(t)𝐳νid)(𝐳νi).superscript𝑊𝑡1superscript𝑊𝑡𝜂𝜆superscript𝑊𝑡superscriptsubscript𝑖1subscript𝑛𝑒𝜂1𝑑superscriptsubscript𝜈1subscript𝑛𝑏subscriptsuperscript𝐹𝑡𝑖superscript𝑊𝑡subscriptsuperscript𝐳𝑖𝜈𝑑superscriptsubscriptsuperscript𝐳𝑖𝜈top\displaystyle W^{(t+1)}=W^{(t)}-\eta\lambda W^{(t)}-\sum_{i=1}^{n_{e}}\eta% \frac{1}{\sqrt{d}}\sum_{\nu=1}^{n_{b}}F^{t}_{i}\left(\frac{W^{(t)}\mathbf{z}^{% i}_{\nu}}{\sqrt{d}}\right)(\mathbf{z}^{i}_{\nu})^{\top}.italic_W start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_η italic_λ italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_η divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ∑ start_POSTSUBSCRIPT italic_ν = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_F start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( divide start_ARG italic_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_z start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) ( bold_z start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ν end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT . (109)

The above form of the dynamics allows a different update to be utilized for data corresponding to different blocks 𝐙1,,𝐙ensuperscript𝐙1subscriptsuperscript𝐙𝑛𝑒\mathbf{Z}^{1},\cdots,\mathbf{Z}^{n}_{e}bold_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , ⋯ , bold_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT. In particular, setting Fitsubscriptsuperscript𝐹𝑡𝑖F^{t}_{i}italic_F start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT to 00 whenever tmodi0modulo𝑡𝑖0t\mod i\neq 0italic_t roman_mod italic_i ≠ 0 and 𝒉subscript𝒉\nabla_{\bm{h}}\mathcal{L}∇ start_POSTSUBSCRIPT bold_italic_h end_POSTSUBSCRIPT caligraphic_L otherwise, results in a cycling schedule over the mini-batches 𝐙1,,𝐙ensuperscript𝐙1subscriptsuperscript𝐙𝑛𝑒\mathbf{Z}^{1},\cdots,\mathbf{Z}^{n}_{e}bold_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , ⋯ , bold_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT. Subsequently, one can show that the update from 𝐙isuperscript𝐙𝑖\mathbf{Z}^{i}bold_Z start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT in the first-epoch leads to the hidden-progress effect on Mtsuperscript𝑀𝑡M^{t}italic_M start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT when the model re-uses 𝐙isuperscript𝐙𝑖\mathbf{Z}^{i}bold_Z start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT in the second epoch.

We believe a similar result would hold for nb=𝒪(1)subscript𝑛𝑏𝒪1n_{b}=\mathcal{O}(1)italic_n start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT = caligraphic_O ( 1 ) samples in the minibatch, as displayed in Figure 5

Refer to caption
Figure 5: Experiments for Gradient Descent on the target z1z2z3+He3(z4)subscript𝑧1subscript𝑧2subscript𝑧3subscriptHe3subscript𝑧4z_{1}z_{2}z_{3}+\mathrm{He}_{3}(z_{4})italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT + roman_He start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ). We use minibatches with 1111 sample each. On the left we use the data sequentially, on the right we sample the data point from the dataset with replacement. The dots are simulations with d=10000𝑑10000d\!=\!10000italic_d = 10000 averaged over 32323232 realisations. In the legend the overlap with the n𝑛nitalic_n-th direction denotes the projection of the student weights in the subspace associated with znsubscript𝑧𝑛z_{n}italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. For this figure we have σ=relu𝜎relu\sigma\!=\!\rm reluitalic_σ = roman_relu, n=5d𝑛5𝑑n=5ditalic_n = 5 italic_d, η=0.2𝜂0.2\eta\!=\!0.2italic_η = 0.2.

Appendix C Details on the numerics

C.1 DMFT equations with a single stochastic process

In this section, we present a set of exact equations equivalent to the ones in the main text, but that depend on a single stochastic process. It is possible to show that asymptotically in the proportional limit, i.e. for d𝑑d\to\inftyitalic_d → ∞ and n=αd𝑛𝛼𝑑n=\alpha ditalic_n = italic_α italic_d, the pre-activations of the student are distributed as 𝐡(t)=𝐫(t)+M(t)𝐡superscript𝐡𝑡superscript𝐫𝑡superscript𝑀𝑡superscript𝐡\mathbf{h}^{(t)}=\mathbf{r}^{(t)}+M^{(t)}\mathbf{h}^{*}bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = bold_r start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, with the constraint:

𝐫(t+1)=𝐫(t)η[(λ+Λ(t))𝐫(t)+𝐡(t)(𝐡(t))τ=0t1R(t,τ)𝐫(t)+ζ(t)]superscript𝐫𝑡1superscript𝐫𝑡𝜂delimited-[]𝜆superscriptΛ𝑡superscript𝐫𝑡subscriptsuperscript𝐡𝑡superscript𝐡𝑡superscriptsubscript𝜏0𝑡1superscriptsubscript𝑅𝑡𝜏superscript𝐫𝑡superscript𝜁𝑡\displaystyle\mathbf{r}^{(t+1)}=\mathbf{r}^{(t)}-\eta\left[\left(\lambda+% \Lambda^{(t)}\right)\mathbf{r}^{(t)}+\nabla_{\mathbf{h}^{(t)}}\ell\left(% \mathbf{h}^{(t)}\right)-\sum_{\tau=0}^{t-1}R_{\ell}^{(t,\tau)}\mathbf{r}^{(t)}% +\mathbf{\zeta}^{(t)}\right]bold_r start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = bold_r start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_η [ ( italic_λ + roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) bold_r start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + ∇ start_POSTSUBSCRIPT bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_ℓ ( bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) - ∑ start_POSTSUBSCRIPT italic_τ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT bold_r start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_ζ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ]

Here ζ(t)superscript𝜁𝑡\mathbf{\zeta}^{(t)}italic_ζ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT is a zero mean Gaussian Process with covariance

𝔼ζ[ζ(t)ζ(τ)]=α𝔼𝐫,𝐡[𝐡(t)(𝐡(t))𝐡(τ)(𝐡(τ))]subscript𝔼𝜁delimited-[]superscript𝜁𝑡superscript𝜁limit-from𝜏top𝛼subscript𝔼𝐫superscript𝐡delimited-[]subscriptsuperscript𝐡𝑡superscript𝐡𝑡subscriptsuperscript𝐡𝜏superscriptsuperscript𝐡𝜏top\displaystyle\mathbb{E}_{\mathbf{\zeta}}\left[\mathbf{\zeta}^{(t)}\mathbf{% \zeta}^{(\tau)\top}\right]=\alpha\mathbb{E}_{\mathbf{r},\mathbf{h}^{*}}\left[% \nabla_{\mathbf{h}^{(t)}}\ell\left(\mathbf{h}^{(t)}\right)\nabla_{\mathbf{h}^{% (\tau)}}\ell\left(\mathbf{h}^{(\tau)}\right)^{\top}\right]blackboard_E start_POSTSUBSCRIPT italic_ζ end_POSTSUBSCRIPT [ italic_ζ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT italic_ζ start_POSTSUPERSCRIPT ( italic_τ ) ⊤ end_POSTSUPERSCRIPT ] = italic_α blackboard_E start_POSTSUBSCRIPT bold_r , bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_ℓ ( bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ∇ start_POSTSUBSCRIPT bold_h start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_ℓ ( bold_h start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ]

and the effective regularisation Λ(t)superscriptΛ𝑡\Lambda^{(t)}roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT concentrates to

Λ(t)=α𝔼𝐡[𝐡(t)2(𝐡(t))].superscriptΛ𝑡𝛼subscript𝔼𝐡delimited-[]subscriptsuperscript2superscript𝐡𝑡superscript𝐡𝑡\displaystyle\Lambda^{(t)}=\alpha\mathbb{E}_{\mathbf{h}}\left[\nabla^{2}_{% \mathbf{h}^{(t)}}\ell\left(\mathbf{h}^{(t)}\right)\right]\,.roman_Λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = italic_α blackboard_E start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_ℓ ( bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ] . (110)

The memory kernel R(t,τ)superscriptsubscript𝑅𝑡𝜏R_{\ell}^{(t,\tau)}italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT is identically zero for tτ𝑡𝜏t\leq\tauitalic_t ≤ italic_τ while for t>τ𝑡𝜏t>\tauitalic_t > italic_τ it concentrates to

R(t,τ)=α𝔼𝐡[𝐡(t)(𝐡(t))ζ(τ)]superscriptsubscript𝑅𝑡𝜏𝛼subscript𝔼𝐡delimited-[]subscriptsuperscript𝐡𝑡superscript𝐡𝑡superscript𝜁𝜏\displaystyle R_{\ell}^{(t,\tau)}=\alpha\mathbb{E}_{\mathbf{h}}\left[\frac{% \partial\,\nabla_{\mathbf{h}^{(t)}}\ell\left(\mathbf{h}^{(t)}\right)}{\partial% \,\zeta^{(\tau)}}\right]italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t , italic_τ ) end_POSTSUPERSCRIPT = italic_α blackboard_E start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT [ divide start_ARG ∂ ∇ start_POSTSUBSCRIPT bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_ℓ ( bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG start_ARG ∂ italic_ζ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT end_ARG ] (111)

Finally, the low dimensionaly projections of the weights M(t)superscript𝑀𝑡M^{(t)}italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT will obey the relation

M(t+1)=M(t)ηα𝔼𝐡,𝐡[𝐡(t)(𝐡(t))𝐡]superscript𝑀𝑡1superscript𝑀𝑡𝜂𝛼subscript𝔼𝐡superscript𝐡delimited-[]subscriptsuperscript𝐡𝑡superscript𝐡𝑡superscript𝐡absenttop\displaystyle M^{(t+1)}=M^{(t)}-\eta\alpha\mathbb{E}_{\mathbf{h},\mathbf{h}^{*% }}\left[\nabla_{\mathbf{h}^{(t)}}\ell\left(\mathbf{h}^{(t)}\right)\mathbf{h}^{% *\top}\right]italic_M start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_η italic_α blackboard_E start_POSTSUBSCRIPT bold_h , bold_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_ℓ ( bold_h start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) bold_h start_POSTSUPERSCRIPT ∗ ⊤ end_POSTSUPERSCRIPT ] (112)

The procedure is explained in detail in appendix D of [Gerbelot et al., 2023], and can be equivalently derived using non-rigorous field theory techniques [Agoritsas et al., 2018].

C.2 Remark on the numerical integration of the DMFT equations

DMFT is an invaluable tool in itself to probe the behaviour of gradient based algorithms. It trades the update equation over heavily coupled weights in (5) with the ones over completely decoupled preactivations (C.1) which implies that a Monte Carlo estimation based on (C.1) is going to be vastly more efficient and it’s a trivially parallelisable computation. Furthermore, equation (C.1) is exact in limit of large d𝑑ditalic_d, which removes completely all finite size effects. In practice, an implementation of the DMFT equations is extracting n𝑛nitalic_n times using from the initial condition distribution of the practivations and iterating forward. The Gaussian process is sampled by rotating white Gaussian noise by the LU factor of the covariance. Sampling the gaussian process is by far the costlier operation, as each time step T𝑇Titalic_T has a complexity 𝒪(T3)𝒪superscript𝑇3\mathcal{O}(T^{3})caligraphic_O ( italic_T start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ), for a total 𝒪(T4)𝒪superscript𝑇4\mathcal{O}(T^{4})caligraphic_O ( italic_T start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT ) complexity considering all the steps up to T𝑇Titalic_T. Notice that this is a much more direct implementation than what is done in the literature [Roy et al., 2019, Mignacco et al., 2020], which usually starts with a guess for all the quantities and proceedes with a damped fixed point iteration until convergence, with an overall complexity 𝒪(mT3)𝒪𝑚superscript𝑇3\mathcal{O}(mT^{3})caligraphic_O ( italic_m italic_T start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ), where m𝑚mitalic_m is the number of fixed point iterations. While it could appear that simply iterating forward is suboptimal, it is a much more stable and reliable procedure: if you are using n𝑛nitalic_n processes and you iterate forward, you are sure that at at each time step you have the best possible Monte Carlo estimate of your samples.

C.3 Details on the numerical simulations

In all the figures the continuous lines are from the numerical integration of the DMFT equations while the dots are from a direct simulation of the gradient descent dynamics. The specific hyperparameters for each setting are near each figure.

For both we fixed the second layer weights to ±1/pplus-or-minus1𝑝\pm 1/\sqrt{p}± 1 / square-root start_ARG italic_p end_ARG, as for the cases under consideration this is an equivalent choice to of Gaussian second layer weights 𝒩(0,1p𝟙p)𝒩01𝑝subscript1𝑝\mathcal{N}(0,\frac{1}{p}{\mathbbm{1}}_{p})caligraphic_N ( 0 , divide start_ARG 1 end_ARG start_ARG italic_p end_ARG blackboard_1 start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ). For the DMFT integration we used a minimum of 106superscript10610^{6}10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT Monte Carlo samples in order to have accurate lined. The error bars are too small to be visualised. The direct simulation of the gradient descent dynamics was performed either using PyTorch or a direct implementation in Numpy. In all plots we used a minimum size d=5000𝑑5000d=5000italic_d = 5000 for the input dimension, and averaged over at least 32323232 independent instances of the dynamics.

In Figure 2 we plot the overlap matrix M(t)superscript𝑀𝑡M^{(t)}italic_M start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT projected on two different directions: the parallel to the subspace that is learned in the first step and one direction in the orthogonal of this space. The projection operator is computed by performing explicitly the integrals in (78)

The code is made available through the following Github repository: https://github.com/IdePHICS/benefit-reusing-batch.