License: CC BY 4.0
arXiv:2401.10809v2 [cs.LG] 24 Jan 2024

Neglected Hessian component explains mysteries in Sharpness regularization

Yann Dauphin and Atish Agarwala
Google Deepmind
{ynd, thetish}@google.com
&Hossein Mobahi
Google Research
[email protected]
Abstract

Recent work has shown that methods like SAM which either explicitly or implicitly penalize second order information can improve generalization in deep learning. Seemingly similar methods like weight noise and gradient penalties often fail to provide such benefits. We show that these differences can be explained by the structure of the Hessian of the loss. First, we show that a common decomposition of the Hessian can be quantitatively interpreted as separating the feature exploitation from feature exploration. The feature exploration, which can be described by the Nonlinear Modeling Error matrix (NME), is commonly neglected in the literature since it vanishes at interpolation. Our work shows that the NME is in fact important as it can explain why gradient penalties are sensitive to the choice of activation function. Using this insight we design interventions to improve performance. We also provide evidence that challenges the long held equivalence of weight noise and gradient penalties. This equivalence relies on the assumption that the NME can be ignored, which we find does not hold for modern networks since they involve significant feature learning. We find that regularizing feature exploitation but not feature exploration yields performance similar to gradient penalties.

1 Introduction

There is a long history in machine learning of trying to use information about the loss landscape geometry to improve gradient-based learning. This has ranged from attempts to use the Fisher information matrix to improve optimization (Martens & Grosse, 2015), to trying to regularize the Hessian to improve generalization (Sankar et al., 2021). More recently, first order methods which implicitly use or penalize second order quantities have been used successfully, including the sharpness aware minimization (SAM) algorithm (Foret et al., 2020). On the other hand, there are many approaches to use second order information which once seemed promising but have had limited success (Dean et al., 2012). These include methods like weight noise (An, 1996) and gradient norm penalties, which have shown mixed success.

Part of the difficulty of using second order information is the difficulty of working with the Hessian of the loss. With the large number of parameters in deep learning architectures, as well as the large number of datapoints, many algorithms use stochastic methods to approximate statistics of the Hessian Martens & Grosse (2015); Liu et al. (2023). However, there is a conceptual difficulty as well which arises from the complicated structure of the Hessian itself. Methods often involves approximating the Hessian via the Gauss-Newton (GN) matrix - which is PSD for convex losses. This is beneficial for conditioners which try to maintain monotonicity of gradient flow via a PSD transformation. Thus indefinite part of the Hessian is often neglected due to its complexity.

In this work we show that it is important to consider both parts of the Hessian to understand certain methods that use second order information for regularization. We show that with saturating non-linearities, the GN part of the Hessian is related to exploiting existing linear structure, while the indefinite part of the Hessian, which we dub the Nonlinear Modeling Error matrix (NME), is related to exploring the effects of switching to different multi-linear regions. In contrast to commonly held assumptions, this work identifies two distinct cases where neglecting the influence of the indefinite component of the Hessian is demonstrably detrimental:

  • Training with Gradient Penalties. Our theoretical analysis reveals that the activation function controls the sparsity of information encoded within the indefinite component of the Hessian. Notably, we demonstrate that manipulating this sparsity by changing the activation function can transform previously ineffective gradient penalties into potent tools for improved generalization. To the best of our knowledge, this work is the first to show that methods using second order information are more sensitive to the choice of activation function.

  • Training with Hessian penalties. Conventional analysis of weight noise casts it as a penalty on the GN part of the Hessian, but in reality it also penalizes the NME. Our experimental ablations show that the NME exerts a significant influence on generalization performance.

We conclude with a discussion about how these insights might be used to design activation functions not with an eye towards forward or backwards passes (Pennington et al., 2017; Martens et al., 2021), but for compatibility with methods that use second order information.

2 Understanding the structure of the Hessian

In this section, we lay the theoretical ground work for our experiments by explaining the structure of the Hessian. Given a model 𝐳(𝜽,𝐱)𝐳𝜽𝐱\mathbf{z}(\bm{\theta},\mathbf{x})bold_z ( bold_italic_θ , bold_x ) defined on parameters 𝜽𝜽\bm{\theta}bold_italic_θ and input 𝐱𝐱\mathbf{x}bold_x, and a loss function (𝐳,𝐲)𝐳𝐲\mathcal{L}(\mathbf{z},\mathbf{y})caligraphic_L ( bold_z , bold_y ) on the model outputs and labels 𝐲𝐲\mathbf{y}bold_y, we can write the gradient of the training loss with respect to 𝜽𝜽\bm{\theta}bold_italic_θ as

𝜽=𝐉T(𝐳)subscript𝜽superscript𝐉Tsubscript𝐳\nabla_{\bm{\theta}}\mathcal{L}=\mathbf{J}^{{\rm T}}(\nabla_{\mathbf{z}}% \mathcal{L})∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L = bold_J start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT ( ∇ start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT caligraphic_L ) (1)

where the Jacobian 𝐉𝜽𝐳𝐉subscript𝜽𝐳\mathbf{J}\equiv\nabla_{\bm{\theta}}\mathbf{z}bold_J ≡ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT bold_z. The Hessian 𝜽2superscriptsubscript𝜽2\nabla_{\bm{\theta}}^{2}\mathcal{L}∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L can be decomposed as:

𝜽2=𝐉T𝐇𝐳𝐉GN+𝐳𝜽2𝐳NMEsuperscriptsubscript𝜽2subscriptsuperscript𝐉Tsubscript𝐇𝐳𝐉GNsubscriptsubscript𝐳subscriptsuperscript2𝜽𝐳NME\nabla_{\bm{\theta}}^{2}\mathcal{L}=\underbrace{\mathbf{J}^{{\rm T}}\mathbf{H}% _{\mathbf{z}}\mathbf{J}}_{\mathrm{GN}}+\underbrace{\nabla_{\mathbf{z}}\mathcal% {L}\cdot\nabla^{2}_{\bm{\theta}}\mathbf{z}}_{\mathrm{NME}}∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L = under⏟ start_ARG bold_J start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_H start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT bold_J end_ARG start_POSTSUBSCRIPT roman_GN end_POSTSUBSCRIPT + under⏟ start_ARG ∇ start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT caligraphic_L ⋅ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT bold_z end_ARG start_POSTSUBSCRIPT roman_NME end_POSTSUBSCRIPT (2)

where 𝐇𝐳𝐳2subscript𝐇𝐳subscriptsuperscript2𝐳\mathbf{H}_{\mathbf{z}}\equiv\nabla^{2}_{\mathbf{z}}\mathcal{L}bold_H start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT ≡ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT caligraphic_L. The first term is often called the Gauss-Newton (GN) part of the Hessian (Jacot et al., 2020; Martens, 2020). If the loss function is convex with respect to the model outputs/logits (such as for MSE and CE losses), then the GN matrix is positive semi-definite. This term often contributes large eigenvalues. The second term has previously been studied theoretically where it is called the functional Hessian (Singh et al., 2021; 2023); in order to avoid confusion with the overall Hessian we call it the Nonlinear Modeling Error matrix (NME). It is in general indefinite and vanishes to zero at an interpolating minimum 𝜽*superscript𝜽\bm{\theta}^{*}bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT where the model “fits”the data (z(𝜽*)=𝟎subscript𝑧superscript𝜽0\nabla_{z}\mathcal{L}(\bm{\theta}^{*})=\bm{0}∇ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) = bold_0), as can happen in overparameterized settings. Due to this, it is quite common for studies to drop this term entirely when dealing with the Hessian. For example, many second order optimizers approximate the Hessian 𝜽2superscriptsubscript𝜽2\nabla_{\bm{\theta}}^{2}\mathcal{L}∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L with only the Gauss-Newton term (Martens & Sutskever, 2011; Liu et al., 2023). It is also common to neglect this term in theoretical analysis of the Hessian 𝜽2superscriptsubscript𝜽2\nabla_{\bm{\theta}}^{2}\mathcal{L}∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L (Bishop, 1995; Sagun et al., 2017). However, we will show why this term should not be ignored.

While the NME term can become small late in training, it encodes significant information during training. More precisely, it is the only part of Hessian that contains second order information from the model features 𝜽2𝐳superscriptsubscript𝜽2𝐳\nabla_{\bm{\theta}}^{2}\mathbf{z}∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_z. The GN matrix only contains second order information about the loss w.r.t. the logits with the term 𝐇𝐳subscript𝐇𝐳\mathbf{H}_{\mathbf{z}}bold_H start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT. All the information about the model function in the GN matrix is first-order. In fact, the GN matrix can be seen as the Hessian of an approximation of the loss where a first-order approximation of the model 𝐳(𝜽,𝐱)𝐳(𝜽,𝐱)+𝐉𝜹𝐳superscript𝜽𝐱𝐳𝜽𝐱𝐉𝜹\mathbf{z}(\bm{\theta}^{\prime},\mathbf{x})\approx\mathbf{z}(\bm{\theta},% \mathbf{x})+\mathbf{J}\bm{\delta}bold_z ( bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , bold_x ) ≈ bold_z ( bold_italic_θ , bold_x ) + bold_J bold_italic_δ (𝜹=𝜽𝜽𝜹superscript𝜽𝜽\bm{\delta}=\bm{\theta}^{\prime}-\bm{\theta}bold_italic_δ = bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - bold_italic_θ) is used (Martens & Sutskever, 2011)

𝜹2(𝐳(θ,𝐱)+𝐉𝜹,𝐲)|𝜽=𝜽=𝐉T𝐇𝐳𝐉evaluated-atsubscriptsuperscript2𝜹𝐳𝜃𝐱𝐉𝜹𝐲superscript𝜽𝜽superscript𝐉Tsubscript𝐇𝐳𝐉\nabla^{2}_{\bm{\delta}}\mathcal{L}(\mathbf{z}(\theta,\mathbf{x})+\mathbf{J}% \bm{\delta},\mathbf{y})|_{\bm{\theta}^{\prime}=\bm{\theta}}=\mathbf{J}^{{\rm T% }}\mathbf{H}_{\mathbf{z}}\mathbf{J}∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_δ end_POSTSUBSCRIPT caligraphic_L ( bold_z ( italic_θ , bold_x ) + bold_J bold_italic_δ , bold_y ) | start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = bold_italic_θ end_POSTSUBSCRIPT = bold_J start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_H start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT bold_J (3)

Thus we can see the GN matrix as the result of a linearization of the model and the NME as the part that takes into account the non-linear part of the model. The GN matrix exactly determines the linearized (NTK) dynamics of training, and therefore controls learning over small parameter changes when the features can be approximated as fixed (see Appendix A.1). In contrast, the NME encodes information about the changes in the NTK (Agarwala et al., 2022). For example given a piecewise multilinear model like a ReLU network, we can think of the GN part of the Hessian as exploiting the linear (NTK) structure, while the NME gives information on exploration - namely, the benefits of switching to a different multilinear region where different neurons are active. See Figure 1 for an illustration of this with ReLU model. We discuss this aspect further in Section 4.3.

The GN part may seem like it must contain this second order information due to its equivalence to the Fisher information matrix for losses that can be written as negative log-likelihoods, like MSE and cross-entropy. For these, the Fisher information itself can be written as the Hessian of a slightly different loss (Pascanu & Bengio, 2013):

𝐅=E𝐲^𝐩𝐳[𝜽2(𝐳,𝐲^)]𝐅subscriptEsimilar-to^𝐲subscript𝐩𝐳delimited-[]subscriptsuperscript2𝜽𝐳^𝐲{\bf F}={\rm E}_{\hat{\bf y}\sim\mathbf{p}_{\mathbf{z}}}\left[\nabla^{2}_{\bm{% \theta}}\mathcal{L}(\mathbf{z},\hat{\mathbf{y}})\right]bold_F = roman_E start_POSTSUBSCRIPT over^ start_ARG bold_y end_ARG ∼ bold_p start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_z , over^ start_ARG bold_y end_ARG ) ] (4)

where the only difference is that the labels 𝐲^^𝐲\hat{\mathbf{y}}over^ start_ARG bold_y end_ARG are sampled from the model instead of the true labels. However, the NME is 00 for this loss. For example, in the case of MSE using Equation 2 we have

E𝐲^𝐩𝐳[𝜽2(𝐳,𝐲^)]subscriptEsimilar-to^𝐲subscript𝐩𝐳delimited-[]subscriptsuperscript2𝜽𝐳^𝐲\displaystyle{\rm E}_{\hat{\mathbf{y}}\sim\mathbf{p}_{\mathbf{z}}}\left[\nabla% ^{2}_{\bm{\theta}}\mathcal{L}(\mathbf{z},\hat{\mathbf{y}})\right]roman_E start_POSTSUBSCRIPT over^ start_ARG bold_y end_ARG ∼ bold_p start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_z , over^ start_ARG bold_y end_ARG ) ] =E𝐲^𝒩(𝐳,𝐈)[𝐉T𝐇𝐳𝐉+𝐳(𝐳,𝐲^)𝜽2𝐳]absentsubscriptEsimilar-to^𝐲𝒩𝐳𝐈delimited-[]superscript𝐉Tsubscript𝐇𝐳𝐉subscript𝐳𝐳^𝐲subscriptsuperscript2𝜽𝐳\displaystyle={\rm E}_{\hat{\mathbf{y}}\sim\mathcal{N}(\mathbf{z},\mathbf{I})}% \left[\mathbf{J}^{{\rm T}}\mathbf{H}_{\mathbf{z}}\mathbf{J}+\nabla_{\mathbf{z}% }\mathcal{L}(\mathbf{z},\hat{\mathbf{y}})\cdot\nabla^{2}_{\bm{\theta}}\mathbf{% z}\right]= roman_E start_POSTSUBSCRIPT over^ start_ARG bold_y end_ARG ∼ caligraphic_N ( bold_z , bold_I ) end_POSTSUBSCRIPT [ bold_J start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_H start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT bold_J + ∇ start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT caligraphic_L ( bold_z , over^ start_ARG bold_y end_ARG ) ⋅ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT bold_z ] (5)
=𝐉T𝐇𝐳𝐉+E𝐲^𝒩(𝐳,𝐈)[𝐳𝐲^]𝜽2𝐳absentsuperscript𝐉Tsubscript𝐇𝐳𝐉cancelsubscriptEsimilar-to^𝐲𝒩𝐳𝐈delimited-[]𝐳^𝐲subscriptsuperscript2𝜽𝐳\displaystyle=\mathbf{J}^{{\rm T}}\mathbf{H}_{\mathbf{z}}\mathbf{J}+\cancel{{% \rm E}_{\hat{\mathbf{y}}\sim\mathcal{N}(\mathbf{z},\mathbf{I})}[\mathbf{z}-% \hat{\mathbf{y}}]}\cdot\nabla^{2}_{\bm{\theta}}\mathbf{z}= bold_J start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_H start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT bold_J + cancel roman_E start_POSTSUBSCRIPT over^ start_ARG bold_y end_ARG ∼ caligraphic_N ( bold_z , bold_I ) end_POSTSUBSCRIPT [ bold_z - over^ start_ARG bold_y end_ARG ] ⋅ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT bold_z (6)

The second term in Equation 6 (NME) vanishes because we are at the global minimum for this loss.

Refer to caption Refer to caption

Figure 1: Loss (left) and Nonlinear Modeling Error matrix (NME) norm (right) as a function of 2222 parameters in the same hidden layer of an MLP (MSE loss, one datapoint). For ReLU activation model is piecewise multilinear, and piecewise linear for parameters in same layer. Loss is piecewise quadratic for parameters in same layer (left). There is little NME information accessible pointwise and the main features are the boundaries of the piecewise linear regions (blue, right). For β𝛽\betaitalic_β-GELU, NME magnitude is high only within distance 1/β1𝛽1/\beta1 / italic_β of those boundaries. Therefore the NME encodes information about the utility of switching between piecewise multilinear regions.

3 Experimental Setup

Our analysis of the Hessian begs an immediate question: when does the NME affect learning algorithms? We conducted experimental studies to answer this question in the context of curvature regularization algorithms which seek to promote convergence to flat areas of the loss landscape. We use the following two setups for the remainder of the paper:

Imagenet We conduct experiments on the popular Imagenet dataset (Deng et al., 2009). All experiments use the Resnet-50 architecture with the same setup and hyper-parameters as Goyal et al. (2018), except that we use cosine learning rate decay (Loshchilov & Hutter, 2016) over 300 epochs.

CIFAR-10 We also provide results on the CIFAR-10 dataset (Krizhevsky et al., 2009). All experiments use the WideResnet 28-10 architecture with the same setup and hyper-parameters as Zagoruyko & Komodakis (2016), except for the use of cosine learning rate decay.

4 How NME affects training with gradient penalties

In this section we will show that the information contained in the NME has a critical impact on the effectiveness of gradient penalties for generalization. We define a gradient penalty as an additive regularizer of the form:

pen,p=ρ0psubscript𝑝𝑒𝑛𝑝𝜌superscriptnormsubscript0𝑝\mathcal{L}_{pen,p}=\rho||\nabla\mathcal{L}_{0}||^{p}caligraphic_L start_POSTSUBSCRIPT italic_p italic_e italic_n , italic_p end_POSTSUBSCRIPT = italic_ρ | | ∇ caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | | start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT (7)

for a base loss 0subscript0\mathcal{L}_{0}caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Gradient penalties have recently gained popularity as regularizers (Barrett & Dherin, 2021; Smith et al., 2021; Du et al., 2022; Zhao et al., 2022; Reizinger & Huszár, 2023); this is in part due to their ability to reduce sharpness. In fact, pen,psubscript𝑝𝑒𝑛𝑝\mathcal{L}_{pen,p}caligraphic_L start_POSTSUBSCRIPT italic_p italic_e italic_n , italic_p end_POSTSUBSCRIPT is closely related to Sharpness Aware Minimization (SAM) (Foret et al., 2020). p=1𝑝1p=1italic_p = 1 corresponds to the original normalized formulation, while p=2𝑝2p=2italic_p = 2 corresponds to the unnormalized formulation which is equally effective and easier to analyze (Andriushchenko & Flammarion, 2022; Agarwala & Dauphin, 2023). A more detailed description of the link between SAM and gradient penalties can be found in Appendix B. We will focus on the p=1𝑝1p=1italic_p = 1 case in the remainder of this section.

4.1 Gradient penalty update rules

Consider the SGD update rule for the p=1𝑝1p=1italic_p = 1 gradient penalty and base loss 0subscript0\mathcal{L}_{0}caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. With learning rate η𝜂\etaitalic_η, the parameters 𝜽𝜽\bm{\theta}bold_italic_θ evolve as:

𝜽t+1𝜽t=η(𝜽0+1𝜽0𝐇𝜽0),𝐇𝜽20formulae-sequencesubscript𝜽𝑡1subscript𝜽𝑡𝜂subscript𝜽subscript01normsubscript𝜽subscript0𝐇subscript𝜽subscript0𝐇superscriptsubscript𝜽2subscript0\bm{\theta}_{t+1}-\bm{\theta}_{t}=-\eta\left(\nabla_{\bm{\theta}}\mathcal{L}_{% 0}+\frac{1}{||\nabla_{\bm{\theta}}\mathcal{L}_{0}||}\mathbf{H}\nabla_{\bm{% \theta}}\mathcal{L}_{0}\right),~{}\mathbf{H}\equiv\nabla_{\bm{\theta}}^{2}% \mathcal{L}_{0}bold_italic_θ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT - bold_italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = - italic_η ( ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG | | ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | | end_ARG bold_H ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) , bold_H ≡ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT (8)

The additional contribution to the dynamics comes in the form of a Hessian-gradient product.

Refer to caption
(a) Imagenet
Refer to caption
(b) CIFAR-10
Figure 2: Test Accuracy vs. ρ𝜌\rhoitalic_ρ for ReLU and GELU networks trained with gradient penalty (p=1𝑝1p=1italic_p = 1, averaged over 5555 seeds). In both cases performance is similar without regularization but with regularization test accuracy increases for GELU until ρ=0.1𝜌0.1\rho=0.1italic_ρ = 0.1 and decreases for ReLU over a similar range.

Since the update rule explicitly involves the Hessian, a natural question is: do the GN and NME both play a significant role in the dynamics? Or does the conventional wisdom hold - that the GN dominates? We explore this question by starting with a simple experiment. In our Imagenet and CIFAR10 setups, we consider networks trained with ReLU activations and networks trained with GELU activations. Without regularization, both activation functions achieve similar test accuracy (76.876.876.876.8 for both on Imagenet). However, as the gradient penalty regularizer is added, differences emerge with increasing ρ𝜌\rhoitalic_ρ (Figure 2). The performance of GELU networks increases with ρ𝜌\rhoitalic_ρ as high as 0.10.10.10.1; in contrast, the performance of ReLU networks is decreasing in ρ𝜌\rhoitalic_ρ.

Even though both activations have similar forward passes, the addition of the Hessian-gradient product seems to dramatically change the learning dynamics. Since training without regularization seems to be similar across the activation functions, we focus on the difference between the Hessians induced by ReLU and GELU. As we will see, it is in fact the NME which is dramatically different between the two activations.

4.2 Effect of Activation functions on the NME

One important feature of the NME is that it depends on the second derivative of the activation function. We can demonstrate this most easily on a fully-connected network, but the general principle applies to most common architectures. Given an activation function ϕitalic-ϕ\phiitalic_ϕ, a feedforward network with L𝐿Litalic_L layers on an input 𝐱0subscript𝐱0\mathbf{x}_{0}bold_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT defined iteratively by

𝐡ł=𝐖ł𝐱ł,𝐱ł+1=ϕ(𝐡ł)formulae-sequencesubscript𝐡italic-łsubscript𝐖italic-łsubscript𝐱italic-łsubscript𝐱italic-ł1italic-ϕsubscript𝐡italic-ł\mathbf{h}_{\l}=\mathbf{W}_{\l}\mathbf{x}_{\l},~{}\mathbf{x}_{\l+1}=\phi(% \mathbf{h}_{\l})bold_h start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT = bold_W start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT bold_x start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT , bold_x start_POSTSUBSCRIPT italic_ł + 1 end_POSTSUBSCRIPT = italic_ϕ ( bold_h start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT ) (9)

The gradient of the model output 𝐱Lsubscript𝐱𝐿\mathbf{x}_{L}bold_x start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT with respect to a weight matrix 𝐖lsubscript𝐖𝑙\mathbf{W}_{l}bold_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT is given by

𝐱L𝐖ł=𝐉L(ł+1)ϕ(𝐡l)𝐱l,𝐉łłm=łł1ϕ(𝐡m)𝐖mformulae-sequencesubscript𝐱𝐿subscript𝐖italic-łtensor-productsubscript𝐉𝐿italic-ł1superscriptitalic-ϕsubscript𝐡𝑙subscript𝐱𝑙subscript𝐉superscriptitalic-łitalic-łsuperscriptsubscriptproduct𝑚italic-łsuperscriptitalic-ł1superscriptitalic-ϕsubscript𝐡𝑚subscript𝐖𝑚\frac{\partial\mathbf{x}_{L}}{\partial\mathbf{W}_{\l}}=\mathbf{J}_{L(\l+1)}% \circ\phi^{\prime}(\mathbf{h}_{l})\otimes\mathbf{x}_{l},~{}\mathbf{J}_{\l^{% \prime}\l}\equiv\prod_{m=\l}^{\l^{\prime}-1}\phi^{\prime}(\mathbf{h}_{m})\circ% \mathbf{W}_{m}divide start_ARG ∂ bold_x start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT end_ARG = bold_J start_POSTSUBSCRIPT italic_L ( italic_ł + 1 ) end_POSTSUBSCRIPT ∘ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ⊗ bold_x start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , bold_J start_POSTSUBSCRIPT italic_ł start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_ł end_POSTSUBSCRIPT ≡ ∏ start_POSTSUBSCRIPT italic_m = italic_ł end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_ł start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) ∘ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT (10)

where \circ is the Hadamard (elementwise) product. The second derivative can be written as:

2𝐱L𝐖ł𝐖m=[𝐉L(ł+1)𝐖mϕ(𝐡l)+𝐉L(ł+1)ϕ(𝐡l)𝐖m]𝐱lsuperscript2subscript𝐱𝐿subscript𝐖italic-łsubscript𝐖𝑚tensor-productdelimited-[]subscript𝐉𝐿italic-ł1subscript𝐖𝑚superscriptitalic-ϕsubscript𝐡𝑙subscript𝐉𝐿italic-ł1superscriptitalic-ϕsubscript𝐡𝑙subscript𝐖𝑚subscript𝐱𝑙\frac{\partial^{2}\mathbf{x}_{L}}{\partial\mathbf{W}_{\l}\partial\mathbf{W}_{m% }}=\left[\frac{\partial\mathbf{J}_{L(\l+1)}}{\partial\mathbf{W}_{m}}\circ\phi^% {\prime}(\mathbf{h}_{l})+\mathbf{J}_{L(\l+1)}\circ\frac{\partial\phi^{\prime}(% \mathbf{h}_{l})}{\partial\mathbf{W}_{m}}\right]\otimes\mathbf{x}_{l}divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_x start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG = [ divide start_ARG ∂ bold_J start_POSTSUBSCRIPT italic_L ( italic_ł + 1 ) end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG ∘ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) + bold_J start_POSTSUBSCRIPT italic_L ( italic_ł + 1 ) end_POSTSUBSCRIPT ∘ divide start_ARG ∂ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG ] ⊗ bold_x start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT (11)

where without loss of generality ml𝑚𝑙m\geq litalic_m ≥ italic_l. The full analysis of this derivative can be found in Appendix A.2. The key feature is that the majority of the terms have a factor of the form

ϕ(𝐡o)𝐖m=ϕ′′(𝐡o)𝐡o𝐖msuperscriptitalic-ϕsubscript𝐡𝑜subscript𝐖𝑚superscriptitalic-ϕ′′subscript𝐡𝑜subscript𝐡𝑜subscript𝐖𝑚\frac{\partial\phi^{\prime}(\mathbf{h}_{o})}{\partial\mathbf{W}_{m}}=\phi^{% \prime\prime}(\mathbf{h}_{o})\circ\frac{\partial\mathbf{h}_{o}}{\partial% \mathbf{W}_{m}}divide start_ARG ∂ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG = italic_ϕ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) ∘ divide start_ARG ∂ bold_h start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG (12)

via the product rule - a dependence on ϕ′′superscriptitalic-ϕ′′\phi^{\prime\prime}italic_ϕ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT. On the diagonal m=l𝑚𝑙m=litalic_m = italic_l, all the terms depend on ϕ′′superscriptitalic-ϕ′′\phi^{\prime\prime}italic_ϕ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT. We note that a similar analysis can be found in Section 8.1.2 of Martens (2020).

4.3 ReLU vs. GELU second derivatives

The second derivative of the activation function is key to controlling the statistics of the NME. Due to the popularity of first order optimizers, activation functions have been designed to have well behaved first derivatives - with little concern for second derivatives. Consider ReLU: it became popular as a way to deal with gradient propagation issues from activations like tanh\tanhroman_tanh; however, it suffers from a “missing curvature” phenomenology - mathematically, the ReLU second derivative is 00 everywhere except the origin, where it is undefined. In practical implementations it is set to 00 at the origin as well. This implies that the diagonal of the NME is 00 for ReLU in practice.

In contrast, GELU has a well-posed second derivative - and therefore a non-trivial NME. We can study the difference between the GELU and ReLU by using the β𝛽\betaitalic_β-GELU which interpolates between the two. It is given by

β-GELU(x)=xΦ(βx)𝛽-GELU𝑥𝑥Φ𝛽𝑥\beta\text{-GELU}(x)=x\Phi(\beta x)italic_β -GELU ( italic_x ) = italic_x roman_Φ ( italic_β italic_x ) (13)

where ΦΦ\Phiroman_Φ is the standard Gaussian CDF. We can recover GELU by setting β=1𝛽1\beta=1italic_β = 1, and ReLU is recovered in the limit β𝛽\beta\to\inftyitalic_β → ∞ (except for the second derivative at the origin which as we will see it is undefined). The second derivative is given by

d2dx2β-GELU(x)=12πβ2ex2/2β2[2(x/β1)2]superscript𝑑2𝑑superscript𝑥2𝛽-GELU𝑥12𝜋superscript𝛽2superscript𝑒superscript𝑥22superscript𝛽2delimited-[]2superscript𝑥superscript𝛽12\frac{d^{2}}{dx^{2}}\beta\text{-GELU}(x)=\frac{1}{\sqrt{2\pi\beta^{-2}}}e^{-x^% {2}/2\beta^{-2}}\left[2-(x/\beta^{-1})^{2}\right]divide start_ARG italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_d italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG italic_β -GELU ( italic_x ) = divide start_ARG 1 end_ARG start_ARG square-root start_ARG 2 italic_π italic_β start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT end_ARG end_ARG italic_e start_POSTSUPERSCRIPT - italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2 italic_β start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT [ 2 - ( italic_x / italic_β start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] (14)

For large β𝛽\betaitalic_β, this function is exponentially small when xβ1much-greater-than𝑥superscript𝛽1x\gg\beta^{-1}italic_x ≫ italic_β start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT, and O(β)𝑂𝛽O(\beta)italic_O ( italic_β ) when |x|=O(β1)𝑥𝑂superscript𝛽1|x|=O(\beta^{-1})| italic_x | = italic_O ( italic_β start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ). As β𝛽\betaitalic_β increases the non-zero region becomes smaller while the non-zero value becomes larger such that the integral is always 1111. This suggests that rather than being uniformly 00, the ReLU second derivative is better described by Dirac delta “function” (really a distribution) - 00 except at the origin, where it is undefined, but still integrable to 1111. Note that β𝛽\betaitalic_β-GELU for large β𝛽\betaitalic_β is different from standard ReLU implementations at the origin, since it has second derivative β𝛽\betaitalic_β, and not 00, at the origin.

The choice of β𝛽\betaitalic_β determines how much information the NME can convey in a practical setting. This second derivative is large only when the input to the activation is within distance 1/β1𝛽1/\beta1 / italic_β of 00. In a deep network this corresponds to being near the boundary of the piecewise multilinear regions where the activations switch on and off. We can illustrate this using two parameters of an MLP in the same layer, where the model is in fact piecewise linear with respect to those parameters (Figure 1). The second derivative serves as an “edge detector”111In fact, the negative of the second order derivative of GELU is closely related to the Laplacian of Gaussian, which is a well-known edge-detector in image processing and computer vision. (more generally, hyperplane detector), and the NME can be used to probe the usefulness of crossing these edges.

From Equation 11, this means that for intermediate β𝛽\betaitalic_β many terms of the diagonal of the NME will be non-zero at a typical point. However as β𝛽\betaitalic_β increases, the probability of terms being non-zero becomes low, but when they are non-zero they are large - giving a sparse, spiky structure to the NME, especially on the diagonal. This leads to the NME becoming a high-variance estimator of local structure. Therefore any methods seeking to use this information explicitly are doomed to fail.

Refer to caption
(a) Imagenet
Refer to caption
(b) CIFAR-10
Figure 3: Accuracy vs β𝛽\betaitalic_β for SGD and SGD with gradient penalty (ρ=0.1𝜌0.1\rho=0.1italic_ρ = 0.1) using β𝛽\betaitalic_β-GELU activations (average of 5555 seeds). We observe that accuracy decreases with larger β𝛽\betaitalic_β with the gradient penalty but not without it. As our theory suggests that the sparsity of the NME increases with β𝛽\betaitalic_β, this is evidence that it has significant impact on gradient penalties.

Our experiments are consistent with this intuition. In Figure 3, we show that accuracy suffers when training with gradient penalties as we increase β𝛽\betaitalic_β but is unaffected for SGD. (We note that large β𝛽\betaitalic_β does worse than ReLU due to the non-zero second derivative of β𝛽\betaitalic_β-GELU at 00.)

Note that we are not claiming that the choice of the activation function is a sufficient condition for gradient penalties to work with larger ρ𝜌\rhoitalic_ρ. There are many architectural changes that can affect the NME matrix and we have shown that the statistics of the activation function is a significant one.

4.4 Augmented ReLU and diminished GELU

Refer to caption
(a) Imagenet
Refer to caption
(b) CIFAR-10
Figure 4: Test accuracy as ρ𝜌\rhoitalic_ρ increases for Augmented ReLU and Diminished GELU (average of 5555 seeds). The addition or removal of information from the NME controls the effectiveness of the gradient penalty.

We can perform a more direct experiment probing the effects of the second derivative part of the NME on the learning dynamics by defining the augmented ReLU and the diminished GELU. The basic idea is to design a modified ReLU which has a well-posed second derivative, and to define a GELU that has a second derivative of 00. This lets us “turn on” the second derivative part of the NME for ReLU, which previously had none, and “turn off” the second derivative part of GELU, making it more similar to the setting with vanilla ReLU.

We will define our augmented and diminished functions using the ability to define custom derivative functions in modern automatic differentiation (AD) frameworks. In AD frameworks, the chain rule is decomposed into derivative operators on basic functions combined with elementary operations. Let us denote the AD derivative operator applied to function f𝑓fitalic_f as 𝒟AD[f]subscript𝒟𝐴𝐷delimited-[]𝑓\mathcal{D}_{AD}[f]caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ italic_f ]. Normally this transformation corresponds to the real derivative operator; that is, 𝒟AD[f]:=df/dxassignsubscript𝒟𝐴𝐷delimited-[]𝑓𝑑𝑓𝑑𝑥\mathcal{D}_{AD}[f]:=df/dxcaligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ italic_f ] := italic_d italic_f / italic_d italic_x.

However we can instead define a custom derivative 𝒟AD[f](x):=g(x)assignsubscript𝒟𝐴𝐷delimited-[]𝑓𝑥𝑔𝑥\mathcal{D}_{AD}[f](x):=g(x)caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ italic_f ] ( italic_x ) := italic_g ( italic_x ). The net result is that any chain rule term evaluating df/dx𝑑𝑓𝑑𝑥df/dxitalic_d italic_f / italic_d italic_x will be replaced by evaluation of g𝑔gitalic_g at that point. In this example, a second application of the AD operator nets us 𝒟AD[𝒟AD[f]](x)=𝒟AD[g](x)subscript𝒟𝐴𝐷delimited-[]subscript𝒟𝐴𝐷delimited-[]𝑓𝑥subscript𝒟𝐴𝐷delimited-[]𝑔𝑥\mathcal{D}_{AD}[\mathcal{D}_{AD}[f]](x)=\mathcal{D}_{AD}[g](x)caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ italic_f ] ] ( italic_x ) = caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ italic_g ] ( italic_x ) - which itself can be a custom derivative.

We define the augmented ReLU as follows: faug(x):=ReLU(x)assignsubscript𝑓𝑎𝑢𝑔𝑥ReLU𝑥f_{aug}(x):={\rm ReLU}(x)italic_f start_POSTSUBSCRIPT italic_a italic_u italic_g end_POSTSUBSCRIPT ( italic_x ) := roman_ReLU ( italic_x ) as normal. We make the common choice to define the first AD derivative as 𝒟AD[faug](x):=Θ(x)assignsubscript𝒟𝐴𝐷delimited-[]subscript𝑓𝑎𝑢𝑔𝑥Θ𝑥\mathcal{D}_{AD}[f_{aug}](x):=\Theta(x)caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ italic_f start_POSTSUBSCRIPT italic_a italic_u italic_g end_POSTSUBSCRIPT ] ( italic_x ) := roman_Θ ( italic_x ), the Heaviside step function (Θ(x)=1Θ𝑥1\Theta(x)=1roman_Θ ( italic_x ) = 1 if x>0𝑥0x>0italic_x > 0, Θ(x)=0Θ𝑥0\Theta(x)=0roman_Θ ( italic_x ) = 0 otherwise). The second AD derivative 𝒟AD[𝒟AD[faug]](x)=𝒟AD[Θ(x)]subscript𝒟𝐴𝐷delimited-[]subscript𝒟𝐴𝐷delimited-[]subscript𝑓𝑎𝑢𝑔𝑥subscript𝒟𝐴𝐷delimited-[]Θ𝑥\mathcal{D}_{AD}[\mathcal{D}_{AD}[f_{aug}]](x)=\mathcal{D}_{AD}[\Theta(x)]caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ italic_f start_POSTSUBSCRIPT italic_a italic_u italic_g end_POSTSUBSCRIPT ] ] ( italic_x ) = caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ roman_Θ ( italic_x ) ]. Normally in AD frameworks, 𝒟AD[Θ(x)]:=0assignsubscript𝒟𝐴𝐷delimited-[]Θ𝑥0\mathcal{D}_{AD}[\Theta(x)]:=0caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ roman_Θ ( italic_x ) ] := 0 and therefore ReLUReLU{\rm ReLU}roman_ReLU implementations have no second derivative; we instead make the definition:

𝒟AD[𝒟AD[faug]](x):=β2πeβ2x2/2assignsubscript𝒟𝐴𝐷delimited-[]subscript𝒟𝐴𝐷delimited-[]subscript𝑓𝑎𝑢𝑔𝑥𝛽2𝜋superscript𝑒superscript𝛽2superscript𝑥22\mathcal{D}_{AD}[\mathcal{D}_{AD}[f_{aug}]](x):=\frac{\beta}{\sqrt{2\pi}}e^{-% \beta^{2}x^{2}/2}caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ italic_f start_POSTSUBSCRIPT italic_a italic_u italic_g end_POSTSUBSCRIPT ] ] ( italic_x ) := divide start_ARG italic_β end_ARG start_ARG square-root start_ARG 2 italic_π end_ARG end_ARG italic_e start_POSTSUPERSCRIPT - italic_β start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2 end_POSTSUPERSCRIPT (15)

Therefore the AD program replaces any second derivatives of faugsubscript𝑓𝑎𝑢𝑔f_{aug}italic_f start_POSTSUBSCRIPT italic_a italic_u italic_g end_POSTSUBSCRIPT (e.g. during HVP calculations) with a Gaussian of width β𝛽\betaitalic_β, which approximates the delta function in the limit β𝛽\beta\to\inftyitalic_β → ∞. However for β𝛽\betaitalic_β of O(1)𝑂1O(1)italic_O ( 1 ), this gives an approximation of the delta function that is more numerically stable, and lets us test if gradient penalty with ReLU can be rescued by adding information related to the second derivative piece of the NME.

Analogously, the diminished GELU is defined by “turning off” the second derivative of GELU. Defining fdim(x):=GELU(x)assignsubscript𝑓𝑑𝑖𝑚𝑥GELU𝑥f_{dim}(x):={\rm GELU}(x)italic_f start_POSTSUBSCRIPT italic_d italic_i italic_m end_POSTSUBSCRIPT ( italic_x ) := roman_GELU ( italic_x ), the first derivative is defined normally as 𝒟AD[fdim](x):=g(x)assignsubscript𝒟𝐴𝐷delimited-[]subscript𝑓𝑑𝑖𝑚𝑥𝑔𝑥\mathcal{D}_{AD}[f_{dim}](x):=g(x)caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ italic_f start_POSTSUBSCRIPT italic_d italic_i italic_m end_POSTSUBSCRIPT ] ( italic_x ) := italic_g ( italic_x ) where g=dGELU/dx𝑔𝑑GELU𝑑𝑥g=d{\rm GELU}/dxitalic_g = italic_d roman_GELU / italic_d italic_x. We define the AD derivative 𝒟AD[g](x)subscript𝒟𝐴𝐷delimited-[]𝑔𝑥\mathcal{D}_{AD}[g](x)caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ italic_g ] ( italic_x ) to be 00, which means:

𝒟AD[𝒟AD[fdim]](x):=0assignsubscript𝒟𝐴𝐷delimited-[]subscript𝒟𝐴𝐷delimited-[]subscript𝑓𝑑𝑖𝑚𝑥0\mathcal{D}_{AD}[\mathcal{D}_{AD}[f_{dim}]](x):=0caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ caligraphic_D start_POSTSUBSCRIPT italic_A italic_D end_POSTSUBSCRIPT [ italic_f start_POSTSUBSCRIPT italic_d italic_i italic_m end_POSTSUBSCRIPT ] ] ( italic_x ) := 0 (16)

This brings the properties of GELU closer to that of ReLU at least in terms of the higher order derivatives. Diminished GELU lets us test whether or not the second derivative part of the NME is necessary for the success of gradient penalties with GELU.

We used our Imagenet setup to train with augmented ReLU and diminished GELU (Figure 4). We find that augmented ReLU performs better than plain ReLU and nearly matches the performance of GELU, while diminished GELU has poor performance similar to ReLU. This suggests that second derivative information is necessary for the improved performance of GELU with gradient penalties, and moreover it is helpful to make gradient penalties work with ReLU. This gives us direct evidence that in this setting, information from the NME is crucial for good generalization, and gradient penalties are sensitive to second derivatives of activation functions.

5 How NME affects training with Hessian penalties

In this section we will show that the NME has significant impact on the effectiveness of Hessian penalties. In particular, we consider the case of weight noise because as we will see it is an efficient way to penalize the Hessian. In contrast to the previous section where the NME solely influenced learning dynamics, weight noise implicitly regularizes the NME. We will show through ablations that this regularization is detrimental and explain why.

5.1 Weight Noise analysis neglects the NME

We first review the analysis of training with noise established by Bishop (1995). Though the paper considers input noise, the same analysis can be applied to weight noise. Adding Gaussian ϵ𝒩(0,σ2)similar-tobold-italic-ϵ𝒩0superscript𝜎2\bm{\epsilon}\sim\mathcal{N}(0,\sigma^{2})bold_italic_ϵ ∼ caligraphic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) noise with strength hyper-parameter σ𝜎\sigmaitalic_σ to the parameters can be approximated to second order by

Eϵ[(𝜽+ϵ)](𝜽)+Eϵ[𝜽ϵ]+Eϵ[ϵT𝐇ϵ]=(𝜽)+σ2tr(𝐇)subscriptEbold-italic-ϵdelimited-[]𝜽bold-italic-ϵ𝜽cancelsubscriptEbold-italic-ϵdelimited-[]subscript𝜽bold-italic-ϵsubscriptEbold-italic-ϵdelimited-[]superscriptbold-italic-ϵT𝐇bold-italic-ϵ𝜽superscript𝜎2tr𝐇{\rm E}_{\bm{\epsilon}}[\mathcal{L}(\bm{\theta}+\bm{\epsilon})]\approx\mathcal% {L}(\bm{\theta})+\cancel{{\rm E}_{\bm{\epsilon}}[\nabla_{\bm{\theta}}\mathcal{% L}\cdot\bm{\epsilon}]}+{\rm E}_{\bm{\epsilon}}[\bm{\epsilon}^{{\rm T}}\mathbf{% H}\bm{\epsilon}]=\mathcal{L}(\bm{\theta})+\sigma^{2}\text{tr}(\mathbf{H})roman_E start_POSTSUBSCRIPT bold_italic_ϵ end_POSTSUBSCRIPT [ caligraphic_L ( bold_italic_θ + bold_italic_ϵ ) ] ≈ caligraphic_L ( bold_italic_θ ) + cancel roman_E start_POSTSUBSCRIPT bold_italic_ϵ end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ⋅ bold_italic_ϵ ] + roman_E start_POSTSUBSCRIPT bold_italic_ϵ end_POSTSUBSCRIPT [ bold_italic_ϵ start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_H bold_italic_ϵ ] = caligraphic_L ( bold_italic_θ ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT tr ( bold_H ) (17)

where the second term has zero expectation since ϵbold-italic-ϵ\bm{\epsilon}bold_italic_ϵ is mean 00, and the third term is a variation of the Hutchison trace estimator (Hutchinson, 1989). (We note that though the second term vanishes in expectation, it still can have large effects on the training dynamics.) (Bishop, 1995) argues that we can simplify the term related to the Hessian by drop** the NME in Equation 2 for the purposes of minimization

tr(𝐇)=tr(𝐉T𝐇𝐳𝐉+𝐳𝜽2𝐳)tr(𝐉T𝐇𝐳𝐉)tr𝐇trsuperscript𝐉Tsubscript𝐇𝐳𝐉subscript𝐳subscriptsuperscript2𝜽𝐳trsuperscript𝐉Tsubscript𝐇𝐳𝐉\text{tr}(\mathbf{H})=\text{tr}\left(\mathbf{J}^{{\rm T}}\mathbf{H}_{\mathbf{z% }}\mathbf{J}+\nabla_{\mathbf{z}}\mathcal{L}\cdot\nabla^{2}_{\bm{\theta}}% \mathbf{z}\right)\approx\text{tr}(\mathbf{J}^{{\rm T}}\mathbf{H}_{\mathbf{z}}% \mathbf{J})tr ( bold_H ) = tr ( bold_J start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_H start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT bold_J + ∇ start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT caligraphic_L ⋅ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT bold_z ) ≈ tr ( bold_J start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_H start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT bold_J ) (18)

The argument is that for the purposes of training neural networks this term can be dropped because it is zero at the global minimum.

However, the hypothesis that the NME has negligible impact in this setting has not been experimentally verified. We address this gap in the next section by providing evidence that the NME cannot be neglected for modern networks.

5.2 Ablations reveal the influence of the NME

In order to study the impact of the NME in this setting, we evaluate ablations of weight noise to determine the impact of the different components. Recalling Equation 17, the methods we will consider are given by

Eϵ[(𝜽+ϵ)]Weight Noise=(𝜽)+σ2tr(𝐉T𝐇𝐳𝐉)Gauss-Newton Trace Penalty+σ2tr(𝐳𝜽2𝐳)Hessian Trace Penalty+𝒪(ϵ2)subscriptsubscriptEbold-italic-ϵdelimited-[]𝜽bold-italic-ϵWeight Noisesuperscriptsubscript𝜽superscript𝜎2trsuperscript𝐉Tsubscript𝐇𝐳𝐉Gauss-Newton Trace Penaltysuperscript𝜎2trsubscript𝐳subscriptsuperscript2𝜽𝐳Hessian Trace Penalty𝒪superscriptnormbold-italic-ϵ2\underbrace{{\rm E}_{\bm{\epsilon}}[\mathcal{L}(\bm{\theta}+\bm{\epsilon})]}_{% \text{Weight Noise}}=\overbrace{\underbrace{\mathcal{L}(\bm{\theta})+\sigma^{2% }\text{tr}\left(\mathbf{J}^{{\rm T}}\mathbf{H}_{\mathbf{z}}\mathbf{J}\right)}_% {\text{Gauss-Newton Trace Penalty}}+\sigma^{2}\text{tr}\left(\nabla_{\mathbf{z% }}\mathcal{L}\cdot\nabla^{2}_{\bm{\theta}}\mathbf{z}\right)}^{\text{Hessian % Trace Penalty}}+\mathcal{O}(\|\bm{\epsilon}\|^{2})under⏟ start_ARG roman_E start_POSTSUBSCRIPT bold_italic_ϵ end_POSTSUBSCRIPT [ caligraphic_L ( bold_italic_θ + bold_italic_ϵ ) ] end_ARG start_POSTSUBSCRIPT Weight Noise end_POSTSUBSCRIPT = over⏞ start_ARG under⏟ start_ARG caligraphic_L ( bold_italic_θ ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT tr ( bold_J start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_H start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT bold_J ) end_ARG start_POSTSUBSCRIPT Gauss-Newton Trace Penalty end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT tr ( ∇ start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT caligraphic_L ⋅ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT bold_z ) end_ARG start_POSTSUPERSCRIPT Hessian Trace Penalty end_POSTSUPERSCRIPT + caligraphic_O ( ∥ bold_italic_ϵ ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (19)

Hessian Trace penalty This ablation allows to us to single out the second order effect of weight noise, as it’s possible the higher order terms from weight noise affect generalization. We implement this penalty with Hutchinson’s trace estimator (tr(𝐇)=Eϵ𝒩(0,1)[ϵT𝐇ϵ]tr𝐇subscriptEsimilar-tobold-italic-ϵ𝒩01delimited-[]superscriptbold-italic-ϵ𝑇𝐇bold-italic-ϵ\text{tr}(\mathbf{H})={\rm E}_{\bm{\epsilon}\sim\mathcal{N}(0,1)}[\bm{\epsilon% }^{T}\mathbf{H}\bm{\epsilon}]tr ( bold_H ) = roman_E start_POSTSUBSCRIPT bold_italic_ϵ ∼ caligraphic_N ( 0 , 1 ) end_POSTSUBSCRIPT [ bold_italic_ϵ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_H bold_italic_ϵ ]).

Gauss-Newton Trace penalty This ablation removes the NME’s contribution, enabling us to isolate and measure its specific influence on the model. Recent work has proposed a new estimator for the trace of the Gauss-Newton matrix for cross-entropy loss Wei et al. (2020). Using this estimator, we can efficiently compute this penalty using

tr(𝐉T𝐇𝐳𝐉)=E𝐲^Cat(𝐳)[𝜽(𝜽,𝐲^)2]trsuperscript𝐉Tsubscript𝐇𝐳𝐉subscriptEsimilar-to^𝐲Cat𝐳delimited-[]superscriptnormsubscript𝜽𝜽^𝐲2\text{tr}\left(\mathbf{J}^{{\rm T}}\mathbf{H}_{\mathbf{z}}\mathbf{J}\right)={% \rm E}_{\hat{\mathbf{y}}\sim\text{Cat}(\mathbf{z})}[\left\|\nabla_{\bm{\theta}% }\mathcal{L}(\bm{\theta},\hat{\mathbf{y}})\right\|^{2}]tr ( bold_J start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_H start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT bold_J ) = roman_E start_POSTSUBSCRIPT over^ start_ARG bold_y end_ARG ∼ Cat ( bold_z ) end_POSTSUBSCRIPT [ ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ , over^ start_ARG bold_y end_ARG ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] (20)

where Cat()Cat\text{Cat}(\cdot)Cat ( ⋅ ) is the categorical distribution and 𝐳𝐳\mathbf{z}bold_z are the logits. This computes the norm of the gradients, but with the labels sampled from the model instead of the ground-truth. We do not pass gradients through the sampling of the labels 𝐲^^𝐲\hat{\mathbf{y}}over^ start_ARG bold_y end_ARG, but we find similar results if we pass gradients using the straight-through estimator (Bengio et al., 2013). Note the similarity to the gradient penalties studied in the previous section, which we will address in later sections.

We draw a single sample to estimate the expectations for the different estimators. We experimented with 2 samples for the Hessian Trace penalty but we found this did not affect the results.

Figure 5 shows that the methods perform quite differently as σ2superscript𝜎2\sigma^{2}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT increases - confirming the influence of the NME. We can see that the generalization improvement of the Gauss-Newton Trace penalty is consistently greater than either weight noise or Hessian Trace penalty. Its improvement on Imagenet is a significant 1.6%percent1.61.6\%1.6 %. In contrast, the other methods provide little accuracy improvement.

Refer to caption
(a) Imagenet
Refer to caption
(b) CIFAR-10
Figure 5: Test Accuracy as σ2superscript𝜎2\sigma^{2}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT increases across different datasets and activation functions averaged over 5555 seeds. Large σ2superscript𝜎2\sigma^{2}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT reveals a stark contrast between the Gauss-Newton trace penalty (excluding NME) and methods incorporating it, highlighting the NME’s influence.

These results are evidence that the NME term of the Hessian should not be dropped when applying the analysis of (Bishop, 1995) to weight noise for modern networks. Indeed there is a significant difference between the Hessian trace penalty, which involves the NME, and the Gauss-Newton penalty, which does not. This suggests that while the NME has a positive influence on the learning dynamics as seen for gradient penalties in Section 4, it is detrimental to regularize it directly in the loss function.

This is not contradictory with the analysis in Section 4 which suggested that incorporating NME information into updates helps learning. Minimizing the NME through the loss will reduce its impact on the learning dynamics. We can also see that the Gauss-Newton penalty, which does not involve the NME in the loss, indeed involves the NME in the update rule:

𝜽tr(𝐉T𝐇𝐳𝐉)=E𝐲^Cat(𝐳)[(𝐳(𝜽,𝐲^)𝜽2𝐳)𝜽(𝜽,𝐲^)]subscript𝜽trsuperscript𝐉Tsubscript𝐇𝐳𝐉subscriptEsimilar-to^𝐲Cat𝐳delimited-[]subscript𝐳𝜽^𝐲superscriptsubscript𝜽2𝐳subscript𝜽𝜽^𝐲\nabla_{\bm{\theta}}{\rm tr}\left(\mathbf{J}^{{\rm T}}\mathbf{H}_{\mathbf{z}}% \mathbf{J}\right)={\rm E}_{\hat{\mathbf{y}}\sim\text{Cat}(\mathbf{z})}[\left(% \nabla_{\mathbf{z}}\mathcal{L}(\bm{\theta},\hat{\mathbf{y}})\cdot\nabla_{\bm{% \theta}}^{2}\mathbf{z}\right)\nabla_{\bm{\theta}}\mathcal{L}(\bm{\theta},\hat{% \mathbf{y}})]∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT roman_tr ( bold_J start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_H start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT bold_J ) = roman_E start_POSTSUBSCRIPT over^ start_ARG bold_y end_ARG ∼ Cat ( bold_z ) end_POSTSUBSCRIPT [ ( ∇ start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ , over^ start_ARG bold_y end_ARG ) ⋅ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_z ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ , over^ start_ARG bold_y end_ARG ) ] (21)

This update rule is very similar to the update rule for the gradient penalty in Equation 8. The three differences are the lack of a normalization factor (equivalent to p=2𝑝2p=2italic_p = 2 gradient penalty), the lack of Gauss-Newton vector product, and the fact that the NME is computed over the labels generated by the model and not the true labels. Therefore the Gauss-Newton trace penalty, the best performing of our ablations, does indeed incorporate NME information into the update rule.

6 Discussion

Our theoretical analysis gives some understanding of the structure of the Hessian - in particular, the Nonlinear Modeling Error matrix. This piece of the Hessian is often neglected as it is generally indefinite and doesn’t generate large eigenvalues, and is 00 at an interpolating minimum. However, the NME can encode important information related to feature learning as it depends on 𝜽2𝐳subscriptsuperscript2𝜽𝐳\nabla^{2}_{\bm{\theta}}\mathbf{z}∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT bold_z - the gradient of the Jacobian. For example, in networks with saturating activation functions the NME gives information about the potential benefits of switching into the saturated regimes of different neurons. More generally, our analysis shows that the elements of the NME, especially on the diagonal are sensitive to the second derivative of the activation function.

6.1 NME and gradient penalties

Our experiments suggest that these second derivative properties can be quite important when training with gradient penalty regularizers. ReLU has a poorly defined pointwise second derivative and the regularizer harms training, while GELU has a well defined one and gains benefits from modest values of the regularizer. Our experiments with β𝛽\betaitalic_β-GELU suggest that if the NME is well-defined but sparse and “spiky”, we also achieve poor training.

One important point here is that the sensitivity to the second derivative comes from the fact that the update rule (Equation 8) involves an explicit Hessian-vector product. We can contrast this with methods which use second order information implicitly via first order measurements. In particular, the SAM algorithm for controlling curvature is equivalent, to low order in regularization strength ρ𝜌\rhoitalic_ρ, to gradient penalties (Appendix B). SAM approximates dynamics on a sharpness-penalized objective by taking two steps with gradient information, and it works with both ReLU and GELU - matching the performance of gradient penalties on GELU (Appendix B.3).

The difference between SAM and the gradient penalty is that SAM acquires second order information via discrete, gradient-based steps. It is effectively integrating over the Hessian (and therefore NME) information. Therefore it is not as sensitive to the pointwise properties of the second derivative of the activation function.

The NME is also important in understanding the regularizers in Section 5, where we showed that even in the case of the Gauss-Newton trace penalty, the NME shows up in the update rule. Therefore the NME can be important for understanding dynamics even when regularization efforts focus on the Gauss-Newton.

6.2 Lessons for using second order information

Our work suggests that some second order methods may benefit from tuning the NME. This is especially true for methods which result in Hessian-vector products in update rules (like the gradient and Hessian penalties studied here).

Our experiments with augmented ReLU suggest that helpful interventions can be designed to improve propagation of NME information. We hypothesize that this information is related to feature learning, and therefore acts over the totality of training to affect generalization.

7 Conclusion

Our work sheds light on the complexities of using second order information in deep learning. We have identified clear cases where it is important to consider the effects of both the Gauss-Newton and Nonlinear Modeling Error terms, and design algorithms and architectures with that in mind. Designing activation functions for compatibility with second order methods may also be an interesting avenue of future research.

References

  • Agarwala & Dauphin (2023) Atish Agarwala and Yann Dauphin. SAM operates far from home: Eigenvalue regularization as a dynamical phenomenon. In Proceedings of the 40th International Conference on Machine Learning, pp.  152–168. PMLR, July 2023.
  • Agarwala et al. (2020) Atish Agarwala, Jeffrey Pennington, Yann Dauphin, and Sam Schoenholz. Temperature check: Theory and practice for training models with softmax-cross-entropy losses, October 2020.
  • Agarwala et al. (2022) Atish Agarwala, Fabian Pedregosa, and Jeffrey Pennington. Second-order regression models exhibit progressive sharpening to the edge of stability, October 2022.
  • An (1996) Guozhong An. The effects of adding noise during backpropagation training on a generalization performance. Neural computation, 8(3):643–674, 1996.
  • Andriushchenko & Flammarion (2022) Maksym Andriushchenko and Nicolas Flammarion. Towards understanding sharpness-aware minimization. In International Conference on Machine Learning, pp.  639–668. PMLR, 2022.
  • Barrett & Dherin (2021) David G. T. Barrett and Benoit Dherin. Implicit gradient regularization. In 9th International Conference on Learning Representations, ICLR 2021, Virtual Event, Austria, May 3-7, 2021. OpenReview.net, 2021. URL https://openreview.net/forum?id=3q5IqUrkcF.
  • Bengio et al. (2013) Yoshua Bengio, Nicholas Léonard, and Aaron Courville. Estimating or propagating gradients through stochastic neurons for conditional computation. arXiv preprint arXiv:1308.3432, 2013.
  • Bishop (1995) Chris M Bishop. Training with noise is equivalent to tikhonov regularization. Neural computation, 7(1):108–116, 1995.
  • Chizat et al. (2019) Lénaïc Chizat, Edouard Oyallon, and Francis Bach. On Lazy Training in Differentiable Programming. In Advances in Neural Information Processing Systems 32, pp.  2937–2947. Curran Associates, Inc., 2019.
  • Dean et al. (2012) Jeffrey Dean, Greg Corrado, Rajat Monga, Kai Chen, Matthieu Devin, Mark Mao, Marc’aurelio Ranzato, Andrew Senior, Paul Tucker, Ke Yang, et al. Large scale distributed deep networks. Advances in neural information processing systems, 25, 2012.
  • Deng et al. (2009) Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pp.  248–255. Ieee, 2009.
  • Du et al. (2022) Jiawei Du, Zhou Daquan, Jiashi Feng, Vincent Tan, and Joey Tianyi Zhou. Sharpness-aware training for free. In Alice H. Oh, Alekh Agarwal, Danielle Belgrave, and Kyunghyun Cho (eds.), Advances in Neural Information Processing Systems, 2022. URL https://openreview.net/forum?id=xK6wRfL2mv7.
  • Foret et al. (2020) Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412, 2020.
  • Goyal et al. (2018) Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. Accurate, large minibatch sgd: Training imagenet in 1 hour, 2018.
  • Hutchinson (1989) Michael F Hutchinson. A stochastic estimator of the trace of the influence matrix for laplacian smoothing splines. Communications in Statistics-Simulation and Computation, 18(3):1059–1076, 1989.
  • Jacot et al. (2018) Arthur Jacot, Franck Gabriel, and Clement Hongler. Neural Tangent Kernel: Convergence and Generalization in Neural Networks. In Advances in Neural Information Processing Systems 31, pp.  8571–8580. Curran Associates, Inc., 2018.
  • Jacot et al. (2020) Arthur Jacot, Franck Gabriel, and Clement Hongler. The asymptotic spectrum of the Hessian of DNN throughout training. In International Conference on Learning Representations, March 2020.
  • Krizhevsky et al. (2009) Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images. 2009.
  • Lee et al. (2019) Jaehoon Lee, Lechao Xiao, Samuel Schoenholz, Yasaman Bahri, Roman Novak, Jascha Sohl-Dickstein, and Jeffrey Pennington. Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent. In Advances in Neural Information Processing Systems 32, pp.  8570–8581. Curran Associates, Inc., 2019.
  • Liu et al. (2023) Hong Liu, Zhiyuan Li, David Hall, Percy Liang, and Tengyu Ma. Sophia: A scalable stochastic second-order optimizer for language model pre-training, 2023.
  • Loshchilov & Hutter (2016) Ilya Loshchilov and Frank Hutter. Sgdr: Stochastic gradient descent with warm restarts. arXiv preprint arXiv:1608.03983, 2016.
  • Martens (2020) James Martens. New Insights and Perspectives on the Natural Gradient Method. Journal of Machine Learning Research, 21(146):1–76, 2020. ISSN 1533-7928.
  • Martens & Grosse (2015) James Martens and Roger Grosse. Optimizing Neural Networks with Kronecker-factored Approximate Curvature. In Proceedings of the 32nd International Conference on Machine Learning, pp.  2408–2417. PMLR, June 2015.
  • Martens & Sutskever (2011) James Martens and Ilya Sutskever. Learning recurrent neural networks with hessian-free optimization. In Proceedings of the 28th international conference on machine learning (ICML-11), pp.  1033–1040, 2011.
  • Martens et al. (2021) James Martens, Andy Ballard, Guillaume Desjardins, Grzegorz Swirszcz, Valentin Dalibard, Jascha Sohl-Dickstein, and Samuel S. Schoenholz. Rapid training of deep neural networks without skip connections or normalization layers using Deep Kernel Sha**. arXiv:2110.01765 [cs], October 2021.
  • Pascanu & Bengio (2013) Razvan Pascanu and Yoshua Bengio. Revisiting natural gradient for deep networks. arXiv preprint arXiv:1301.3584, 2013.
  • Pennington et al. (2017) Jeffrey Pennington, Samuel Schoenholz, and Surya Ganguli. Resurrecting the sigmoid in deep learning through dynamical isometry: Theory and practice. In Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017.
  • Reizinger & Huszár (2023) Patrik Reizinger and Ferenc Huszár. SAMBA: Regularized autoencoders perform sharpness-aware minimization. In Fifth Symposium on Advances in Approximate Bayesian Inference, 2023. URL https://openreview.net/forum?id=gk3PAmy_UNz.
  • Sagun et al. (2017) Levent Sagun, Utku Evci, V Ugur Guney, Yann Dauphin, and Leon Bottou. Empirical analysis of the hessian of over-parametrized neural networks. arXiv preprint arXiv:1706.04454, 2017.
  • Sankar et al. (2021) Adepu Ravi Sankar, Yash Khasbage, Rahul Vigneswaran, and Vineeth N Balasubramanian. A deeper look at the hessian eigenspectrum of deep neural networks and its applications to regularization. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 35, pp.  9481–9488, 2021.
  • Singh et al. (2021) Sidak Pal Singh, Gregor Bachmann, and Thomas Hofmann. Analytic Insights into Structure and Rank of Neural Network Hessian Maps, July 2021.
  • Singh et al. (2023) Sidak Pal Singh, Thomas Hofmann, and Bernhard Schölkopf. The Hessian perspective into the Nature of Convolutional Neural Networks. In Proceedings of the 40th International Conference on Machine Learning, pp.  31930–31968. PMLR, July 2023.
  • Smith et al. (2021) Samuel L. Smith, Benoit Dherin, David G. T. Barrett, and Soham De. On the origin of implicit regularization in stochastic gradient descent, 2021.
  • Wei et al. (2020) Colin Wei, Sham Kakade, and Tengyu Ma. The implicit and explicit regularization effects of dropout. In International conference on machine learning, pp.  10181–10192. PMLR, 2020.
  • Zagoruyko & Komodakis (2016) Sergey Zagoruyko and Nikos Komodakis. Wide residual networks. arXiv preprint arXiv:1605.07146, 2016.
  • Zhao et al. (2022) Yang Zhao, Hao Zhang, and Xiuyuan Hu. Penalizing gradient norm for efficiently improving generalization in deep learning, 2022.

Appendix A Hessian structure

A.1 Gauss-Newton and NTK learning

In the large width limit (width/channels/patches increasing while dataset is fixed), the learning dynamics of neural networks are well described by the neural tangent kernel, or NTK (Jacot et al., 2018; Lee et al., 2019). Consider a dataset size D𝐷Ditalic_D, with outputs 𝐳(𝜽,𝐗)𝐳𝜽𝐗\mathbf{z}(\bm{\theta},\mathbf{X})bold_z ( bold_italic_θ , bold_X ) over the inputs 𝐗𝐗\mathbf{X}bold_X with parameters 𝜽𝜽\bm{\theta}bold_italic_θ. The (empirical) NTK 𝚯^bold-^𝚯\bm{\hat{\Theta}}overbold_^ start_ARG bold_Θ end_ARG is the D×D𝐷𝐷D\times Ditalic_D × italic_D matrix given by

𝚯^1D𝐉𝐉T,𝐉𝐳𝜽formulae-sequencebold-^𝚯1𝐷superscript𝐉𝐉T𝐉𝐳𝜽\bm{\hat{\Theta}}\equiv\frac{1}{D}\mathbf{J}\mathbf{J}^{{\rm T}},~{}\mathbf{J}% \equiv\frac{\partial\mathbf{z}}{\partial\bm{\theta}}overbold_^ start_ARG bold_Θ end_ARG ≡ divide start_ARG 1 end_ARG start_ARG italic_D end_ARG bold_JJ start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT , bold_J ≡ divide start_ARG ∂ bold_z end_ARG start_ARG ∂ bold_italic_θ end_ARG (22)

For wide enough networks, the learning dynamics can be written in terms of the model output 𝐳𝐳\mathbf{z}bold_z and the NTK Θ^^Θ\hat{\Theta}over^ start_ARG roman_Θ end_ARG alone. For small learning rates we can study the gradient flow dynamics. The gradient flow dynamics on the parameters 𝜽𝜽\bm{\theta}bold_italic_θ with loss function \mathcal{L}caligraphic_L (averaged over the dataset) is given by

𝜽˙=1D𝜽=1D𝐉T𝐳˙𝜽1𝐷subscript𝜽1𝐷superscript𝐉Tsubscript𝐳\dot{\bm{\theta}}=-\frac{1}{D}\nabla_{\bm{\theta}}\mathcal{L}=-\frac{1}{D}% \mathbf{J}^{{\rm T}}\nabla_{\mathbf{z}}\mathcal{L}over˙ start_ARG bold_italic_θ end_ARG = - divide start_ARG 1 end_ARG start_ARG italic_D end_ARG ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L = - divide start_ARG 1 end_ARG start_ARG italic_D end_ARG bold_J start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT caligraphic_L (23)

We can use the chain rule to write down the dynamics of 𝐳𝐳\mathbf{z}bold_z:

𝐳˙=𝐳𝜽𝜽˙=1D𝐉𝐉Tz=𝚯^z˙𝐳𝐳𝜽˙𝜽1𝐷superscript𝐉𝐉Tsubscript𝑧bold-^𝚯subscript𝑧\dot{\mathbf{z}}=\frac{\partial\mathbf{z}}{\partial\bm{\theta}}\dot{\bm{\theta% }}=-\frac{1}{D}\mathbf{J}\mathbf{J}^{{\rm T}}\nabla_{z}\mathcal{L}=-\bm{\hat{% \Theta}}\nabla_{z}\mathcal{L}over˙ start_ARG bold_z end_ARG = divide start_ARG ∂ bold_z end_ARG start_ARG ∂ bold_italic_θ end_ARG over˙ start_ARG bold_italic_θ end_ARG = - divide start_ARG 1 end_ARG start_ARG italic_D end_ARG bold_JJ start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT caligraphic_L = - overbold_^ start_ARG bold_Θ end_ARG ∇ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT caligraphic_L (24)

In the limit of infinite width, the overall changes in individual parameters become small, and the Θ^^Θ\hat{\Theta}over^ start_ARG roman_Θ end_ARG is fixed during training. This corresponds to the linearized or lazy regime Chizat et al. (2019); Agarwala et al. (2020). The NTK encodes the linear response of 𝐳𝐳\mathbf{z}bold_z to small changes in 𝜽𝜽\bm{\theta}bold_italic_θ, and the dynamics is closed in terms of 𝐳𝐳\mathbf{z}bold_z. For finite width networks, this can well-approximate the dynamics for a number of steps related to the network width amongst other properties Lee et al. (2019).

In order to understand the dynamics of Equation 24 at small times, or around minima, we can linearize with respect to 𝐳𝐳\mathbf{z}bold_z. We have:

𝐳˙𝐳=𝚯^𝐳𝐳𝚯^𝐇𝐳˙𝐳𝐳bold-^𝚯𝐳subscript𝐳bold-^𝚯subscript𝐇𝐳\frac{\partial\dot{\mathbf{z}}}{\partial\mathbf{z}}=-\frac{\partial\bm{\hat{% \Theta}}}{\partial\mathbf{z}}\nabla_{\mathbf{z}}\mathcal{L}-\bm{\hat{\Theta}}% \mathbf{H}_{\mathbf{z}}divide start_ARG ∂ over˙ start_ARG bold_z end_ARG end_ARG start_ARG ∂ bold_z end_ARG = - divide start_ARG ∂ overbold_^ start_ARG bold_Θ end_ARG end_ARG start_ARG ∂ bold_z end_ARG ∇ start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT caligraphic_L - overbold_^ start_ARG bold_Θ end_ARG bold_H start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT (25)

where 𝐇𝐳=2𝐳𝐳subscript𝐇𝐳superscript2𝐳superscript𝐳\mathbf{H}_{\mathbf{z}}=\frac{\partial^{2}\mathcal{L}}{\partial\mathbf{z}% \partial\mathbf{z}^{\prime}}bold_H start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT = divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L end_ARG start_ARG ∂ bold_z ∂ bold_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG. In the limit of large width, the NTK is constant and the first term vanishes. The local dynamics depends on the spectrum of 𝚯^𝐇𝐳bold-^𝚯subscript𝐇𝐳\bm{\hat{\Theta}}\mathbf{H}_{\mathbf{z}}overbold_^ start_ARG bold_Θ end_ARG bold_H start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT. From the cyclic property of the trace, the non-zero part of the spectrum is equal to the non-zero spectrum of 1D𝐉T𝐇𝐳𝐉1𝐷superscript𝐉Tsubscript𝐇𝐳𝐉\frac{1}{D}\mathbf{J}^{{\rm T}}\mathbf{H}_{\mathbf{z}}\mathbf{J}divide start_ARG 1 end_ARG start_ARG italic_D end_ARG bold_J start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_H start_POSTSUBSCRIPT bold_z end_POSTSUBSCRIPT bold_J - which is the Gauss-Newton matrix.

Therefore the eigenvalues of the Gauss-Newton matrix control the short term, linearized dynamics of 𝐳𝐳\mathbf{z}bold_z, for fixed NTK. It is in this sense that the Gauss-Newton encodes information about exploiting the local linear structure of the model.

A.2 Nonlinear Modeling Error and second derivatives of FCNs

We can explicitly compute the Jacobian and second derivative of the model for a fully connected network. We write a feedforward network as follows:

𝐡ł=𝐖ł𝐱ł,𝐱ł+1=ϕ(𝐡ł)formulae-sequencesubscript𝐡italic-łsubscript𝐖italic-łsubscript𝐱italic-łsubscript𝐱italic-ł1italic-ϕsubscript𝐡italic-ł\mathbf{h}_{\l}=\mathbf{W}_{\l}\mathbf{x}_{\l},~{}\mathbf{x}_{\l+1}=\phi(% \mathbf{h}_{\l})bold_h start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT = bold_W start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT bold_x start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT , bold_x start_POSTSUBSCRIPT italic_ł + 1 end_POSTSUBSCRIPT = italic_ϕ ( bold_h start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT ) (26)

The gradient of 𝐱Lsubscript𝐱𝐿\mathbf{x}_{L}bold_x start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT with respect to 𝐖lsubscript𝐖𝑙\mathbf{W}_{l}bold_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT can be written as:

𝐱L𝐖ł=𝐱L𝐡ł𝐡l𝐖lsubscript𝐱𝐿subscript𝐖italic-łsubscript𝐱𝐿subscript𝐡italic-łsubscript𝐡𝑙subscript𝐖𝑙\frac{\partial\mathbf{x}_{L}}{\partial\mathbf{W}_{\l}}=\frac{\partial\mathbf{x% }_{L}}{\partial\mathbf{h}_{\l}}\frac{\partial\mathbf{h}_{l}}{\partial\mathbf{W% }_{l}}divide start_ARG ∂ bold_x start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT end_ARG = divide start_ARG ∂ bold_x start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_h start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT end_ARG divide start_ARG ∂ bold_h start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG (27)

which can be written in coordinate-free notation as

𝐱L𝐖ł=𝐱L𝐡ł𝐱lsubscript𝐱𝐿subscript𝐖italic-łtensor-productsubscript𝐱𝐿subscript𝐡italic-łsubscript𝐱𝑙\frac{\partial\mathbf{x}_{L}}{\partial\mathbf{W}_{\l}}=\frac{\partial\mathbf{x% }_{L}}{\partial\mathbf{h}_{\l}}\otimes\mathbf{x}_{l}divide start_ARG ∂ bold_x start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT end_ARG = divide start_ARG ∂ bold_x start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_h start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT end_ARG ⊗ bold_x start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT (28)

If we define the partial Jacobian 𝐉łł𝐱l𝐱lsubscript𝐉superscriptitalic-łitalic-łsubscript𝐱superscript𝑙subscript𝐱𝑙\mathbf{J}_{\l^{\prime}\l}\equiv\frac{\partial\mathbf{x}_{l^{\prime}}}{% \partial\mathbf{x}_{l}}bold_J start_POSTSUBSCRIPT italic_ł start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_ł end_POSTSUBSCRIPT ≡ divide start_ARG ∂ bold_x start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_x start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG, ł>łsuperscriptitalic-łitalic-ł\l^{\prime}>\litalic_ł start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT > italic_ł

𝐱L𝐖ł=𝐉L(ł+1)ϕ(𝐡l)𝐱lsubscript𝐱𝐿subscript𝐖italic-łtensor-productsubscript𝐉𝐿italic-ł1superscriptitalic-ϕsubscript𝐡𝑙subscript𝐱𝑙\frac{\partial\mathbf{x}_{L}}{\partial\mathbf{W}_{\l}}=\mathbf{J}_{L(\l+1)}% \circ\phi^{\prime}(\mathbf{h}_{l})\otimes\mathbf{x}_{l}divide start_ARG ∂ bold_x start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT end_ARG = bold_J start_POSTSUBSCRIPT italic_L ( italic_ł + 1 ) end_POSTSUBSCRIPT ∘ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ⊗ bold_x start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT (29)

Here \circ denotes the Hadamard product, in this case equivalent to matrix multiplication by diag(ϕ(𝐡m))diagsuperscriptitalic-ϕsubscript𝐡𝑚{\rm diag}(\phi^{\prime}(\mathbf{h}_{m}))roman_diag ( italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) ).

The Jacobian can be explicitly written as

𝐉łł=m=łł1ϕ(𝐡m)𝐖msubscript𝐉superscriptitalic-łitalic-łsuperscriptsubscriptproduct𝑚italic-łsuperscriptitalic-ł1superscriptitalic-ϕsubscript𝐡𝑚subscript𝐖𝑚\mathbf{J}_{\l^{\prime}\l}=\prod_{m=\l}^{\l^{\prime}-1}\phi^{\prime}(\mathbf{h% }_{m})\circ\mathbf{W}_{m}bold_J start_POSTSUBSCRIPT italic_ł start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_ł end_POSTSUBSCRIPT = ∏ start_POSTSUBSCRIPT italic_m = italic_ł end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_ł start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) ∘ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT (30)

Therefore, we can write:

𝐱L𝐖ł=[m=ł+1L1ϕ(𝐡m)𝐖m]ϕ(𝐡l)𝐱lsubscript𝐱𝐿subscript𝐖italic-łtensor-productdelimited-[]superscriptsubscriptproduct𝑚italic-ł1𝐿1superscriptitalic-ϕsubscript𝐡𝑚subscript𝐖𝑚superscriptitalic-ϕsubscript𝐡𝑙subscript𝐱𝑙\frac{\partial\mathbf{x}_{L}}{\partial\mathbf{W}_{\l}}=\left[\prod_{m=\l+1}^{L% -1}\phi^{\prime}(\mathbf{h}_{m})\circ\mathbf{W}_{m}\right]\circ\phi^{\prime}(% \mathbf{h}_{l})\otimes\mathbf{x}_{l}divide start_ARG ∂ bold_x start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT end_ARG = [ ∏ start_POSTSUBSCRIPT italic_m = italic_ł + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L - 1 end_POSTSUPERSCRIPT italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) ∘ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ] ∘ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ⊗ bold_x start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT (31)

The second derivative is more complicated. Consider

2𝐱L𝐖ł𝐖m=𝐖m[𝐉L(ł+1)ϕ(𝐡ł)𝐱l]superscript2subscript𝐱𝐿subscript𝐖italic-łsubscript𝐖𝑚subscript𝐖𝑚delimited-[]tensor-productsubscript𝐉𝐿italic-ł1superscriptitalic-ϕsubscript𝐡italic-łsubscript𝐱𝑙\frac{\partial^{2}\mathbf{x}_{L}}{\partial\mathbf{W}_{\l}\partial\mathbf{W}_{m% }}=\frac{\partial}{\partial\mathbf{W}_{m}}\left[\mathbf{J}_{L(\l+1)}\circ\phi^% {\prime}(\mathbf{h}_{\l})\otimes\mathbf{x}_{l}\right]divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_x start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG = divide start_ARG ∂ end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG [ bold_J start_POSTSUBSCRIPT italic_L ( italic_ł + 1 ) end_POSTSUBSCRIPT ∘ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT ) ⊗ bold_x start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ] (32)

for weight matrices 𝐖łsubscript𝐖italic-ł\mathbf{W}_{\l}bold_W start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT and 𝐖msubscript𝐖𝑚\mathbf{W}_{m}bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT. Without loss of generality, assume ml𝑚𝑙m\geq litalic_m ≥ italic_l.

We first consider the case where m>l𝑚𝑙m>litalic_m > italic_l. In this case, we have

ϕ(𝐡ł)𝐖m=0,𝐱l𝐖m=0formulae-sequencesuperscriptitalic-ϕsubscript𝐡italic-łsubscript𝐖𝑚0subscript𝐱𝑙subscript𝐖𝑚0\frac{\partial\phi^{\prime}(\mathbf{h}_{\l})}{\partial\mathbf{W}_{m}}=0,~{}% \frac{\partial\mathbf{x}_{l}}{\partial\mathbf{W}_{m}}=0divide start_ARG ∂ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG = 0 , divide start_ARG ∂ bold_x start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG = 0 (33)

since 𝐖msubscript𝐖𝑚\mathbf{W}_{m}bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT comes after 𝐡lsubscript𝐡𝑙\mathbf{h}_{l}bold_h start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT. If we write down the derivative of 𝐉L(l+1)subscript𝐉𝐿𝑙1\mathbf{J}_{L(l+1)}bold_J start_POSTSUBSCRIPT italic_L ( italic_l + 1 ) end_POSTSUBSCRIPT, there are two types of terms. The first comes from the direct differentiation of 𝐖msubscript𝐖𝑚\mathbf{W}_{m}bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT; the others come from differentation of ϕ(𝐡n)superscriptitalic-ϕsubscript𝐡𝑛\phi^{\prime}(\mathbf{h}_{n})italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) for nm𝑛𝑚n\geq mitalic_n ≥ italic_m. We have:

𝐉L(l+1)𝐖m=𝐉L(m+1)ϕ(𝐡m)𝐖m𝐖m𝐉(m1)(l+1)+o=mL1𝐉L(o+1)ϕ(𝐡o)𝐖m𝐖o𝐉(o1)(l+1)subscript𝐉𝐿𝑙1subscript𝐖𝑚subscript𝐉𝐿𝑚1superscriptitalic-ϕsubscript𝐡𝑚subscript𝐖𝑚subscript𝐖𝑚subscript𝐉𝑚1𝑙1superscriptsubscript𝑜𝑚𝐿1subscript𝐉𝐿𝑜1superscriptitalic-ϕsubscript𝐡𝑜subscript𝐖𝑚subscript𝐖𝑜subscript𝐉𝑜1𝑙1\frac{\partial\mathbf{J}_{L(l+1)}}{\partial\mathbf{W}_{m}}=\mathbf{J}_{L(m+1)}% \phi^{\prime}(\mathbf{h}_{m})\frac{\partial\mathbf{W}_{m}}{\partial\mathbf{W}_% {m}}\mathbf{J}_{(m-1)(l+1)}+\sum_{o=m}^{L-1}\mathbf{J}_{L(o+1)}\frac{\partial% \phi^{\prime}(\mathbf{h}_{o})}{\partial\mathbf{W}_{m}}\mathbf{W}_{o}\mathbf{J}% _{(o-1)(l+1)}divide start_ARG ∂ bold_J start_POSTSUBSCRIPT italic_L ( italic_l + 1 ) end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG = bold_J start_POSTSUBSCRIPT italic_L ( italic_m + 1 ) end_POSTSUBSCRIPT italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) divide start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG bold_J start_POSTSUBSCRIPT ( italic_m - 1 ) ( italic_l + 1 ) end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT italic_o = italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L - 1 end_POSTSUPERSCRIPT bold_J start_POSTSUBSCRIPT italic_L ( italic_o + 1 ) end_POSTSUBSCRIPT divide start_ARG ∂ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG bold_W start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT bold_J start_POSTSUBSCRIPT ( italic_o - 1 ) ( italic_l + 1 ) end_POSTSUBSCRIPT (34)

The 𝐖msubscript𝐖𝑚\mathbf{W}_{m}bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT derivative projected into a direction 𝐁𝐁\mathbf{B}bold_B can be written as:

𝐉L(l+1)𝐖m𝐁=𝐉L(m+1)ϕ(𝐡m)𝐁𝐉(m1)(l+1)+o=mL1𝐉L(o+1)(ϕ′′(𝐡o)𝐖o𝐱o1𝐖m𝐁)𝐖o𝐉(o1)(l+1)subscript𝐉𝐿𝑙1subscript𝐖𝑚𝐁subscript𝐉𝐿𝑚1superscriptitalic-ϕsubscript𝐡𝑚subscript𝐁𝐉𝑚1𝑙1superscriptsubscript𝑜𝑚𝐿1subscript𝐉𝐿𝑜1superscriptitalic-ϕ′′subscript𝐡𝑜subscript𝐖𝑜subscript𝐱𝑜1subscript𝐖𝑚𝐁subscript𝐖𝑜subscript𝐉𝑜1𝑙1\begin{split}\frac{\partial\mathbf{J}_{L(l+1)}}{\partial\mathbf{W}_{m}}\cdot% \mathbf{B}&=\mathbf{J}_{L(m+1)}\phi^{\prime}(\mathbf{h}_{m})\mathbf{B}\mathbf{% J}_{(m-1)(l+1)}\\ &+\sum_{o=m}^{L-1}\mathbf{J}_{L(o+1)}\left(\phi^{\prime\prime}(\mathbf{h}_{o})% \circ\mathbf{W}_{o}\frac{\partial\mathbf{x}_{o-1}}{\partial\mathbf{W}_{m}}% \cdot\mathbf{B}\right)\mathbf{W}_{o}\mathbf{J}_{(o-1)(l+1)}\end{split}start_ROW start_CELL divide start_ARG ∂ bold_J start_POSTSUBSCRIPT italic_L ( italic_l + 1 ) end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG ⋅ bold_B end_CELL start_CELL = bold_J start_POSTSUBSCRIPT italic_L ( italic_m + 1 ) end_POSTSUBSCRIPT italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) bold_BJ start_POSTSUBSCRIPT ( italic_m - 1 ) ( italic_l + 1 ) end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL + ∑ start_POSTSUBSCRIPT italic_o = italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L - 1 end_POSTSUPERSCRIPT bold_J start_POSTSUBSCRIPT italic_L ( italic_o + 1 ) end_POSTSUBSCRIPT ( italic_ϕ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) ∘ bold_W start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT divide start_ARG ∂ bold_x start_POSTSUBSCRIPT italic_o - 1 end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG ⋅ bold_B ) bold_W start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT bold_J start_POSTSUBSCRIPT ( italic_o - 1 ) ( italic_l + 1 ) end_POSTSUBSCRIPT end_CELL end_ROW (35)

From our previous analysis, we have:

𝐉L(l+1)𝐖m𝐁=𝐉L(m+1)ϕ(𝐡m)𝐁𝐉(m1)(l+1)+o=mL1𝐉L(o+1)(ϕ′′(𝐡o)[𝐖o𝐉o(m+1)ϕ(𝐡m+1)𝐁𝐱m])ϕ(𝐡o)𝐖m𝐖o𝐉(o1)(l+1)subscript𝐉𝐿𝑙1subscript𝐖𝑚𝐁subscript𝐉𝐿𝑚1superscriptitalic-ϕsubscript𝐡𝑚subscript𝐁𝐉𝑚1𝑙1superscriptsubscript𝑜𝑚𝐿1subscript𝐉𝐿𝑜1superscriptitalic-ϕ′′subscript𝐡𝑜delimited-[]subscript𝐖𝑜subscript𝐉𝑜𝑚1superscriptitalic-ϕsubscript𝐡𝑚1subscript𝐁𝐱𝑚superscriptitalic-ϕsubscript𝐡𝑜subscript𝐖𝑚subscript𝐖𝑜subscript𝐉𝑜1𝑙1\begin{split}\frac{\partial\mathbf{J}_{L(l+1)}}{\partial\mathbf{W}_{m}}\cdot% \mathbf{B}&=\mathbf{J}_{L(m+1)}\phi^{\prime}(\mathbf{h}_{m})\mathbf{B}\mathbf{% J}_{(m-1)(l+1)}\\ &+\sum_{o=m}^{L-1}\mathbf{J}_{L(o+1)}\left(\phi^{\prime\prime}(\mathbf{h}_{o})% \circ\left[\mathbf{W}_{o}\mathbf{J}_{o(m+1)}\circ\phi^{\prime}(\mathbf{h}_{m+1% })\circ\mathbf{B}\mathbf{x}_{m}\right]\right)\frac{\partial\phi^{\prime}(% \mathbf{h}_{o})}{\partial\mathbf{W}_{m}}\mathbf{W}_{o}\mathbf{J}_{(o-1)(l+1)}% \end{split}start_ROW start_CELL divide start_ARG ∂ bold_J start_POSTSUBSCRIPT italic_L ( italic_l + 1 ) end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG ⋅ bold_B end_CELL start_CELL = bold_J start_POSTSUBSCRIPT italic_L ( italic_m + 1 ) end_POSTSUBSCRIPT italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) bold_BJ start_POSTSUBSCRIPT ( italic_m - 1 ) ( italic_l + 1 ) end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL + ∑ start_POSTSUBSCRIPT italic_o = italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L - 1 end_POSTSUPERSCRIPT bold_J start_POSTSUBSCRIPT italic_L ( italic_o + 1 ) end_POSTSUBSCRIPT ( italic_ϕ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) ∘ [ bold_W start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT bold_J start_POSTSUBSCRIPT italic_o ( italic_m + 1 ) end_POSTSUBSCRIPT ∘ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_m + 1 end_POSTSUBSCRIPT ) ∘ bold_Bx start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ] ) divide start_ARG ∂ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG bold_W start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT bold_J start_POSTSUBSCRIPT ( italic_o - 1 ) ( italic_l + 1 ) end_POSTSUBSCRIPT end_CELL end_ROW (36)

In total, the second derivative projected into the (𝐀,𝐁)𝐀𝐁(\mathbf{A},\mathbf{B})( bold_A , bold_B ) direction for m>ł𝑚italic-łm>\litalic_m > italic_ł is given by:

2𝐱L𝐖l𝐖m(𝐀𝐁)=[𝐉L(m+1)ϕ(𝐡m)𝐁𝐉(m1)(l+1)+o=mL1𝐉L(o+1)(ϕ′′(𝐡o)[𝐖o𝐉o(m+1)ϕ(𝐡m+1)𝐁𝐱m])ϕ(𝐡o)𝐖m𝐖o𝐉(o1)(l+1)]ϕ(𝐡ł)𝐀𝐱lsuperscript2subscript𝐱𝐿subscript𝐖𝑙subscript𝐖𝑚tensor-product𝐀𝐁delimited-[]subscript𝐉𝐿𝑚1superscriptitalic-ϕsubscript𝐡𝑚subscript𝐁𝐉𝑚1𝑙1superscriptsubscript𝑜𝑚𝐿1subscript𝐉𝐿𝑜1superscriptitalic-ϕ′′subscript𝐡𝑜delimited-[]subscript𝐖𝑜subscript𝐉𝑜𝑚1superscriptitalic-ϕsubscript𝐡𝑚1subscript𝐁𝐱𝑚superscriptitalic-ϕsubscript𝐡𝑜subscript𝐖𝑚subscript𝐖𝑜subscript𝐉𝑜1𝑙1superscriptitalic-ϕsubscript𝐡italic-łsubscript𝐀𝐱𝑙\begin{split}\frac{\partial^{2}\mathbf{x}_{L}}{\partial\mathbf{W}_{l}\partial% \mathbf{W}_{m}}\cdot(\mathbf{A}\otimes\mathbf{B})&=\left[\mathbf{J}_{L(m+1)}% \phi^{\prime}(\mathbf{h}_{m})\mathbf{B}\mathbf{J}_{(m-1)(l+1)}+\right.\\ &\left.\sum_{o=m}^{L-1}\mathbf{J}_{L(o+1)}\left(\phi^{\prime\prime}(\mathbf{h}% _{o})\circ\left[\mathbf{W}_{o}\mathbf{J}_{o(m+1)}\circ\phi^{\prime}(\mathbf{h}% _{m+1})\circ\mathbf{B}\mathbf{x}_{m}\right]\right)\frac{\partial\phi^{\prime}(% \mathbf{h}_{o})}{\partial\mathbf{W}_{m}}\mathbf{W}_{o}\mathbf{J}_{(o-1)(l+1)}% \right]\\ &\circ\phi^{\prime}(\mathbf{h}_{\l})\mathbf{A}\mathbf{x}_{l}\end{split}start_ROW start_CELL divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_x start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG ⋅ ( bold_A ⊗ bold_B ) end_CELL start_CELL = [ bold_J start_POSTSUBSCRIPT italic_L ( italic_m + 1 ) end_POSTSUBSCRIPT italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) bold_BJ start_POSTSUBSCRIPT ( italic_m - 1 ) ( italic_l + 1 ) end_POSTSUBSCRIPT + end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ∑ start_POSTSUBSCRIPT italic_o = italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L - 1 end_POSTSUPERSCRIPT bold_J start_POSTSUBSCRIPT italic_L ( italic_o + 1 ) end_POSTSUBSCRIPT ( italic_ϕ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) ∘ [ bold_W start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT bold_J start_POSTSUBSCRIPT italic_o ( italic_m + 1 ) end_POSTSUBSCRIPT ∘ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_m + 1 end_POSTSUBSCRIPT ) ∘ bold_Bx start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ] ) divide start_ARG ∂ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG bold_W start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT bold_J start_POSTSUBSCRIPT ( italic_o - 1 ) ( italic_l + 1 ) end_POSTSUBSCRIPT ] end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ∘ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_ł end_POSTSUBSCRIPT ) bold_Ax start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_CELL end_ROW (37)

Now consider the case m=ł𝑚italic-łm=\litalic_m = italic_ł. Here there is no direct differentiation with respect to 𝐖msubscript𝐖𝑚\mathbf{W}_{m}bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT, but there is a derivative with respect to ϕ(𝐡m)superscriptitalic-ϕsubscript𝐡𝑚\phi^{\prime}(\mathbf{h}_{m})italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ). The derivative is written as:

2𝐱L𝐖m𝐖m(𝐀𝐁)=𝐉L(m+1)[ϕ′′(𝐡m)𝐁𝐱l]𝐀𝐱m+[o=mL1𝐉L(o+1)(ϕ′′(𝐡o)[𝐖o𝐉o(m+1)ϕ(𝐡m+1)𝐁𝐱m])ϕ(𝐡o)𝐖m𝐖o𝐉(o1)(m+1)]ϕ(𝐡m)𝐀𝐱msuperscript2subscript𝐱𝐿subscript𝐖𝑚subscript𝐖𝑚tensor-product𝐀𝐁subscript𝐉𝐿𝑚1delimited-[]superscriptitalic-ϕ′′subscript𝐡𝑚subscript𝐁𝐱𝑙subscript𝐀𝐱𝑚delimited-[]superscriptsubscript𝑜𝑚𝐿1subscript𝐉𝐿𝑜1superscriptitalic-ϕ′′subscript𝐡𝑜delimited-[]subscript𝐖𝑜subscript𝐉𝑜𝑚1superscriptitalic-ϕsubscript𝐡𝑚1subscript𝐁𝐱𝑚superscriptitalic-ϕsubscript𝐡𝑜subscript𝐖𝑚subscript𝐖𝑜subscript𝐉𝑜1𝑚1superscriptitalic-ϕsubscript𝐡𝑚subscript𝐀𝐱𝑚\begin{split}\frac{\partial^{2}\mathbf{x}_{L}}{\partial\mathbf{W}_{m}\partial% \mathbf{W}_{m}}\cdot(\mathbf{A}\otimes\mathbf{B})&=\mathbf{J}_{L(m+1)}\circ[% \phi^{\prime\prime}(\mathbf{h}_{m})\circ\mathbf{B}\mathbf{x}_{l}]\mathbf{A}% \mathbf{x}_{m}+\\ &\left[\sum_{o=m}^{L-1}\mathbf{J}_{L(o+1)}\left(\phi^{\prime\prime}(\mathbf{h}% _{o})\circ\left[\mathbf{W}_{o}\mathbf{J}_{o(m+1)}\circ\phi^{\prime}(\mathbf{h}% _{m+1})\circ\mathbf{B}\mathbf{x}_{m}\right]\right)\frac{\partial\phi^{\prime}(% \mathbf{h}_{o})}{\partial\mathbf{W}_{m}}\mathbf{W}_{o}\mathbf{J}_{(o-1)(m+1)}% \right]\\ &\circ\phi^{\prime}(\mathbf{h}_{m})\mathbf{A}\mathbf{x}_{m}\end{split}start_ROW start_CELL divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_x start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG ⋅ ( bold_A ⊗ bold_B ) end_CELL start_CELL = bold_J start_POSTSUBSCRIPT italic_L ( italic_m + 1 ) end_POSTSUBSCRIPT ∘ [ italic_ϕ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) ∘ bold_Bx start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ] bold_Ax start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT + end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL [ ∑ start_POSTSUBSCRIPT italic_o = italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L - 1 end_POSTSUPERSCRIPT bold_J start_POSTSUBSCRIPT italic_L ( italic_o + 1 ) end_POSTSUBSCRIPT ( italic_ϕ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) ∘ [ bold_W start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT bold_J start_POSTSUBSCRIPT italic_o ( italic_m + 1 ) end_POSTSUBSCRIPT ∘ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_m + 1 end_POSTSUBSCRIPT ) ∘ bold_Bx start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ] ) divide start_ARG ∂ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ bold_W start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG bold_W start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT bold_J start_POSTSUBSCRIPT ( italic_o - 1 ) ( italic_m + 1 ) end_POSTSUBSCRIPT ] end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ∘ italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) bold_Ax start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_CELL end_ROW (38)

There are two key points: first, all but one of the terms in the off-diagonal second derivative depend on only first derivatives of the activation; for a deep network, the majority of the terms depend on ϕ′′superscriptitalic-ϕ′′\phi^{\prime\prime}italic_ϕ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT. Secondly, on the diagonal, all terms depend on ϕ′′superscriptitalic-ϕ′′\phi^{\prime\prime}italic_ϕ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT. Therefore if ϕ′′(x)=0superscriptitalic-ϕ′′𝑥0\phi^{\prime\prime}(x)=0italic_ϕ start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( italic_x ) = 0, the diagonal of the model second derivative is 00 as well.

Appendix B SAM and gradient penalties

The gradient penalties studied in Section 4 are related to the Sharpness Aware Minimization algorithm (SAM) developed to combat high curvature in deep learning (Foret et al., 2020). In this appendix we review the basics of SAM, show the correspondence to gradient penalties, and show that SAM is less sensitive to the choice of activation function.

B.1 SAM

The ideas behind the SAM algorithm originates from seeking a minimum with a uniformly low loss in its neighborhood (hence flat). This is formulated in Foret et al. (2020) as a minmax problem,

min𝜽maxϵ(𝜽+ϵ)s.t.ϵρ.subscript𝜽subscriptbold-italic-ϵ𝜽bold-italic-ϵs.t.normbold-italic-ϵ𝜌\min_{\bm{\theta}}\max_{\bm{\epsilon}}\mathcal{L}(\bm{\theta}+\bm{\epsilon})% \quad\mbox{s.t.}\quad\|\bm{\epsilon}\|\leq\rho\,.roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT roman_max start_POSTSUBSCRIPT bold_italic_ϵ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ + bold_italic_ϵ ) s.t. ∥ bold_italic_ϵ ∥ ≤ italic_ρ . (39)

For computational tractability, Foret et al. (2020) approximates the inner optimization by linearizing \mathcal{L}caligraphic_L w.r.t. ϵbold-italic-ϵ\bm{\epsilon}bold_italic_ϵ around the origin. Plugging the optimal ϵbold-italic-ϵ\bm{\epsilon}bold_italic_ϵ into the objective function yields

min𝜽(𝜽+ρ𝜽(𝜽)𝜽(𝜽)).subscript𝜽𝜽𝜌subscript𝜽𝜽normsubscript𝜽𝜽\min_{\bm{\theta}}\mathcal{L}\Big{(}\bm{\theta}+\rho\,\frac{\nabla_{\bm{\theta% }}\mathcal{L}(\bm{\theta})}{\|\nabla_{\bm{\theta}}\mathcal{L}(\bm{\theta})\|}% \Big{)}\,.roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ + italic_ρ divide start_ARG ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ ) ∥ end_ARG ) . (40)

To minimize the above by gradient descent, we would need to compute222In our notation the gradient and Hessian operators \nabla and 2superscript2\nabla^{2}∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT precede function evaluation, e.g. 𝜽(f(𝜽))subscript𝜽𝑓𝜽\nabla_{\bm{\theta}}\mathcal{L}(f(\bm{\theta}))∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( italic_f ( bold_italic_θ ) ) means (𝝉(𝝉))𝝉=f(𝜽)subscript𝝉𝝉𝝉𝑓𝜽\big{(}\frac{\partial}{\partial\bm{\tau}}\mathcal{L}(\bm{\tau})\big{)}_{\bm{% \tau}=f(\bm{\theta})}( divide start_ARG ∂ end_ARG start_ARG ∂ bold_italic_τ end_ARG caligraphic_L ( bold_italic_τ ) ) start_POSTSUBSCRIPT bold_italic_τ = italic_f ( bold_italic_θ ) end_POSTSUBSCRIPT.:

𝜽(𝜽+ρ𝐠(𝜽)𝐠(𝜽))=(𝐈+ρ𝐇𝐠(𝐈𝐠𝐠𝐠T𝐠)Hessian related term)𝜽(𝜽+ρ𝐠𝐠),𝐠𝜽(𝜽),𝐇𝜽2(𝜽)formulae-sequence𝜽𝜽𝜌𝐠𝜽norm𝐠𝜽𝐈subscript𝜌𝐇norm𝐠𝐈𝐠norm𝐠superscript𝐠Tnorm𝐠Hessian related termsubscript𝜽𝜽𝜌𝐠norm𝐠formulae-sequence𝐠subscript𝜽𝜽𝐇superscriptsubscript𝜽2𝜽\frac{\partial}{\partial\bm{\theta}}\mathcal{L}\Big{(}\bm{\theta}+\rho\frac{% \mathbf{g}(\bm{\theta})}{\|\mathbf{g}(\bm{\theta})\|}\Big{)}\,=\,\Bigg{(}% \mathbf{I}+\underbrace{\rho\frac{\mathbf{H}}{\|\mathbf{g}\|}\Big{(}\mathbf{I}-% \frac{\mathbf{g}}{\|\mathbf{g}\|}\frac{{\mathbf{g}}^{{\rm T}}}{\|\mathbf{g}\|}% \Big{)}}_{\mbox{Hessian related term}}\Bigg{)}\,\,\,\nabla_{\bm{\theta}}% \mathcal{L}\left(\bm{\theta}+\rho\frac{\mathbf{g}}{\|\mathbf{g}\|}\right)\,,~{% }\mathbf{g}\equiv\nabla_{\bm{\theta}}\mathcal{L}(\bm{\theta}),~{}\mathbf{H}% \equiv\nabla_{\bm{\theta}}^{2}\mathcal{L}(\bm{\theta})divide start_ARG ∂ end_ARG start_ARG ∂ bold_italic_θ end_ARG caligraphic_L ( bold_italic_θ + italic_ρ divide start_ARG bold_g ( bold_italic_θ ) end_ARG start_ARG ∥ bold_g ( bold_italic_θ ) ∥ end_ARG ) = ( bold_I + under⏟ start_ARG italic_ρ divide start_ARG bold_H end_ARG start_ARG ∥ bold_g ∥ end_ARG ( bold_I - divide start_ARG bold_g end_ARG start_ARG ∥ bold_g ∥ end_ARG divide start_ARG bold_g start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT end_ARG start_ARG ∥ bold_g ∥ end_ARG ) end_ARG start_POSTSUBSCRIPT Hessian related term end_POSTSUBSCRIPT ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ + italic_ρ divide start_ARG bold_g end_ARG start_ARG ∥ bold_g ∥ end_ARG ) , bold_g ≡ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ ) , bold_H ≡ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( bold_italic_θ ) (41)

This can still be computationally demanding as it involves the computation of a Hessian-vector product 𝐇𝐠𝐇𝐠\mathbf{H}\mathbf{g}bold_Hg. The SAM algorithm drops the Hessian related term in (41) giving the update rule:

𝜽𝜽η𝜽(𝜽+ρ𝐠~),𝐠~𝜽(𝜽)/𝜽(𝜽)formulae-sequence𝜽𝜽𝜂subscript𝜽𝜽𝜌~𝐠~𝐠subscript𝜽𝜽normsubscript𝜽𝜽\bm{\theta}\leftarrow\bm{\theta}-\eta\,\nabla_{\bm{\theta}}\mathcal{L}\left(% \bm{\theta}+\rho\tilde{\mathbf{g}}\right),~{}\tilde{\mathbf{g}}\equiv\nabla_{% \bm{\theta}}\mathcal{L}(\bm{\theta})/||\nabla_{\bm{\theta}}\mathcal{L}(\bm{% \theta})||bold_italic_θ ← bold_italic_θ - italic_η ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ + italic_ρ over~ start_ARG bold_g end_ARG ) , over~ start_ARG bold_g end_ARG ≡ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ ) / | | ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ ) | | (42)

for some step-size parameter η>0𝜂0\eta>0italic_η > 0. A related learning algorithm is unnormalized SAM (USAM) with update rule (Andriushchenko & Flammarion, 2022)

𝜽𝜽η𝜽(𝜽+ρ𝐠),𝐠𝜽(𝜽)formulae-sequence𝜽𝜽𝜂subscript𝜽𝜽𝜌𝐠𝐠subscript𝜽𝜽\bm{\theta}\leftarrow\bm{\theta}-\eta\,\nabla_{\bm{\theta}}\mathcal{L}\left(% \bm{\theta}+\rho\mathbf{g}\right),~{}\mathbf{g}\equiv\nabla_{\bm{\theta}}% \mathcal{L}(\bm{\theta})bold_italic_θ ← bold_italic_θ - italic_η ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ + italic_ρ bold_g ) , bold_g ≡ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ ) (43)

USAM has similar performance to SAM and is easier to analyze (Agarwala & Dauphin, 2023).

B.2 Penalty SAM

If ρ𝜌\rhoitalic_ρ is very small, we may approximate \mathcal{L}caligraphic_L in (40) by its first order Taylor expansion around the point ρ=0𝜌0\rho=0italic_ρ = 0 as below.

PSAM(𝜽)subscriptPSAM𝜽\displaystyle\mathcal{L}_{\text{PSAM}}(\bm{\theta})\,caligraphic_L start_POSTSUBSCRIPT PSAM end_POSTSUBSCRIPT ( bold_italic_θ ) (𝜽)ρ=0+ρ(ρ(𝜽+ρ𝜽(𝜽)𝜽(𝜽)))ρ=0+O(ρ2)absentsubscript𝜽𝜌0𝜌subscript𝜌𝜽𝜌subscript𝜽𝜽normsubscript𝜽𝜽𝜌0𝑂superscript𝜌2\displaystyle\triangleq\,\mathcal{L}(\bm{\theta})_{\rho=0}+\rho\Big{(}\frac{% \partial}{\partial\rho}\mathcal{L}\,\Big{(}\bm{\theta}+\rho\frac{\nabla_{\bm{% \theta}}\mathcal{L}(\bm{\theta})}{\|\nabla_{\bm{\theta}}\mathcal{L}(\bm{\theta% })\|}\Big{)}\Big{)}_{\rho=0}\,+O(\rho^{2})≜ caligraphic_L ( bold_italic_θ ) start_POSTSUBSCRIPT italic_ρ = 0 end_POSTSUBSCRIPT + italic_ρ ( divide start_ARG ∂ end_ARG start_ARG ∂ italic_ρ end_ARG caligraphic_L ( bold_italic_θ + italic_ρ divide start_ARG ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ ) ∥ end_ARG ) ) start_POSTSUBSCRIPT italic_ρ = 0 end_POSTSUBSCRIPT + italic_O ( italic_ρ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (44)
=(𝜽)+ρ𝜽(𝜽),𝜽(𝜽)𝜽(𝜽)+O(ρ2)absent𝜽𝜌subscript𝜽𝜽subscript𝜽𝜽normsubscript𝜽𝜽𝑂superscript𝜌2\displaystyle=\,\mathcal{L}(\bm{\theta})+\rho\left\langle\nabla_{\bm{\theta}}% \mathcal{L}(\bm{\theta})\,,\,\frac{\nabla_{\bm{\theta}}\mathcal{L}(\bm{\theta}% )}{\|\nabla_{\bm{\theta}}\mathcal{L}(\bm{\theta})\|}\right\rangle+O(\rho^{2})= caligraphic_L ( bold_italic_θ ) + italic_ρ ⟨ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ ) , divide start_ARG ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ ) ∥ end_ARG ⟩ + italic_O ( italic_ρ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (45)
=(𝜽)+ρ𝜽(𝜽)+O(ρ2).absent𝜽𝜌normsubscript𝜽𝜽𝑂superscript𝜌2\displaystyle\,=\,\mathcal{L}(\bm{\theta})+\rho\,\|\nabla_{\bm{\theta}}% \mathcal{L}(\bm{\theta})\|+O(\rho^{2})\,.= caligraphic_L ( bold_italic_θ ) + italic_ρ ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ ) ∥ + italic_O ( italic_ρ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) . (46)

Drop** terms of O(ρ2)𝑂superscript𝜌2O(\rho^{2})italic_O ( italic_ρ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) we arrive at the gradient penalty with p=1𝑝1p=1italic_p = 1. If ρ𝜌\rhoitalic_ρ is not close to zero, then loss landscape of PSAMsubscriptPSAM\mathcal{L}_{\text{PSAM}}caligraphic_L start_POSTSUBSCRIPT PSAM end_POSTSUBSCRIPT provides a very poor approximation to that of 40. In the remainder of this section, we refer to this specific gradient penalty as Penalty SAM and denote its associated objective function (46) by PSAM. The unnormalized equivalent PUSAM is

PUSAM(𝜽)(𝜽)+ρ𝜽(𝜽)2+O(ρ2).subscriptPUSAM𝜽𝜽𝜌superscriptnormsubscript𝜽𝜽2𝑂superscript𝜌2\mathcal{L}_{\text{PUSAM}}(\bm{\theta})\,\triangleq\,\mathcal{L}(\bm{\theta})+% \rho\,\|\nabla_{\bm{\theta}}\mathcal{L}(\bm{\theta})\|^{2}+O(\rho^{2})\,.caligraphic_L start_POSTSUBSCRIPT PUSAM end_POSTSUBSCRIPT ( bold_italic_θ ) ≜ caligraphic_L ( bold_italic_θ ) + italic_ρ ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_O ( italic_ρ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) . (47)

which corresponds to the p=2𝑝2p=2italic_p = 2 case of the gradient penalty.

B.3 Penalty SAM vs Original SAM

Figure 6 shows our experimental results comparing PSAM and SAM for Imagenet with different activation functions. We already saw in Section 4 that PSAM behaves differently between the two activation functions; by contrast, SAM is insensitive to them. The original SAM algorithm implicitly captures the NME information with the discrete ρ𝜌\rhoitalic_ρ step even in the ReLU case, while the gradient penalty, which uses explicit Hessian-gradient products, does not.

This suggests that another way to combat poor NME performance is to incorporate first order information from nearby points. There may be cases where this is more efficient computationally; SAM requires 2222 gradient computations per step, which is similar to the cost of an HVP.

Refer to caption
(a) Imagenet with ReLU
Refer to caption
(b) Imagenet with GELU
Figure 6: Test Accuracy as ρ𝜌\rhoitalic_ρ increases across different datasets and activation functions averaged over 2 seeds. For ReLU networks and large ρ𝜌\rhoitalic_ρ, there is a significant difference between PSAM and SAM. PSAM with GELU networks more closely follows the behavior of SAM.

B.4 Penalty SAM vs. implicit regularization of SGD

The analysis of Smith et al. (2021) suggested that SGD with learning rate η𝜂\etaitalic_η is similar to gradient flow (GF) with PUSAM with ρ=η/4𝜌𝜂4\rho=\eta/4italic_ρ = italic_η / 4. In this section we use a linear model to highlight some key differences between PUSAM and the discrete effects from finite stepsize.

Consider a quadratic loss (𝜽)=12𝜽T𝐇𝜽𝜽12superscript𝜽T𝐇𝜽\mathcal{L}(\bm{\theta})=\frac{1}{2}\bm{\theta}^{{\rm T}}\mathbf{H}\bm{\theta}caligraphic_L ( bold_italic_θ ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_θ start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_H bold_italic_θ for some parameters 𝜽𝜽\bm{\theta}bold_italic_θ and PSD Hessian 𝐇𝐇\mathbf{H}bold_H. It is illustrative to consider gradient descent (GD) with learning rate η𝜂\etaitalic_η and (unnormalized) penalty SAM with radius ρ𝜌\rhoitalic_ρ.

The gradient descent update rule is

𝜽t+1𝜽t=η(𝐇+ρ𝐇2)𝜽tsubscript𝜽𝑡1subscript𝜽𝑡𝜂𝐇𝜌superscript𝐇2subscript𝜽𝑡\bm{\theta}_{t+1}-\bm{\theta}_{t}=-\eta(\mathbf{H}+\rho\mathbf{H}^{2})\bm{% \theta}_{t}bold_italic_θ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT - bold_italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = - italic_η ( bold_H + italic_ρ bold_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) bold_italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (48)

The “effective Hessian” is given by 𝐇+ρ𝐇2𝐇𝜌superscript𝐇2\mathbf{H}+\rho\mathbf{H}^{2}bold_H + italic_ρ bold_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (see Agarwala & Dauphin (2023) for more analysis). Solving the linear equation gives us

𝜽t=(1η(𝐇+ρ𝐇2))t𝜽0subscript𝜽𝑡superscript1𝜂𝐇𝜌superscript𝐇2𝑡subscript𝜽0\bm{\theta}_{t}=\left(1-\eta(\mathbf{H}+\rho\mathbf{H}^{2})\right)^{t}\bm{% \theta}_{0}bold_italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( 1 - italic_η ( bold_H + italic_ρ bold_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT (49)

This dynamics is well described by the eigenvalues of the effective Hessian - λ+ρλ2𝜆𝜌superscript𝜆2\lambda+\rho\lambda^{2}italic_λ + italic_ρ italic_λ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, where λ𝜆\lambdaitalic_λ are the eigenvalues of 𝐇𝐇\mathbf{H}bold_H. The effect of the regularizer is therefore to introduce eigenvalue-dependent modifications into the Hessian.

There is a special setting of ρ𝜌\rhoitalic_ρ which can be derived from the calculations in Smith et al. (2021). Consider ρ=η/2𝜌𝜂2\rho=\eta/2italic_ρ = italic_η / 2, and consider the dynamics after 2t2𝑡2t2 italic_t steps. We have:

𝜽2t=(1η(𝐇+12η𝐇2))2t𝜽0subscript𝜽2𝑡superscript1𝜂𝐇12𝜂superscript𝐇22𝑡subscript𝜽0\bm{\theta}_{2t}=\left(1-\eta(\mathbf{H}+\frac{1}{2}\eta\mathbf{H}^{2})\right)% ^{2t}\bm{\theta}_{0}bold_italic_θ start_POSTSUBSCRIPT 2 italic_t end_POSTSUBSCRIPT = ( 1 - italic_η ( bold_H + divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_η bold_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 2 italic_t end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT (50)

which can be re-written as

𝜽2t=(12η𝐇+η3𝐇3+14η4𝐇4)t𝜽0subscript𝜽2𝑡superscript12𝜂𝐇superscript𝜂3superscript𝐇314superscript𝜂4superscript𝐇4𝑡subscript𝜽0\bm{\theta}_{2t}=\left(1-2\eta\mathbf{H}+\eta^{3}\mathbf{H}^{3}+\frac{1}{4}% \eta^{4}\mathbf{H}^{4}\right)^{t}\bm{\theta}_{0}bold_italic_θ start_POSTSUBSCRIPT 2 italic_t end_POSTSUBSCRIPT = ( 1 - 2 italic_η bold_H + italic_η start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT bold_H start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG 4 end_ARG italic_η start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT bold_H start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT (51)

To leading order in η𝐇𝜂𝐇\eta\mathbf{H}italic_η bold_H, this is the same as the dynamics for learning rate 2η2𝜂2\eta2 italic_η, ρ=0𝜌0\rho=0italic_ρ = 0 after t𝑡titalic_t steps:

𝜽t=(12η𝐇)t𝜽0subscript𝜽𝑡superscript12𝜂𝐇𝑡subscript𝜽0\bm{\theta}_{t}=\left(1-2\eta\mathbf{H}\right)^{t}\bm{\theta}_{0}bold_italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( 1 - 2 italic_η bold_H ) start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT (52)

We note that these two are similar only if η𝐇1much-less-than𝜂𝐇1\eta\mathbf{H}\ll 1italic_η bold_H ≪ 1. Under this condition, ηρ𝐇2=12η2𝐇2η𝐇𝜂𝜌superscript𝐇212superscript𝜂2superscript𝐇2much-less-than𝜂𝐇\eta\rho\mathbf{H}^{2}=\frac{1}{2}\eta^{2}\mathbf{H}^{2}\ll\eta\mathbf{H}italic_η italic_ρ bold_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_H start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≪ italic_η bold_H, and the gradient penalty only has a small effect on the overall dynamics. In many practical learning scenarios, including those involving SAM, ηλ𝜂𝜆\eta\lambdaitalic_η italic_λ can become O(1)𝑂1O(1)italic_O ( 1 ) for many eigenvalues during training Agarwala & Dauphin (2023). In these scenarios there will be qualitative differences between using penalty SAM and training with a different learning rate.

In addition, when ρ𝜌\rhoitalic_ρ is set arbitrarily, the dynamics of η𝜂\etaitalic_η and 2η2𝜂2\eta2 italic_η will no longer match to second order in η𝐇𝜂𝐇\eta\mathbf{H}italic_η bold_H. This provides further theoretical evidence that combining SGD with penalty SAM is qualitatively and quantitatively different from training with a larger learning rate.