Custom Gradient Estimators are Straight-Through Estimators in Disguise

Matt Schoenbauer
[email protected]
&Daniele Moro
Google
Mountain View, CA, 94043
[email protected]
&Lukasz Lew
Google
Mountain View, CA, 94043
[email protected]
&Andrew Howard
Google
Mountain View, CA, 94043
[email protected]
Abstract

Quantization-aware training comes with a fundamental challenge: the derivative of quantization functions such as rounding are zero almost everywhere and nonexistent elsewhere. Various differentiable approximations of quantization functions have been proposed to address this issue. In this paper, we prove that a large class of weight gradient estimators is approximately equivalent with the straight through estimator (STE). Specifically, after swap** in the STE and adjusting both the weight initialization and the learning rate in SGD, the model will train in almost exactly the same way as it did with the original gradient estimator. Moreover, we show that for adaptive learning rate algorithms like Adam, the same result can be seen without any modifications to the weight initialization and learning rate. These results reduce the burden of hyperparameter tuning for practitioners of QAT, as they can now confidently choose the STE for gradient estimation and ignore more complex gradient estimators. We experimentally show that these results hold for both a small convolutional model trained on the MNIST dataset and for a ResNet50 model trained on ImageNet.

1 Introduction

The importance of quantized deep learning. Quantized deep learning has gained significant attention as a means to address the demand for efficient deployment of deep neural networks on resource-constrained devices. Traditional deep learning models typically employ high-precision representations, consuming substantial computational resources and memory. Quantized deep learning techniques offer a compelling solution by reducing the precision of network parameters and activations. Although the Post-Training Quantization technique is easier to use to quantize any given model, Quantization-Aware Training (QAT) has been shown to provide higher quality results since quantized weights are updated throughout the training process [30].

Gradient estimators are needed in QAT. QAT encounters a problem where the derivatives of quantization functions are zero or nonexistent everywhere. To sidestep this problem, practitioners use approximations of the quantization functions (known as gradient estimators) for backpropagation. The straight-through estimator is a common choice for this, but many believe it is better for a gradient estimator to more closely approximate the rounding function. We show that this belief is misguided.

Our main contributions are as follows:

  1. 1.

    A proof under minimal assumptions that all nonzero weight gradient estimators lead to approximately equivalent weight movement for non-adaptive learning rate optimizers (SGD, SGD + Momentum, etc.) when the learning rate is sufficiently small, after a change to weight initialization and learning rates has been applied.

  2. 2.

    A proof that for adaptive learning rate optimizers (Adam, RMSProp, etc.) the same result holds without any need for adjustment to the learning rate and weight initialization.

  3. 3.

    Empirical evidence demonstrating this result on both a small deep neural networked train on MNIST and a larger ResNet50 model trained on ImageNet.

Value for practitioners: Our findings reduce the burden of hyperparameter tuning for QAT. Practitioners can now confidently choose the Straight Through Estimator [2] for gradient estimation and allocate their attention on problems like choosing the weight initialization scheme, learning rate, and optimization method.

2 Background and Related Work

The standard quantizer function. The core operation in QAT is the application of a quantizer function to weights and activations, which transforms continuous, high-precision values into discrete, lower-precision representations. Quantization functions act elementwise on weight tensors 𝐰𝐰\mathbf{w}bold_w, and can therefore be described by scalar functions on weights w𝑤witalic_w. While there are many options for the arrangement of quantized values [8, 19, 33, 31, 26], we will be focused on the most popular formulation, uniform quantization functions, which are defined by

Q(x):=Δround(clip(xΔ,l,u))whereclip(x,l,u)={lif x<l,xif lxu,uif x>u.formulae-sequenceassign𝑄𝑥Δroundclip𝑥Δ𝑙𝑢whereclip𝑥𝑙𝑢cases𝑙if 𝑥𝑙𝑥if 𝑙𝑥𝑢𝑢if 𝑥𝑢Q(x):=\Delta\cdot{\textrm{round}}\left(\textrm{clip}\left(\frac{x}{\Delta},l,u% \right)\right)\qquad\mathrm{where}\qquad\textrm{clip}(x,l,u)=\begin{cases}l&% \text{if }x<l,\\ x&\text{if }l\leq x\leq u,\\ u&\text{if }x>u.\end{cases}italic_Q ( italic_x ) := roman_Δ ⋅ round ( clip ( divide start_ARG italic_x end_ARG start_ARG roman_Δ end_ARG , italic_l , italic_u ) ) roman_where clip ( italic_x , italic_l , italic_u ) = { start_ROW start_CELL italic_l end_CELL start_CELL if italic_x < italic_l , end_CELL end_ROW start_ROW start_CELL italic_x end_CELL start_CELL if italic_l ≤ italic_x ≤ italic_u , end_CELL end_ROW start_ROW start_CELL italic_u end_CELL start_CELL if italic_x > italic_u . end_CELL end_ROW (1)

The problem of choosing ΔΔ\Deltaroman_Δ, l𝑙litalic_l, and u𝑢uitalic_u is well-researched, and we cover common approaches in Appendix A.

Boundary points. We will refer to the sets of quantizer input values that map to a single output value as quantization bins. The boundaries of these bins are known as boundary points. We will use w+subscript𝑤w_{+}italic_w start_POSTSUBSCRIPT + end_POSTSUBSCRIPT and wsubscript𝑤w_{-}italic_w start_POSTSUBSCRIPT - end_POSTSUBSCRIPT to refer to the lower and upper boundary points for the bin containing weight w𝑤witalic_w. One of these points must exist for each w𝑤witalic_w, but outside of the representable range (see Appendix A) of the quantizer only one of the two will exist. Note that w+w=Δsubscript𝑤subscript𝑤Δw_{+}-w_{-}=\Deltaitalic_w start_POSTSUBSCRIPT + end_POSTSUBSCRIPT - italic_w start_POSTSUBSCRIPT - end_POSTSUBSCRIPT = roman_Δ for all weights in the representable range.

The Straight Through Estimator. Because Q(x)=dQ/dxsuperscript𝑄𝑥𝑑𝑄𝑑𝑥Q^{\prime}(x)=dQ/dxitalic_Q start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) = italic_d italic_Q / italic_d italic_x is zero almost everywhere and nonexistent elsewhere, vanilla gradient descent would never update the weights of a quantized model. The standard approach for addressing this issue is to approximate Q(x)𝑄𝑥Q(x)italic_Q ( italic_x ) by a differentiable surrogate function Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG and use its gradient Q^(x)superscript^𝑄𝑥\hat{Q}^{\prime}(x)over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) for backpropagation. The derivative Q^superscript^𝑄\hat{Q}^{\prime}over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is known as a gradient estimator (or gradient approximation). The earliest popular choice of gradient estimator is known as the straight-through estimator [17, 2] or STE, defined by Q^(x)=x^𝑄𝑥𝑥\hat{Q}(x)=xover^ start_ARG italic_Q end_ARG ( italic_x ) = italic_x, Q^(x)=1superscript^𝑄𝑥1\hat{Q}^{\prime}(x)=1over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) = 1.

Piecewise linear estimators. Piecewise linear (PWL) estimators have derivative I[wmin,wmax]subscript𝐼subscript𝑤𝑚𝑖𝑛subscript𝑤𝑚𝑎𝑥I_{[w_{min},w_{max}]}italic_I start_POSTSUBSCRIPT [ italic_w start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT, where I𝐼Iitalic_I is the indicator function. They make Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG more closely resemble Q𝑄Qitalic_Q [36, 18, 53]. The simplest way to define a PWL estimator for a multi-bit quantizer is to simply use Equation 1 with the round step removed, and in this case [wmin,wmax]subscript𝑤𝑚𝑖𝑛subscript𝑤𝑚𝑎𝑥[w_{min},w_{max}][ italic_w start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT ] is exactly the representable range. This way, the behavior of PWL estimators more closely relate to the quantization function. In general, we will use PWLwmin,wmax(x)=clip(x,wmin,wmax)𝑃𝑊subscript𝐿subscript𝑤𝑚𝑖𝑛subscript𝑤𝑚𝑎𝑥𝑥clip𝑥subscript𝑤𝑚𝑖𝑛subscript𝑤𝑚𝑎𝑥PWL_{w_{min},w_{max}}(x)=\textrm{clip}(x,w_{min},w_{max})italic_P italic_W italic_L start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x ) = clip ( italic_x , italic_w start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT ) to denote a PWL gradient estimator.

STE and PWL lead to “gradient error". The simple STE and PWL gradient estimators described above still leave a significant gap between the behavior of the forward pass and the surrogate forward pass. For this reason, researchers have proposed a large number of custom gradient estimators, often citing a high “gradient error" in the simpler choices of gradient estimators as motivation for their work. Gradient error is often described as the difference between Q𝑄Qitalic_Q and Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG.

An abundance of custom gradient estimators. In order to solve the perceived problem of gradient error, many researchers have proposed gradient estimators that carry more complexity than the STE or PWL estimators. In Appendix B, we cite and describe 15 examples of custom gradient estimators in the quantization literature. Plots of some prominent examples are given in Figure 1.

Refer to caption

Figure 1: Gradient Estimators from left to right: STE [17], PWL [18], MAD [41], HTGE [32], EDE [35]. The EDE is for binary quantization, and the others are for multi-bit quantization.

3 Gradient Descent Terminology for QAT

For a quantized model with gradient estimator Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG, the gradient value at step t𝑡titalic_t is f(Q(w(t)))Q^(w(t))𝑓𝑄superscript𝑤𝑡superscript^𝑄superscript𝑤𝑡\nabla f(Q(w^{(t)}))\hat{Q}^{\prime}(w^{(t)})∇ italic_f ( italic_Q ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ), where f𝑓fitalic_f is the loss function of the model. Of course f𝑓fitalic_f depends on the dataset and all other network weights, but we suppress this for notational convenience. Going forward, we will abbreviate f(Q(w(t)))𝑓𝑄superscript𝑤𝑡\nabla f(Q(w^{(t)}))∇ italic_f ( italic_Q ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) as f(t)superscript𝑓𝑡\nabla f^{(t)}∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT. The weight update is expressed as

w(t+1)=w(t)+g(t)(f(0)Q^(w(0)),,f(t)Q^(w(t)),η).superscript𝑤𝑡1superscript𝑤𝑡superscript𝑔𝑡superscript𝑓0superscript^𝑄superscript𝑤0superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡𝜂w^{(t+1)}=w^{(t)}+g^{(t)}(\nabla f^{(0)}\hat{Q}^{\prime}(w^{(0)}),\ldots,% \nabla f^{(t)}\hat{Q}^{\prime}(w^{(t)}),\eta).italic_w start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) . (2)

where η𝜂\etaitalic_η is the learning rate. The notation for g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT is borrowed from [1]. By defining g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT, we can recover all of the standard gradient descent algorithms, i.e. SGD, Adam, RMSProp, etc. In the simplest case, we have g(t)(f(t)Q^(w(t)),η)=ηf(t)Q^(w(t))superscript𝑔𝑡superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡𝜂𝜂superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡g^{(t)}(\nabla f^{(t)}\hat{Q}^{\prime}(w^{(t)}),\eta)=-\eta\nabla f^{(t)}\hat{% Q}^{\prime}(w^{(t)})italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) = - italic_η ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ), which gives us the common SGD learning rule

w(t+1)=w(t)ηf(t)Q^(w(t)).superscript𝑤𝑡1superscript𝑤𝑡𝜂superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡w^{(t+1)}=w^{(t)}-\eta\nabla f^{(t)}\hat{Q}^{\prime}(w^{(t)}).italic_w start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_η ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) . (3)

The definition of g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT for SGD with momentum is given in Appendix D. A more complex but highly popular learning rule is the Adam [22] optimizer, which is defined with the above notation in Appendix E.

Adaptive and non-adaptive algorithms. Adam is an example of an adaptive learning rate algorithm, since the weight update steps are normalized by a computation on past gradient values. Other examples of adaptive learning rate methods are RMSprop [17], Adadelta [50], AdaMax [22], and AdamW [29], We refer to all other update rules, such SGD and SGD with momentum [34], as non-adaptive learning rate algorithms.

4 Intuition

To aid the reader in develo** intuition about our main results, we tell a brief story below.

The Mirror Room story. Imagine you are in a room with a glass wall. On the other side of the glass wall, there is a person in another room, larger than yours. You are standing at different positions in your respective rooms. Any time you take a step, this other person takes a step in the same direction, albeit with a different step length. You continue to move around, and you are rarely exactly across from this person, but any time you try to leave, this person leaves the room on the same side at the same time.

You realize that the glass wall is not a wall, it’s a funhouse mirror. The person on the other side is you, but the picture is “warped" by the mirror.

The Mirror Room is the quantization bin for two equivalent models. The scenario described above is similar to the relationship between the motion of weights in a model (Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net) that uses a complex gradient estimator Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG and another (STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net) that uses the STE𝑆𝑇𝐸STEitalic_S italic_T italic_E with the proper reconfigurations to match Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net. In the analogy, you are a weight in STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net, your reflection is the weight in Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net. The room is a quantization bin, and the doors are the boundary points. The simultaneous exit of you and your reflection from the room parallels the synchronized quantized weights in both models, leading to identical gradients and training outcomes.

Refer to caption
Figure 2: The funhouse mirror. The blue figure represents you (a weight in STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net), and the red figure represents your reflection (a weight in Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net) on the other side. The reflections line up at the edge of the room.

The “Funhouse Mirror" effect of M𝑀Mitalic_M and Q^^𝑄{\hat{Q}}over^ start_ARG italic_Q end_ARG. In Section 5, we define a map M𝑀Mitalic_M that acts as a “funhouse mirror" map** the weights of Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net to those of STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net. Any initial weight w(0)superscript𝑤0w^{(0)}italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT in Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net is re-initialized to M(w(0))𝑀superscript𝑤0M(w^{(0)})italic_M ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) in STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net, and the relationship M(wQ^)=wSTE𝑀subscript𝑤^𝑄subscript𝑤𝑆𝑇𝐸M(w_{\hat{Q}})=w_{STE}italic_M ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT ) = italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT approximately holds throughout training, where wQ^subscript𝑤^𝑄w_{\hat{Q}}italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT is a weight in Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net, and wSTEsubscript𝑤𝑆𝑇𝐸w_{STE}italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT is the corresponding weight in STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net. Thus after the Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net weight takes a step, the STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net weight moves in near lockstep after passing through the “funhouse mirror" of M𝑀Mitalic_M. Furthermore, since M(w)=w𝑀𝑤𝑤M(w)=witalic_M ( italic_w ) = italic_w whenever w𝑤witalic_w is a boundary point, these two weights will cross the quantization boudaries at nearly the same time. The bisimulation of the two models is justified by this property.

A visualization of the funhouse mirror is given in Figure 2.

5 Main Results

In this section we formalize the realizations of Section 4 and provide our main mathematical results (1 and 2). Furthermore, this will show that much of the concern about “gradient error" is unfounded. We provide Theorem statements for both the SGD update rule and the Adam update rule, with proofs and generalizations in the Appendices. Note that all of the below results apply to weight quantizers. We do not address activation quantizers in this work.

5.1 Definitions and Notation

Cyclical gradient estimators. We say that a gradient estimator Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG for a uniform quantizer Q𝑄Qitalic_Q is cyclical if Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG is identical on each finite-length quantization bin, i.e. Q^(w)=Q^(w+Δ)superscript^𝑄𝑤superscript^𝑄𝑤Δ\hat{Q}^{\prime}(w)=\hat{Q}^{\prime}(w+\Delta)over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ) = over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w + roman_Δ ) whenever w𝑤witalic_w and w+Δ𝑤Δw+\Deltaitalic_w + roman_Δ are inside a finite-length quantization bin (i.e. within the representable range). Most multi-bit gradient estimators proposed in the literature are cyclical. Binary gradient estimators are cyclical by default, since they have no finite quantization bins. Unless otherwise specified, we will assume that all gradient estimators are cyclical.

Definitions of α𝛼\alphaitalic_α and M𝑀Mitalic_M. We give two more definitions before presenting the details of the models we are comparing. These objects (α𝛼\alphaitalic_α and M𝑀Mitalic_M) will allow us to succinctly express the learning rate update and weight initialization update needed to mimic the behavior of a positive gradient estimator Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG using only the STE. If Q𝑄Qitalic_Q is a uniform multi-bit quantizer and Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG is cyclical, we define the learning rate adjustment factor α𝛼\alphaitalic_α and weight readjustment map M𝑀Mitalic_M:

α:=Δww+dsQ^(s)M(w):=wb+αwbwdsQ^(s)formulae-sequenceassign𝛼Δsuperscriptsubscriptsubscript𝑤subscript𝑤𝑑𝑠superscript^𝑄𝑠assign𝑀𝑤subscript𝑤𝑏𝛼superscriptsubscriptsubscript𝑤𝑏𝑤𝑑𝑠superscript^𝑄𝑠\alpha:=\frac{\Delta}{\int_{w_{-}}^{w_{+}}\frac{ds}{\hat{Q}^{\prime}(s)}}% \qquad\qquad M(w):=w_{b}+\alpha\int_{w_{b}}^{w}\frac{ds}{\hat{Q}^{\prime}(s)}italic_α := divide start_ARG roman_Δ end_ARG start_ARG ∫ start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUPERSCRIPT divide start_ARG italic_d italic_s end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_s ) end_ARG end_ARG italic_M ( italic_w ) := italic_w start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT + italic_α ∫ start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT divide start_ARG italic_d italic_s end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_s ) end_ARG (4)

Here w+subscript𝑤w_{+}italic_w start_POSTSUBSCRIPT + end_POSTSUBSCRIPT and wsubscript𝑤w_{-}italic_w start_POSTSUBSCRIPT - end_POSTSUBSCRIPT are adjacent boundary points, and wbsubscript𝑤𝑏w_{b}italic_w start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT is any standalone boundary point. Since Q𝑄Qitalic_Q is uniform and Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG is cyclical, the definition of α𝛼\alphaitalic_α is independent of the choice of boundary points. If Q𝑄Qitalic_Q is a binary quantizer, then Q𝑄Qitalic_Q has only one boundary point, and we define α:=1assign𝛼1\alpha:=1italic_α := 1. Note that α𝛼\alphaitalic_α is defined entirely by Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG, and can be computed at the outset of training. It may vary per-layer if the parameters of Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG do so. Intuitively it can be thought of as the ratio between the quantization bin size (ΔΔ\Deltaroman_Δ) and the “effective bin size" of a gradient estimator Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG (the denominator of Equation 4). The definition of M𝑀Mitalic_M is independent of the choice of wbsubscript𝑤𝑏w_{b}italic_w start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT. We can think of M𝑀Mitalic_M as a function that maps a weight w𝑤witalic_w to a new point M(w)𝑀𝑤M(w)italic_M ( italic_w ) whose relative distance from its left and right boundaries matches the relative “effective distance" (under Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG) between the boundary points and the original weight w𝑤witalic_w.

Definition of Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net. For both optimization techniques we consider (SGD and Adam) we will study two models, Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net. The models can have any architecture, as long as they are equivalent. We will focus on corresponding weights wQ^(t)superscriptsubscript𝑤^𝑄𝑡w_{\hat{Q}}^{(t)}italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and wSTE(t)superscriptsubscript𝑤𝑆𝑇𝐸𝑡w_{STE}^{(t)}italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT, respectively, at iteration t𝑡titalic_t. We will denote the gradients of the loss function f𝑓fitalic_f with respect to Q(wQ^(t))𝑄superscriptsubscript𝑤^𝑄𝑡Q(w_{\hat{Q}}^{(t)})italic_Q ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) and Q(wSTE(t))𝑄superscriptsubscript𝑤𝑆𝑇𝐸𝑡Q(w_{STE}^{(t)})italic_Q ( italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) as fQ^(t)superscriptsubscript𝑓^𝑄𝑡\nabla f_{\hat{Q}}^{(t)}∇ italic_f start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and fSTE(t)superscriptsubscript𝑓𝑆𝑇𝐸𝑡\nabla f_{STE}^{(t)}∇ italic_f start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT, respectively. The differences in gradient estimators, learning rates and weight initialization for both SGD and Adam are given in Tables 2 and 2, respectively.

Table 1: Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E Models for SGD
Model Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net
Gradient Estimators Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG STE𝑆𝑇𝐸STEitalic_S italic_T italic_E
Learning Rates η𝜂\etaitalic_η αη𝛼𝜂\alpha\etaitalic_α italic_η
Initial Weights wQ^(0)superscriptsubscript𝑤^𝑄0w_{\hat{Q}}^{(0)}italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT M(wQ^(0))𝑀superscriptsubscript𝑤^𝑄0M(w_{\hat{Q}}^{(0)})italic_M ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT )
Table 2: Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E Models for Adam
Model Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net
Gradient Estimators Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG STE𝑆𝑇𝐸STEitalic_S italic_T italic_E
Learning Rates η𝜂\etaitalic_η η𝜂\etaitalic_η
Initial Weights wQ^(0)superscriptsubscript𝑤^𝑄0w_{\hat{Q}}^{(0)}italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT wQ^(0)superscriptsubscript𝑤^𝑄0w_{\hat{Q}}^{(0)}italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT

Comparison Metric. We can quantify how the weights between Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net differ using weight alignment error, which is defined as

E(t):=|M(wQ^(t))wSTE(t)|forSGD,andE(t):=|wQ^(t)wSTE(t)|forAdam.formulae-sequenceassignsuperscript𝐸𝑡𝑀subscriptsuperscript𝑤𝑡^𝑄subscriptsuperscript𝑤𝑡𝑆𝑇𝐸forSGDandassignsuperscript𝐸𝑡subscriptsuperscript𝑤𝑡^𝑄subscriptsuperscript𝑤𝑡𝑆𝑇𝐸forAdamE^{(t)}:=\left|M\left(w^{(t)}_{\hat{Q}}\right)-w^{(t)}_{STE}\right|\quad% \mathrm{for\ SGD,\ and}\quad E^{(t)}:=\left|w^{(t)}_{\hat{Q}}-w^{(t)}_{STE}% \right|\quad\mathrm{for\ Adam.}italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT := | italic_M ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT ) - italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT | roman_for roman_SGD , roman_and italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT := | italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT - italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT | roman_for roman_Adam . (5)

E(t)superscript𝐸𝑡E^{(t)}italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT measures how far off the weights are between the two models at iteration t𝑡titalic_t, and starts at E(0)=0superscript𝐸00E^{(0)}=0italic_E start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT = 0 due to our choice of initial weights in Tables 2 and 2. Furthermore, since M𝑀Mitalic_M preserves quantization bins, we have that Q(wQ^(t))=Q(wQ^(t))𝑄subscriptsuperscript𝑤𝑡^𝑄𝑄subscriptsuperscript𝑤𝑡^𝑄Q(w^{(t)}_{\hat{Q}})=Q(w^{(t)}_{\hat{Q}})italic_Q ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT ) = italic_Q ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT ) whenever E(t)superscript𝐸𝑡E^{(t)}italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT is small.

5.2 Theorem Statements

Theorem 5.1 rigorously states contribution 1 for the SGD update rule (Equation 3). It states that after adjusting the learning rate of a model by α𝛼\alphaitalic_α and re-initializing the weights by applying M(w)𝑀𝑤M(w)italic_M ( italic_w ), a positive gradient estimator Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG can be replaced by the STE with minimal differences in training.

Theorem 5.1.

Suppose that E(t)superscript𝐸𝑡E^{(t)}italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT is the alignment error for Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net with SGD (Table 2). Assume that the following hold:

  1. 5.1.1

    0<L|Q^(w)|L+0subscript𝐿superscript^𝑄𝑤subscript𝐿0<L_{-}\leq|\hat{Q}^{\prime}(w)|\leq L_{+}0 < italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT ≤ | over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ) | ≤ italic_L start_POSTSUBSCRIPT + end_POSTSUBSCRIPT for all w𝑤witalic_w. (Bounded, positive gradient estimator)

  2. 5.1.2

    Q^(w)superscript^𝑄𝑤\hat{Q}^{\prime}(w)over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ) is Lsuperscript𝐿L^{\prime}italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT-Lipschitz. (Well-behaved gradient estimator)

Then we have

E(t+1)E(t)+ηα|fQ^(t)fSTE(t)|gradient error+L2(ηL+fQ^(t)L)2convexity errorsuperscript𝐸𝑡1superscript𝐸𝑡subscript𝜂𝛼superscriptsubscript𝑓^𝑄𝑡superscriptsubscript𝑓𝑆𝑇𝐸𝑡gradient errorsubscriptsuperscript𝐿2superscript𝜂subscript𝐿superscriptsubscript𝑓^𝑄𝑡subscript𝐿2convexity errorE^{(t+1)}\leq E^{(t)}+\underbrace{\eta\alpha\left|\nabla f_{\hat{Q}}^{(t)}-% \nabla f_{STE}^{(t)}\right|}_{\text{gradient error}}+\underbrace{\frac{L^{% \prime}}{2}\cdot\left(\frac{\eta L_{+}\nabla f_{\hat{Q}}^{(t)}}{L_{-}}\right)^% {2}}_{\text{convexity error}}italic_E start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ≤ italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + under⏟ start_ARG italic_η italic_α | ∇ italic_f start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - ∇ italic_f start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | end_ARG start_POSTSUBSCRIPT gradient error end_POSTSUBSCRIPT + under⏟ start_ARG divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ⋅ ( divide start_ARG italic_η italic_L start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ∇ italic_f start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT convexity error end_POSTSUBSCRIPT (6)

See Appendix C.1 for a rigorous proof. The theorem only considers the standard gradient descent process. For a similar statement for a more general class of non-adaptive learning rate optimizers, see Appendix C.1. See Appendix D for a more specific result for SGD with momentum.

Theorem 5.2 rigorously proves contribution 2 for the Adam update rule (Equations 57-61). The result here is stronger than Theorem 5.1. When using the Adam update rule, the gradient estimator Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG can be replaced by the STE without any update to the learning rate or weight initialization.

Theorem 5.2.

Suppose that E(t)superscript𝐸𝑡E^{(t)}italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT is the alignment error for Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net with Adam (Table 2). Assume that the following hold:

  1. 5.2.1

    0<LQ^(w)0subscript𝐿superscript^𝑄𝑤0<L_{-}\leq\hat{Q}^{\prime}(w)0 < italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT ≤ over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ) for all w𝑤witalic_w. (Lower bounded positive gradient estimator)

  2. 5.2.2

    Q^(w)superscript^𝑄𝑤\hat{Q}^{\prime}(w)over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ) is Lsuperscript𝐿L^{\prime}italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT-Lipschitz. (Well-behaved gradient estimator)

Then we have

E(t+1)E(t)+|g(t)(fQ^(0),,fQ^(t),η)g(t)(fSTE(0),,fSTE(t),η)|gradient error+O(η2)convexity error,superscript𝐸𝑡1superscript𝐸𝑡subscriptsuperscript𝑔𝑡subscriptsuperscript𝑓0^𝑄subscriptsuperscript𝑓𝑡^𝑄𝜂superscript𝑔𝑡subscriptsuperscript𝑓0𝑆𝑇𝐸subscriptsuperscript𝑓𝑡𝑆𝑇𝐸𝜂gradient errorsubscript𝑂superscript𝜂2convexity errorE^{(t+1)}\leq E^{(t)}+\underbrace{\left|g^{(t)}(\nabla f^{(0)}_{\hat{Q}},% \ldots,\nabla f^{(t)}_{\hat{Q}},\eta)-g^{(t)}(\nabla f^{(0)}_{STE},\ldots,% \nabla f^{(t)}_{STE},\eta)\right|}_{\text{gradient error}}+\underbrace{O(\eta^% {2})}_{\text{convexity error}},italic_E start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ≤ italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + under⏟ start_ARG | italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , italic_η ) - italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , italic_η ) | end_ARG start_POSTSUBSCRIPT gradient error end_POSTSUBSCRIPT + under⏟ start_ARG italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT convexity error end_POSTSUBSCRIPT , (7)

where g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT is the gradient update rule for Adam (see Equation 2 and Equations 57-61).

See Appendix E for a rigorous proof. In Theorem 5.2, the exact definition of the O(η2)𝑂superscript𝜂2O(\eta^{2})italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) term is omitted due to its complexity. For a similar statement for a more general class of non-adaptive learning rate optimizers (not just the Adam optimizer), see Appendix E. For a discussion of Theorems 5.1 and 5.2 for learning rate schedules, see Appendix F.

5.3 On the Assumptions and Implications of Theorems 5.1 and 5.2

Theorems 5.1 and 5.2 rely on specific assumptions about the gradient estimator Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG. In this section, we break down these assumptions clearly. Furthermore, we describe how these theorems imply contributions 1 and 2.

The assumptions are reasonable: The upper bound on Q^superscript^𝑄\hat{Q}^{\prime}over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT in Assumption 5.1.1 is very mild. Gradient estimators with an unbounded derivative would likely cause training instability, and are not used in practice. Similarly, the authors are not aware of a gradient estimator that breaks Assumptions 5.1.2 and 5.2.2. In addition, the constants Lsubscript𝐿L_{-}italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT, L+subscript𝐿L_{+}italic_L start_POSTSUBSCRIPT + end_POSTSUBSCRIPT, and Lsuperscript𝐿L^{\prime}italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT are usually quite small in practice (see Appendix H for calculations). The lower bound on Q^superscript^𝑄\hat{Q}^{\prime}over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT in Assumptions 5.1.1 and 5.2.1, however, is often broken in practice. In Appendix G, we describe how the Theorems still support contributions 1 and 2 in these cases.

The bounds in Equations 6 and 7 are small: In order to see how Theorems 5.1 and 5.2 provide contributions 1 and 2, we can closely examine each term in Equations 6 and 7. The gradient and convexity error in each equation together give a worst-case increase to E(t)superscript𝐸𝑡E^{(t)}italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT at each training step. That is, as long as these terms are small, Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net will train in a very similar manner. The convexity error terms are unavoidable errors, and are extremely small (O(η2)𝑂superscript𝜂2O(\eta^{2})italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )) in practice. The gradient error terms, however, are O(η)𝑂𝜂O(\eta)italic_O ( italic_η ), so they can be large if the gradients of the two models are misaligned. However, since the gradient terms fQ^(t)superscriptsubscript𝑓^𝑄𝑡\nabla f_{\hat{Q}}^{(t)}∇ italic_f start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT and fSTE(t)superscriptsubscript𝑓𝑆𝑇𝐸𝑡\nabla f_{STE}^{(t)}∇ italic_f start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT only depend on quantized weights, these terms will be zero at the beginning of training and remain small as long as E(t)superscript𝐸𝑡E^{(t)}italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT remains small.

The claim is nontrivial: Note that these theorems do not simply say that when the learning rate is small, the models change very little, and therefore Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net are aligned. Since the irreducible error term is quadratic in η𝜂\etaitalic_η, the misalignment at each step is small relative to the learning rate itself.

The claim applies to networks of any size: The Theorems only give bounds for the error in a single network weight, but can be applied to each weight independently in a multi-weight network. Of course, the trajectories of weights in a neural network are not independent, but luckily in our case the weight trajectories only depend on the quantized versions of the other network weights. To see this, note that the only terms in Equations 6 and 7 that depend on other network weights are the gradient error terms. As stated earlier, these gradient terms only depend on quantized weights, so we do not need perfect alignment in other network weights in order to keep the error terms in these Equations small. Since the gradient error terms can depend on all other quantized weights in the network, larger models are at a greater risk of weight misalignment. However, this is more a property of large models than of gradient estimators: any two large models that have only a small difference in hyperparameter configurations but otherwise equivalent training setups will have potentially large step-by-step divergences in weight alignment. And the fundamental difference in training induced by a gradient estimator is indeed small, since in Equations 6 and 7, the true source of all misalignment is an O(η2)𝑂superscript𝜂2O(\eta^{2})italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) term. This is supported by our experiments in Section 6.

6 Experimental Results

Here we demonstrate our main results on practical models. The general strategy we will take is to implement Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net for a specific model architecture and compare on a variety of metrics to demonstrate the following:

  1. A.

    Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net train in almost exactly the same way.

  2. B.

    If we do not apply the weight re-initialization of Theorem 5.1, we do not see the same results.

6.1 Models and Training Setup

Models and Quantizers. We use two model architecture/dataset pairs:

  1. 1.

    A simple three-layer quantized convoluational archicture proposed in [4] for image classification on the MNIST dataset, which gives a uniform weight distribution with the variance recommended in [14] trained on a CPU.

  2. 2.

    ResNet50 [15] on the ILSVRC 2012 ImageNet dataset [7], which showcases generality to a more complex model and dataset trained on a TPU. We used a fully deterministic version of the Flax example library [10].

Gradient estimator and Optimizers: We quantize weights using a 2-bit uniform quantizer, and for gradient estimation, we use the Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG given by the HTGE formula [32]. See Appendix B for our justification of this choice. For optimization techniques on both models, we consider both SGD and Adam. We use a learning rate of 0.001 for SGD, and 0.0001 for Adam. All models are trained with weight initialization and learning rate adjustments given by Tables 2 and 2. For more details on the training recipe and quantizers, see Appendix I.

6.2 Metrics.

We use two metrics in order to establish Points A and B. Both of these compare STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net weights to Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net weights. In addition to the metrics below, we also report accuracy and loss statistics for all models.

Quantized Weight Agreement. At the end of training the complete set of quantized weights is calculated for both models and compared. We report the proportion of quantized weights that are the same for both models.

Normalized Weight Alignment Error (𝐄¯¯𝐄\mathbf{\bar{E}}over¯ start_ARG bold_E end_ARG). For each pair of models, we compute the average value of E(T)superscript𝐸𝑇E^{(T)}italic_E start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT for the final training step T𝑇Titalic_T over all weights. Note that Equation 5 gives two definitions of E𝐸Eitalic_E, and for each model pair we use the version that matches the weight initialization setup, which gives E(0)=0superscript𝐸00E^{(0)}=0italic_E start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT = 0 for all model pairs. Each E(T)superscript𝐸𝑇E^{(T)}italic_E start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT is normalized by the length of the representable range, so that a value of 100% indicates that the two models’ weights are on opposite sides of the representable range. We denote the average as E¯¯𝐸\bar{E}over¯ start_ARG italic_E end_ARG for all model pairs.

6.3 Results

Experiment Name Experiment Description 𝐄¯¯𝐄\mathbf{\bar{E}}over¯ start_ARG bold_E end_ARG Interpretation/Comparison to Baseline
  baseline Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG vs. STE 0.515% Baseline
  lr-tweak Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG vs. Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG with 1% learning rate increase 0.572% Replacing STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net with Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net is about as impactful as a small change to η𝜂\etaitalic_η (A).
unadjusted Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG vs. STE without reinitializing weights 2.52% The two models only see the same weight movement if weights are re-initialized according to M𝑀Mitalic_M (B).
Table 3: Normalized weight alignment metric E¯¯𝐸\bar{E}over¯ start_ARG italic_E end_ARG for MNIST model with SGD + Momentum, including descriptions and interpretations for all four experiment types. This table serves as a guide for interpreting Table 4.
Experiment Name 𝐄¯¯𝐄\mathbf{\bar{E}}over¯ start_ARG bold_E end_ARG Quantized Weight Agreement
  baseline (S) 0.515% 98.31%
  lr-tweak (S) 0.572% 98.66%
unadjusted (S) 2.52% 96.53%
  baseline (A) 2.81% 94.42%
  lr-tweak (A) 1.74% 95.4%
Experiment Name 𝐄¯¯𝐄\mathbf{\bar{E}}over¯ start_ARG bold_E end_ARG Quantized Weight Agreement
  baseline (S) 5.42% 68.94%
  lr-tweak (S) 5.46% 75.64%
unadjusted (S) 7.88% 67.53%
  baseline (A) 7.18% 72.22%
  lr-tweak (A) 4.99% 76.32%
Table 4: Alignment metrics for SGD (S) and Adam (A). Results for the MNIST model are shown on the left, and results for ResNet50 trained on ImageNet are shown on the right.
Train acc Train loss Val acc Val loss
STE (S) 97.05% 0.1439 97.08% 0.1417
Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG (S) 96.98% 0.1483 97.14% 0.1468
Diff -0.06% 0.0044 0.06% 0.0051
STE (A) 97.56% 0.1270 97.66% 0.1257
Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG (A) 97.63% 0.1254 97.58% 0.1245
Diff 0.07% -0.0016 -0.08% -0.0013
Train acc Train loss Val acc Val loss
STE (S) 68.94% 1.3370 69.83% 1.2227
Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG (S) 68.51% 1.3365 68.77% 1.2793
Diff 0.43% 0.0005 -1.06% -0.0566
STE (A) 69.78% 1.2876 70.01% 1.2209
Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG (A) 69.02% 1.3153 69.37% 1.2490
Diff -0.77% 0.0277 -0.65% 0.0281
Table 5: Loss and Accuracy differences between Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net with SGD (S) and Adam (A). Results for the MNIST model are shown on the left, and results for ResNet50 trained on ImageNet are shown on the right. For both SGD (S) and Adam (A) and both models, differences are small.
Refer to caption
(a) Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net weights vs STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net weights for MNIST convolutional model at the conclusion of training for default SGD.
Refer to caption
(b) Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net weights vs STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net weights at the conclusion of training without re-initializing STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net weights.

Tables for Points A and B: We provide all metrics for both the default SGD and Adam models described in Section 6.1 within in Table 4, with detailed interpretations for the E¯¯𝐸\bar{E}over¯ start_ARG italic_E end_ARG metric in Table 3. Note that Adam does not have an “unadjusted" case, since there is no need for weight initialization adjustment when Adam is used.

Point A is validated. The standard comparison between Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net is labeled as “baseline". We compute metrics between a Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net model and the same model with a learning rate increase of 1% (chosen arbitrarily and only once), reported with the label “lr-tweak". This serves as an example of a “small change" to a model that the reader may be more familiar with, providing additional context about the scale of the metric results and supporting Point A. For both the MNIST and ImageNet models, the alignment between Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net is similar to the alignment expected from a 1% learning rate change.

Point B is validated. We report alignment measurements between Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net without the weight and learning rate adjustments described in Theorem 5.1 using the label “unadjusted". The alignment worsens for both the MNIST model and the ResNet model when removing the weight reinitialization by M𝑀Mitalic_M.

Weight Alignment. For a visual of the weight alignment phenomenon, see Figures 3(a) and 3(b).

There is almost no difference in training accuracy. Standard training metrics for both Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net are given in Table 5 for both optimizers and both models we consider. This table shows that the two models have very similar train and test metrics, indicating that replacing Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG with the STE is of minimal impact after applying the appropriate weight initialization and learning rate adjustments. As expected, the alignment is stronger for the smaller model.

7 Implications

Here we discuss the implications of this work on the existing literature and future practice and research.

For practitioners. The main message for practitioners is simple, and depends on the optimization strategy used as follows:

  • SGD and other non-adaptive optimizers: In this case, if the learning rate is sufficiently small and you wish to tweak the gradient estimator, you can instead apply a corresponding weight re-initialization and learning rate adjustment to a model with the STE or PWL estimator and see nearly the same training procedure. The proof and related assumptions are given in Theorem C.1.

  • Adam and other adaptive optimizers: In this case, when the learning rate is sufficiently small, the only gradient estimators you need consider are the STE and PWL estimators. The proof and related assumptions are given in Theorem E.1.

For researchers. For future research, we hope that this work will inspire further study on processes for updating quantized model parameters that are fundamentally different from the use of gradient estimators, and therefore immune to the arguments of this paper. This may include novel computations on gradients that diverge from the standard chain rule [23, 45], optimizers specially designed for QAT [16], or even methods that do not involve gradient computations at all [44]. As for the existing literature, our message is that the concern about “gradient error" should not be considered in the future.

Why are so many gradient estimators published? A natural question that a reader may have concerning past research is this: If the choice of gradient estimator is so irrelevant, why is there so much research that proposes new gradient estimators and demonstrates improved performance with their aid? There are several potential answers to this. The simplest explanation is that their gradient estimation techniques happen to have implictly uncovered a superior weight re-initialization and learning rate adjustment, as indicated by Theorem 5.1. The more applicable answer is that nearly all of these studies propose more than simply a new gradient estimator (as described in Appendix B), and so the results can be due to multiple different contributions. Another answer could be that the performance improvements were due to changes in quantized activation gradient estimators, which cannot be equated to the STE. A final answer could be that the learning rates in their experiments were too high to see an equivalence between their gradient estimators and the STE. This is a limitation of our main argument, but we expect that this counter-argument will not stand the test of time, since by our main results, the higher learning rate masks the fact that models with novel Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG and the STE are still approximating the same process.

References

  • [1] Marcin Andrychowicz, Misha Denil, Sergio Gomez, Matthew W Hoffman, David Pfau, Tom Schaul, Brendan Shillingford, and Nando De Freitas. Learning to learn by gradient descent by gradient descent. Advances in neural information processing systems, 29, 2016.
  • [2] Yoshua Bengio, Nicholas Léonard, and Aaron Courville. Estimating or propagating gradients through stochastic neurons for conditional computation. arXiv preprint arXiv:1308.3432, 2013.
  • [3] Jungwook Choi, Zhuo Wang, Swagath Venkataramani, Pierce I-Jen Chuang, Vijayalakshmi Srinivasan, and Kailash Gopalakrishnan. Pact: Parameterized clip** activation for quantized neural networks. arXiv preprint arXiv:1805.06085, 2018.
  • [4] Francois Chollet. Deep learning with Python. Simon and Schuster, 2021.
  • [5] Sajad Darabi, Mouloud Belbahri, Matthieu Courbariaux, and Vahid Partovi Nia. Regularized binary network training. arXiv preprint arXiv:1812.11800, 2018.
  • [6] Christian Darken, Joseph Chang, John Moody, et al. Learning rate schedules for faster stochastic gradient search. In Neural networks for signal processing, volume 2, pages 3–12. Citeseer, 1992.
  • [7] Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pages 248–255. Ieee, 2009.
  • [8] Tim Dettmers, Artidoro Pagnoni, Ari Holtzman, and Luke Zettlemoyer. Qlora: Efficient finetuning of quantized llms. arXiv preprint arXiv:2305.14314, 2023.
  • [9] Steven K Esser, Jeffrey L McKinstry, Deepika Bablani, Rathinakumar Appuswamy, and Dharmendra S Modha. Learned step size quantization. arXiv preprint arXiv:1902.08153, 2019.
  • [10] The Flax contributors. Flax imagenet example. https://github.com/google/flax/tree/main/examples/imagenet, 2024. Original implementation of ImageNet example in Flax.
  • [11] Amir Gholami, Sehoon Kim, Zhen Dong, Zhewei Yao, Michael W. Mahoney, and Kurt Keutzer. A survey of quantization methods for efficient neural network inference. CoRR, abs/2103.13630, 2021.
  • [12] Ruihao Gong, Xianglong Liu, Shenghu Jiang, Tianxiang Li, Peng Hu, Jiazhen Lin, Fengwei Yu, and Junjie Yan. Differentiable soft quantization: Bridging full-precision and low-bit neural networks. In 2019 IEEE/CVF International Conference on Computer Vision, ICCV 2019, Seoul, Korea (South), October 27 - November 2, 2019, pages 4851–4860. IEEE, 2019.
  • [13] Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. Accurate, large minibatch sgd: Training imagenet in 1 hour. arXiv preprint arXiv:1706.02677, 2017.
  • [14] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Delving deep into rectifiers: Surpassing human-level performance on imagenet classification. In Proceedings of the IEEE international conference on computer vision, pages 1026–1034, 2015.
  • [15] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.
  • [16] Koen Helwegen, James Widdicombe, Lukas Geiger, Zechun Liu, Kwang-Ting Cheng, and Roeland Nusselder. Latent weights do not exist: Rethinking binarized neural network optimization. Advances in neural information processing systems, 32, 2019.
  • [17] Geoffrey Hinton. COURSERA: Neural networks for machine learning, 2012.
  • [18] Itay Hubara, Matthieu Courbariaux, Daniel Soudry, Ran El-Yaniv, and Yoshua Bengio. Binarized neural networks. Advances in neural information processing systems, 29, 2016.
  • [19] Sangil Jung, Changyong Son, Seohyung Lee, **woo Son, Jae-Joon Han, Youngjun Kwak, Sung Ju Hwang, and Changkyu Choi. Learning to quantize deep networks by optimizing quantization intervals with task loss. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 4350–4359, 2019.
  • [20] Dohyung Kim, Junghyup Lee, and Bumsub Ham. Distance-aware quantization. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 5271–5280, 2021.
  • [21] Jangho Kim, KiYoon Yoo, and Nojun Kwak. Position-based scaled gradient for model quantization and pruning. Advances in neural information processing systems, 33:20415–20426, 2020.
  • [22] Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  • [23] Junghyup Lee, Dohyung Kim, and Bumsub Ham. Network quantization with element-wise gradient scaling. In IEEE Conference on Computer Vision and Pattern Recognition, CVPR 2021, virtual, June 19-25, 2021, pages 6448–6457. Computer Vision Foundation / IEEE, 2021.
  • [24] Zhiyuan Li and Sanjeev Arora. An exponential learning rate schedule for deep learning. arXiv preprint arXiv:1910.07454, 2019.
  • [25] Mingbao Lin, Rongrong Ji, Zihan Xu, Baochang Zhang, Yan Wang, Yongjian Wu, Feiyue Huang, and Chia-Wen Lin. Rotated binary neural network. Advances in neural information processing systems, 33:7474–7485, 2020.
  • [26] Zechun Liu, Kwang-Ting Cheng, Dong Huang, Eric P Xing, and Zhiqiang Shen. Nonuniform-to-uniform quantization: Towards accurate quantization via generalized straight-through estimation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 4942–4952, 2022.
  • [27] Zechun Liu, Baoyuan Wu, Wenhan Luo, Xin Yang, Wei Liu, and Kwang-Ting Cheng. Bi-real net: Enhancing the performance of 1-bit cnns with improved representational capability and advanced training algorithm. In Proceedings of the European conference on computer vision (ECCV), pages 722–737, 2018.
  • [28] Ilya Loshchilov and Frank Hutter. Sgdr: Stochastic gradient descent with warm restarts. arXiv preprint arXiv:1608.03983, 2016.
  • [29] Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. arXiv preprint arXiv:1711.05101, 2017.
  • [30] Markus Nagel, Marios Fournarakis, Rana Ali Amjad, Yelysei Bondarenko, Mart Van Baalen, and Tijmen Blankevoort. A white paper on neural network quantization. arXiv preprint arXiv:2106.08295, 2021.
  • [31] Sangyun Oh, Hyeonuk Sim, Sugil Lee, and Jongeun Lee. Automated log-scale quantization for low-cost deep neural networks. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 742–751, 2021.
  • [32] Zehua Pei, Xufeng Yao, Wenqian Zhao, and Bei Yu. Quantization via distillation and contrastive learning. IEEE Transactions on Neural Networks and Learning Systems, 2023.
  • [33] Dominika Przewlocka-Rus, Syed Shakib Sarwar, H Ekin Sumbul, Yuecheng Li, and Barbara De Salvo. Power-of-two quantization for low bitwidth and hardware compliant neural networks. arXiv preprint arXiv:2203.05025, 2022.
  • [34] Ning Qian. On the momentum term in gradient descent learning algorithms. Neural networks, 12(1):145–151, 1999.
  • [35] Haotong Qin, Ruihao Gong, Xianglong Liu, Mingzhu Shen, Ziran Wei, Fengwei Yu, and **gkuan Song. Forward and backward information retention for accurate binary neural networks. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), June 2020.
  • [36] Mohammad Rastegari, Vicente Ordonez, Joseph Redmon, and Ali Farhadi. Xnor-net: Imagenet classification using binary convolutional neural networks. In European conference on computer vision, pages 525–542. Springer, 2016.
  • [37] Herbert Robbins and Sutton Monro. A stochastic approximation method. The annals of mathematical statistics, pages 400–407, 1951.
  • [38] Babak Rokh, Ali Azarpeyvand, and Alireza Khanteymoori. A comprehensive survey on model quantization for deep neural networks. arXiv preprint arXiv:2205.07877, 2022.
  • [39] Sebastian Ruder. An overview of gradient descent optimization algorithms. arXiv preprint arXiv:1609.04747, 2016.
  • [40] Charbel Sakr, Jungwook Choi, Zhuo Wang, Kailash Gopalakrishnan, and Naresh Shanbhag. True gradient-based training of deep binary activated neural networks via continuous binarization. In 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pages 2346–2350. IEEE, 2018.
  • [41] Charbel Sakr, Steve Dai, Rangharajan Venkatesan, Brian Zimmer, William J. Dally, and Brucek Khailany. Optimal clip** and magnitude-aware differentiation for improved quantization-aware training. In Kamalika Chaudhuri, Stefanie Jegelka, Le Song, Csaba Szepesvári, Gang Niu, and Sivan Sabato, editors, International Conference on Machine Learning, ICML 2022, 17-23 July 2022, Baltimore, Maryland, USA, volume 162 of Proceedings of Machine Learning Research, pages 19123–19138. PMLR, 2022.
  • [42] Ratshih Sayed, Haytham Azmi, Heba A. Shawkey, A. H. Khalil, and Mohamed Refky. A systematic literature review on binary neural networks. IEEE Access, 11:27546–27578, 2023.
  • [43] Leslie N Smith. Cyclical learning rates for training neural networks. In 2017 IEEE winter conference on applications of computer vision (WACV), pages 464–472. IEEE, 2017.
  • [44] Masashi Takemoto, Yasutake Masuda, **gyong Cai, and Hironori Nakajo. Learning algorithm for lesserdnn, a dnn with quantized weights. In Proceedings of the 12th International Symposium on Information and Communication Technology, pages 1–7, 2023.
  • [45] Xuanhong Wangl, Yuan Zhong, and Jiawei Dong. A new low-bit quantization algorithm for neural networks. In 2023 42nd Chinese Control Conference (CCC), pages 8509–8514. IEEE, 2023.
  • [46] Yixing Xu, Kai Han, Chang Xu, Yehui Tang, Chun**g Xu, and Yunhe Wang. Learning frequency domain approximation for binary neural networks. Advances in Neural Information Processing Systems, 34:25553–25565, 2021.
  • [47] Zhe Xu and Ray CC Cheung. Accurate and compact convolutional neural networks with trained binarization. arXiv preprint arXiv:1909.11366, 2019.
  • [48] Jiwei Yang, Xu Shen, Jun Xing, Xinmei Tian, Houqiang Li, Bing Deng, Jianqiang Huang, and Xian-sheng Hua. Quantization networks. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 7308–7316, 2019.
  • [49] Chunyu Yuan and Sos S. Agaian. A comprehensive review of binary neural network. CoRR, abs/2110.06804, 2021.
  • [50] Matthew D Zeiler. Adadelta: an adaptive learning rate method. arXiv preprint arXiv:1212.5701, 2012.
  • [51] Luoming Zhang, Yefei He, Zhenyu Lou, Xin Ye, Yuxing Wang, and Hong Zhou. Root quantization: a self-adaptive supplement ste. Applied Intelligence, 53(6):6266–6275, 2023.
  • [52] Xiangxiong Zhang. Notes for optimization algorithms spring 2023. 2023.
  • [53] Yichi Zhang, Zhiru Zhang, and Lukasz Lew. Pokebnn: A binary pursuit of lightweight accuracy. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 12475–12485, 2022.
  • [54] Shuchang Zhou, Yuxin Wu, Zekun Ni, Xinyu Zhou, He Wen, and Yuheng Zou. Dorefa-net: Training low bitwidth convolutional neural networks with low bitwidth gradients. arXiv preprint arXiv:1606.06160, 2016.

Appendix A Choosing Quantization Parameters

The clip** bounds l𝑙litalic_l and u𝑢uitalic_u are determined by the number of bits b𝑏bitalic_b in the quantized representation and the desired number of representable values in the positive and negative range of the quantizer. This range of weight values is referred to as the representable range (or quantization range) of the quantizer, and can be computed as [Δl,Δu]Δ𝑙Δ𝑢[\Delta\cdot l,\Delta\cdot u][ roman_Δ ⋅ italic_l , roman_Δ ⋅ italic_u ]. Large ΔΔ\Deltaroman_Δ values allow for large w𝑤witalic_w values to avoid the clip step, whereas small values give small w𝑤witalic_w values a more granular representation. These parameters are either learned [9, 3, 12] or set by the user. For b>1𝑏1b>1italic_b > 1, l𝑙litalic_l and u𝑢uitalic_u are often chosen as l=2b1𝑙superscript2𝑏1l=-2^{b-1}italic_l = - 2 start_POSTSUPERSCRIPT italic_b - 1 end_POSTSUPERSCRIPT, u=2b11𝑢superscript2𝑏11u=2^{b-1}-1italic_u = 2 start_POSTSUPERSCRIPT italic_b - 1 end_POSTSUPERSCRIPT - 1 for symmetric quantization and l=0𝑙0l=0italic_l = 0, u=2b1𝑢superscript2𝑏1u=2^{b}-1italic_u = 2 start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT - 1 for asymmetric quatization. ΔΔ\Deltaroman_Δ is often chosen uniformly per-channel or per-token, based off of latent weight data W𝑊Witalic_W. It is sometimes set as max(|W|)/(2b1)𝑊superscript2𝑏1\max(|W|)/(2^{b}-1)roman_max ( | italic_W | ) / ( 2 start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT - 1 ), or is chosen to minimize a loss function (such as MSE or cross entropy [30]) comparing W𝑊Witalic_W and Q(W)𝑄𝑊Q(W)italic_Q ( italic_W ). For binary quantization (b=1)𝑏1(b=1)( italic_b = 1 ), Q(w)𝑄𝑤Q(w)italic_Q ( italic_w ) is typically a sign function [30, 11, 38], and there is no representable range. For binary PWL estimators,a common choice is to use Equation 1 and simply set Δ=1Δ1\Delta=1roman_Δ = 1 and [wmin,wmax]=[1,1]subscript𝑤𝑚𝑖𝑛subscript𝑤𝑚𝑎𝑥11[w_{min},w_{max}]=[-1,1][ italic_w start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT ] = [ - 1 , 1 ] [42].

Appendix B Detailed Overview of Custom Gradient Estimators

Custom binary gradient estimators. A substantial amount of research has gone into custom gradient estimators. Many choices [40, 5, 27, 47, 35, 25, 46] for binary gradient estimators are described in [49]. A popular estimator is the “Error Decay Estimator" (EDE) of [35], which uses an evolving tanh\tanhroman_tanh function to approximate the sign function.

Custom gradient estimators. The hyperbolic tangent gradient estimator (HTGE) [32] gives a piecewise function locally described by tanh functions. This approximation is used for both the forward and backward pass of Q𝑄Qitalic_Q in Differentiable Soft Quantization (DSQ) [12]. Similar approaches to the HTGE use a sum of sigmoid functions [48] and a distance-weighted piecewise linear combination of the outputs of Q𝑄Qitalic_Q [20] to approximate Q𝑄Qitalic_Q. These techniques make up the most common choices of gradient estimators, which justifies our choice of HTGE for our experiments. The gradient computation in [21] leverages a special choice of Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG based on the distance between the full-precision weight and its quantized version. [51] proposes a gradient estimator that includes an extra parameter that attempts to allow the quantization strategy to work well for both low-bit and high-bit quantization. [54] uses the STE for the round function, but replaces the clip function in the forward pass with a modified tanh\tanhroman_tanh function, which affects the gradient calculations as well. [41] introduces a choice for Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG known as “Magnitude Aware Differentiation" (MAD) that matches the STE on the representable range of the quantizer and a reciprocal function outside of this range. See Figure 1 for examples of several gradient estimators.

Gradient estimators are proposed alongside other innovations, making them hard to evaluate in isolation. Many papers that introduce a novel gradient estimator Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG simultaneously introduce further changes to the learning recipe. Some allow the parameters of Q𝑄Qitalic_Q and Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG to be learnable through gradient descent or explicit computations on the weights, or adjust them on a schedule (See Appendix A). Others, such as DSQ [12], use Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG on the forward pass and gradually update Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG to more closely approximate Q𝑄Qitalic_Q. [25] contributes a process for rotating the entire weight vector to align with the binarized weight vector. Bi-Real Net [27] also includes a trick with network activations to increase the representational capacity of the model. In addition to the Error Decay Estimator, [35] describes a method for maximizing the entropy of quantized parameters to ensure higher parameter diversity.

Implications of our main results. In light of our results 1 and 2, we can sometimes equate these addition algorithms with more well-known training strategies. For example, [35] proposes a schedule for a tanh\tanhroman_tanh-based gradient estimator to gradually approach a sign function throughout training. Since they use SGD in their experiments, we can think of each update to sharpen the gradient estimator as an effective “shifting" of the weights according to the function defined in Equation 4. This particular shift will push most weights away from 0, which has an effect similar to slowing down the learning rate. Thus this adaptive gradient estimation technique is similar to a standard learning rate decay schedule.

Appendix C Proof of Theorem 5.1

Proving Theorem 5.1 will require several steps. First, in Theorem C.1 we prove a general statement that allows us to bound the increase in weight alignment error at each training step for any non-adaptive learning rate optimization strategy. This will allow us to quickly prove Theorem 5.1, and will also simplify the proof of a similar statement for SGD with momentum, which will be given in Appendix D.

The proof in its full generality requires heavy notation and somewhat obscures the simple point of the Theorem. Because of this, we provide a less formal proof of the SGD case below.

Informal proof of Theorem 5.1.

We have for all t𝑡titalic_t,

E(t+1)=superscript𝐸𝑡1absent\displaystyle E^{(t+1)}=italic_E start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = |M(wQ^(t+1))wSTE(t+1)|𝑀superscriptsubscript𝑤^𝑄𝑡1superscriptsubscript𝑤𝑆𝑇𝐸𝑡1\displaystyle\left|M\left(w_{\hat{Q}}^{(t+1)}\right)-w_{STE}^{(t+1)}\right|| italic_M ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) - italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT | (8)
(Expand terms)=\displaystyle\text{(Expand terms})=(Expand terms ) = |M(wQ^(t)ηfQ^(t)Q^(wQ^(t)))(wSTE(t)ηαfSTE(t))|𝑀superscriptsubscript𝑤^𝑄𝑡𝜂superscriptsubscript𝑓^𝑄𝑡superscript^𝑄superscriptsubscript𝑤^𝑄𝑡superscriptsubscript𝑤𝑆𝑇𝐸𝑡𝜂𝛼superscriptsubscript𝑓𝑆𝑇𝐸𝑡\displaystyle\left|M\left(w_{\hat{Q}}^{(t)}-\eta\nabla f_{\hat{Q}}^{(t)}\hat{Q% }^{\prime}\left(w_{\hat{Q}}^{(t)}\right)\right)-\left(w_{STE}^{(t)}-\eta\alpha f% _{STE}^{(t)}\right)\right|| italic_M ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_η ∇ italic_f start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) - ( italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_η italic_α italic_f start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) | (9)
(Taylor’s Thm.)=(Taylor’s Thm.)absent\displaystyle\text{(Taylor's Thm.)}=(Taylor’s Thm.) = |M(wQ^(t))ηfQ^(t)Q^(wQ^(t))M(wQ^(t))(wSTE(t)ηαfSTE(t))+O(η2)|𝑀superscriptsubscript𝑤^𝑄𝑡𝜂superscriptsubscript𝑓^𝑄𝑡superscript^𝑄superscriptsubscript𝑤^𝑄𝑡superscript𝑀superscriptsubscript𝑤^𝑄𝑡superscriptsubscript𝑤𝑆𝑇𝐸𝑡𝜂𝛼superscriptsubscript𝑓𝑆𝑇𝐸𝑡𝑂superscript𝜂2\displaystyle\left|M\left(w_{\hat{Q}}^{(t)}\right)-\eta\nabla f_{\hat{Q}}^{(t)% }\hat{Q}^{\prime}\left(w_{\hat{Q}}^{(t)}\right)M^{\prime}\left(w_{\hat{Q}}^{(t% )}\right)-\left(w_{STE}^{(t)}-\eta\alpha f_{STE}^{(t)}\right)+O(\eta^{2})\right|| italic_M ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) - italic_η ∇ italic_f start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_M start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) - ( italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_η italic_α italic_f start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) + italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) | (10)
(Apply Eq. 13)=(Apply Eq. 13)absent\displaystyle\text{(Apply Eq. \ref{eq:M_derivative})}=(Apply Eq. ) = |M(wQ^(t))ηαfQ^(t)(wSTE(t)ηαfSTE(t))+O(η2)|𝑀superscriptsubscript𝑤^𝑄𝑡𝜂𝛼superscriptsubscript𝑓^𝑄𝑡superscriptsubscript𝑤𝑆𝑇𝐸𝑡𝜂𝛼superscriptsubscript𝑓𝑆𝑇𝐸𝑡𝑂superscript𝜂2\displaystyle\left|M\left(w_{\hat{Q}}^{(t)}\right)-\eta\alpha\nabla f_{\hat{Q}% }^{(t)}-\left(w_{STE}^{(t)}-\eta\alpha f_{STE}^{(t)}\right)+O(\eta^{2})\right|| italic_M ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) - italic_η italic_α ∇ italic_f start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - ( italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_η italic_α italic_f start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) + italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) | (11)
(Triangle Ineq.)(Triangle Ineq.)absent\displaystyle\text{(Triangle Ineq.)}\leq(Triangle Ineq.) ≤ E(t)+ηα|fQ^(t)fSTE(t)|+O(η)2superscript𝐸𝑡𝜂𝛼superscriptsubscript𝑓^𝑄𝑡superscriptsubscript𝑓𝑆𝑇𝐸𝑡𝑂superscript𝜂2\displaystyle E^{(t)}+\eta\alpha\left|\nabla f_{\hat{Q}}^{(t)}-f_{STE}^{(t)}% \right|+O(\eta)^{2}italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_η italic_α | ∇ italic_f start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_f start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | + italic_O ( italic_η ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (12)

Here Equation 10 follows from Taylor’s Theorem. Equation 11 follows from Equation 13 below

Mw(w)=α1Q^(w),𝑀𝑤𝑤𝛼1superscript^𝑄𝑤\frac{\partial M}{\partial w}(w)=\alpha\cdot\frac{1}{\hat{Q}^{\prime}(w)},divide start_ARG ∂ italic_M end_ARG start_ARG ∂ italic_w end_ARG ( italic_w ) = italic_α ⋅ divide start_ARG 1 end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ) end_ARG , (13)

and Equation 12 follows from the triangle inequality. The complete proof simply requires writing out an explicit form for the O(η2)𝑂superscript𝜂2O(\eta^{2})italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) term, and is given in detail below. ∎

Theorem C.1 applies to gradient update rules that satisfy a special property in Assumption C.1.3. We will show later in this section that this holds for the SGD formula defined in 3, and in Appendix D for SGD with momentum. Similar proofs show that it holds for a large class of non-adaptive learning rate gradient update rules.

Theorem C.1.

Suppose that

E(t):=|M(wQ^(t))wSTE(t)|assignsuperscript𝐸𝑡𝑀subscriptsuperscript𝑤𝑡^𝑄subscriptsuperscript𝑤𝑡𝑆𝑇𝐸E^{(t)}:=\left|M\left(w^{(t)}_{\hat{Q}}\right)-w^{(t)}_{STE}\right|italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT := | italic_M ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT ) - italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT | (14)

is the alignment error for Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net with gradient estimators, learning rates, and initial weights given by Table 2. Suppose that Assumptions 5.1.1 and 5.1.2 hold and the model weights are updated according to Equation 2 for some function g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT. In addition, suppose that

  1. C.1.3

    For each t𝑡titalic_t, the quantity

    |g(t)(fQ^(0)Q^(w(0)),,fQ^(t)Q^(w(t)),η)Q^(w(t))g(t)(fQ^(0),,fQ^(t),η)|=O(c(η)).superscript𝑔𝑡subscriptsuperscript𝑓0^𝑄superscript^𝑄superscript𝑤0subscriptsuperscript𝑓𝑡^𝑄superscript^𝑄superscript𝑤𝑡𝜂superscript^𝑄superscript𝑤𝑡superscript𝑔𝑡subscriptsuperscript𝑓0^𝑄subscriptsuperscript𝑓𝑡^𝑄𝜂𝑂𝑐𝜂\left|\frac{g^{(t)}(\nabla f^{(0)}_{\hat{Q}}\hat{Q}^{\prime}(w^{(0)}),\ldots,% \nabla f^{(t)}_{\hat{Q}}\hat{Q}^{\prime}(w^{(t)}),\eta)}{\hat{Q}^{\prime}(w^{(% t)})}-g^{(t)}(\nabla f^{(0)}_{\hat{Q}},\ldots,\nabla f^{(t)}_{\hat{Q}},\eta)% \right|=O(c(\eta)).| divide start_ARG italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG - italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , italic_η ) | = italic_O ( italic_c ( italic_η ) ) . (15)

Then we have

E(t+1)superscript𝐸𝑡1absent\displaystyle E^{(t+1)}\leqitalic_E start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ≤ E(t)+|αg(t)(αfQ^(0),,fQ^(t),η)g(t)(fSTE(0),,fSTE(t),αη)|+superscript𝐸𝑡limit-from𝛼superscript𝑔𝑡𝛼subscriptsuperscript𝑓0^𝑄subscriptsuperscript𝑓𝑡^𝑄𝜂superscript𝑔𝑡subscriptsuperscript𝑓0𝑆𝑇𝐸subscriptsuperscript𝑓𝑡𝑆𝑇𝐸𝛼𝜂\displaystyle E^{(t)}+\left|\alpha g^{(t)}(\alpha\nabla f^{(0)}_{\hat{Q}},% \ldots,\nabla f^{(t)}_{\hat{Q}},\eta)-g^{(t)}(\nabla f^{(0)}_{STE},\ldots,% \nabla f^{(t)}_{STE},\alpha\eta)\right|+italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + | italic_α italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( italic_α ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , italic_η ) - italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , italic_α italic_η ) | + (16)
L2(g(t)(fQ^(0)Q^(w(0)),,fQ^(t)Q^(w(t)),η)L)2+O(c(η))superscript𝐿2superscriptsuperscript𝑔𝑡subscriptsuperscript𝑓0^𝑄superscript^𝑄superscript𝑤0subscriptsuperscript𝑓𝑡^𝑄superscript^𝑄superscript𝑤𝑡𝜂subscript𝐿2𝑂𝑐𝜂\displaystyle\frac{L^{\prime}}{2}\cdot\left(\frac{g^{(t)}(\nabla f^{(0)}_{\hat% {Q}}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^{(t)}_{\hat{Q}}\hat{Q}^{\prime}(% w^{(t)}),\eta)}{L_{-}}\right)^{2}+O(c(\eta))divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ⋅ ( divide start_ARG italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_O ( italic_c ( italic_η ) ) (17)
Proof.

By Equation 2, we have

E(t+1)=superscript𝐸𝑡1absent\displaystyle E^{(t+1)}=italic_E start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = |M(wQ^(t+1))wSTE(t+1)|𝑀superscriptsubscript𝑤^𝑄𝑡1superscriptsubscript𝑤𝑆𝑇𝐸𝑡1\displaystyle\left|M\left(w_{\hat{Q}}^{(t+1)}\right)-w_{STE}^{(t+1)}\right|| italic_M ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) - italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT | (18)
=\displaystyle== |M(wQ^(t)+g(t)(fQ^(0)Q^(w(0)),,fQ^(t)Q^(w(t)),η))\displaystyle\Big{|}M\left(w_{\hat{Q}}^{(t)}+g^{(t)}(\nabla f^{(0)}_{\hat{Q}}% \hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^{(t)}_{\hat{Q}}\hat{Q}^{\prime}(w^{(% t)}),\eta)\right)-| italic_M ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) ) - (19)
(wSTE(t)+g(t)(fSTE(0),,fSTE(t),αη))|\displaystyle\left(w_{STE}^{(t)}+g^{(t)}(\nabla f^{(0)}_{STE},\ldots,\nabla f^% {(t)}_{STE},\alpha\eta)\right)\Big{|}( italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , italic_α italic_η ) ) | (20)
=\displaystyle== |M(wQ^(t))+g(t)(fQ^(0)Q^(w(0)),,fQ^(t)Q^(w(t)),η)M(wQ^(t))\displaystyle\Big{|}M\left(w_{\hat{Q}}^{(t)}\right)+g^{(t)}(\nabla f^{(0)}_{% \hat{Q}}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^{(t)}_{\hat{Q}}\hat{Q}^{% \prime}(w^{(t)}),\eta)M^{\prime}\left(w_{\hat{Q}}^{(t)}\right)-| italic_M ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) + italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) italic_M start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) - (21)
(wSTE(t)+g(t)(fSTE(0),,fSTE(t),αη))+R|\displaystyle\left(w_{STE}^{(t)}+g^{(t)}(\nabla f^{(0)}_{STE},\ldots,\nabla f^% {(t)}_{STE},\alpha\eta)\right)+R\Big{|}( italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , italic_α italic_η ) ) + italic_R | (22)
=\displaystyle== |M(wQ^(t))+αg(t)(fQ^(0)Q^(w(0)),,fQ^(t)Q^(w(t)),η)/Q^(wQ^(t))\displaystyle\Big{|}M\left(w_{\hat{Q}}^{(t)}\right)+\alpha g^{(t)}(\nabla f^{(% 0)}_{\hat{Q}}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^{(t)}_{\hat{Q}}\hat{Q}^% {\prime}(w^{(t)}),\eta)/\hat{Q}^{\prime}\left(w_{\hat{Q}}^{(t)}\right)-| italic_M ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) + italic_α italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) / over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) - (23)
(wSTE(t)+g(t)(fSTE(0),,fSTE(t),αη))+R|\displaystyle\left(w_{STE}^{(t)}+g^{(t)}(\nabla f^{(0)}_{STE},\ldots,\nabla f^% {(t)}_{STE},\alpha\eta)\right)+R\Big{|}( italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , italic_α italic_η ) ) + italic_R | (24)
=\displaystyle== |M(wQ^(t))+αg(t)(fQ^(0),,fQ^(t),η)+O(c(η))\displaystyle\Big{|}M\left(w_{\hat{Q}}^{(t)}\right)+\alpha g^{(t)}(\nabla f^{(% 0)}_{\hat{Q}},\ldots,\nabla f^{(t)}_{\hat{Q}},\eta)+O(c(\eta))-| italic_M ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) + italic_α italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , italic_η ) + italic_O ( italic_c ( italic_η ) ) - (25)
(wSTE(t)+g(t)(fSTE(0),,fSTE(t),αη))+R|\displaystyle\left(w_{STE}^{(t)}+g^{(t)}(\nabla f^{(0)}_{STE},\ldots,\nabla f^% {(t)}_{STE},\alpha\eta)\right)+R\Big{|}( italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , italic_α italic_η ) ) + italic_R | (26)
\displaystyle\leq E(t)+|αg(t)(fQ^(0),,fQ^(t),η)g(t)(fSTE(0),,fSTE(t),αη)|+superscript𝐸𝑡limit-from𝛼superscript𝑔𝑡subscriptsuperscript𝑓0^𝑄subscriptsuperscript𝑓𝑡^𝑄𝜂superscript𝑔𝑡subscriptsuperscript𝑓0𝑆𝑇𝐸subscriptsuperscript𝑓𝑡𝑆𝑇𝐸𝛼𝜂\displaystyle E^{(t)}+\left|\alpha g^{(t)}(\nabla f^{(0)}_{\hat{Q}},\ldots,% \nabla f^{(t)}_{\hat{Q}},\eta)-g^{(t)}(\nabla f^{(0)}_{STE},\ldots,\nabla f^{(% t)}_{STE},\alpha\eta)\right|+italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + | italic_α italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , italic_η ) - italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , italic_α italic_η ) | + (27)
|R|+O(c(η))𝑅𝑂𝑐𝜂\displaystyle|R|+O(c(\eta))| italic_R | + italic_O ( italic_c ( italic_η ) ) (28)

Here Equation 22 follows from Taylor’s Theorem, where R𝑅Ritalic_R is the remainder term. Equation 24 follows from Equation 13 in the previous proof, and Equation 26 follows from Assumption C.1.3. Equation 28 follows from the triangle inequality. By Lemma 2.1 of [52], we can bound R𝑅Ritalic_R by

|R|L2L2(g(t)(fQ^(0)Q^(w(0)),,fQ^(t)Q^(w(t)),η))2,𝑅superscript𝐿2superscriptsubscript𝐿2superscriptsuperscript𝑔𝑡subscriptsuperscript𝑓0^𝑄superscript^𝑄superscript𝑤0subscriptsuperscript𝑓𝑡^𝑄superscript^𝑄superscript𝑤𝑡𝜂2|R|\leq\frac{L^{\prime}}{2L_{-}^{2}}\left(g^{(t)}(\nabla f^{(0)}_{\hat{Q}}\hat% {Q}^{\prime}(w^{(0)}),\ldots,\nabla f^{(t)}_{\hat{Q}}\hat{Q}^{\prime}(w^{(t)})% ,\eta)\right)^{2},| italic_R | ≤ divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , (29)

To see this, we need to show that Msuperscript𝑀M^{\prime}italic_M start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is Lipschitz continuous with Lipschitz constant L/L2superscript𝐿superscriptsubscript𝐿2L^{\prime}/L_{-}^{2}italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT / italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. This holds since for any w,v𝑤𝑣w,v\in\mathbb{R}italic_w , italic_v ∈ blackboard_R,

|M(w)M(v)|=|1Q(w)1Q(w)|=|Q(v)Q(w)Q(w)Q(v)|LL2|wv|.superscript𝑀𝑤superscript𝑀𝑣1superscript𝑄𝑤1superscript𝑄𝑤superscript𝑄𝑣superscript𝑄𝑤superscript𝑄𝑤superscript𝑄𝑣superscript𝐿superscriptsubscript𝐿2𝑤𝑣|M^{\prime}(w)-M^{\prime}(v)|=\left|\frac{1}{Q^{\prime}(w)}-\frac{1}{Q^{\prime% }(w)}\right|=\left|\frac{Q^{\prime}(v)-Q^{\prime}(w)}{Q^{\prime}(w)Q^{\prime}(% v)}\right|\leq\frac{L^{\prime}}{L_{-}^{2}}|w-v|.| italic_M start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ) - italic_M start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_v ) | = | divide start_ARG 1 end_ARG start_ARG italic_Q start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ) end_ARG - divide start_ARG 1 end_ARG start_ARG italic_Q start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ) end_ARG | = | divide start_ARG italic_Q start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_v ) - italic_Q start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ) end_ARG start_ARG italic_Q start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ) italic_Q start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_v ) end_ARG | ≤ divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG | italic_w - italic_v | .

In the last step we use both Assumptions 5.1.1 and 5.1.2. Putting this all together, we have Equation 17. ∎

We can now apply Theorem C.1 for the SGD update rule (Equation 3) to give a proof of Theorem 5.1.

Proof of Theorem 5.1.

To prove Theorem 5.1, we first show that Assumption C.1.3 holds for the SGD update rule with c(η)=0𝑐𝜂0c(\eta)=0italic_c ( italic_η ) = 0. We have

|g(t)(fQ^(0)Q^(w(0)),,fQ^(t)Q^(w(t)),η)Q^(w(t))g(t)(fQ^(0),,fQ^(t),η)|=superscript𝑔𝑡subscriptsuperscript𝑓0^𝑄superscript^𝑄superscript𝑤0subscriptsuperscript𝑓𝑡^𝑄superscript^𝑄superscript𝑤𝑡𝜂superscript^𝑄superscript𝑤𝑡superscript𝑔𝑡subscriptsuperscript𝑓0^𝑄subscriptsuperscript𝑓𝑡^𝑄𝜂absent\displaystyle\left|\frac{g^{(t)}(\nabla f^{(0)}_{\hat{Q}}\hat{Q}^{\prime}(w^{(% 0)}),\ldots,\nabla f^{(t)}_{\hat{Q}}\hat{Q}^{\prime}(w^{(t)}),\eta)}{\hat{Q}^{% \prime}(w^{(t)})}-g^{(t)}(\nabla f^{(0)}_{\hat{Q}},\ldots,\nabla f^{(t)}_{\hat% {Q}},\eta)\right|=| divide start_ARG italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG - italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , italic_η ) | = (30)
|ηfQ^(t)Q^(w(t))Q^(w(t))ηfQ^(t)|=𝜂subscriptsuperscript𝑓𝑡^𝑄superscript^𝑄superscript𝑤𝑡superscript^𝑄superscript𝑤𝑡𝜂subscriptsuperscript𝑓𝑡^𝑄absent\displaystyle\left|\frac{\eta\nabla f^{(t)}_{\hat{Q}}\hat{Q}^{\prime}(w^{(t)})% }{\hat{Q}^{\prime}(w^{(t)})}-\eta\nabla f^{(t)}_{\hat{Q}}\right|=| divide start_ARG italic_η ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG - italic_η ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT | = 0.0\displaystyle 0.0 . (31)

Now we can apply Theorem C.1. We have

E(t+1)superscript𝐸𝑡1absent\displaystyle E^{(t+1)}\leqitalic_E start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ≤ E(t)+|αg(t)(αfQ^(0),,fQ^(t),η)g(t)(fSTE(0),,fSTE(t),αη)|+superscript𝐸𝑡limit-from𝛼superscript𝑔𝑡𝛼subscriptsuperscript𝑓0^𝑄subscriptsuperscript𝑓𝑡^𝑄𝜂superscript𝑔𝑡subscriptsuperscript𝑓0𝑆𝑇𝐸subscriptsuperscript𝑓𝑡𝑆𝑇𝐸𝛼𝜂\displaystyle E^{(t)}+\left|\alpha g^{(t)}(\alpha\nabla f^{(0)}_{\hat{Q}},% \ldots,\nabla f^{(t)}_{\hat{Q}},\eta)-g^{(t)}(\nabla f^{(0)}_{STE},\ldots,% \nabla f^{(t)}_{STE},\alpha\eta)\right|+italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + | italic_α italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( italic_α ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , italic_η ) - italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , italic_α italic_η ) | + (32)
L2(g(t)(fQ^(0)Q^(w(0)),,fQ^(t)Q^(w(t)),η)L)2+O(c(η))superscript𝐿2superscriptsuperscript𝑔𝑡subscriptsuperscript𝑓0^𝑄superscript^𝑄superscript𝑤0subscriptsuperscript𝑓𝑡^𝑄superscript^𝑄superscript𝑤𝑡𝜂subscript𝐿2𝑂𝑐𝜂\displaystyle\frac{L^{\prime}}{2}\cdot\left(\frac{g^{(t)}(\nabla f^{(0)}_{\hat% {Q}}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^{(t)}_{\hat{Q}}\hat{Q}^{\prime}(% w^{(t)}),\eta)}{L_{-}}\right)^{2}+O(c(\eta))divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ⋅ ( divide start_ARG italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_O ( italic_c ( italic_η ) ) (33)
=\displaystyle== E(t)+ηα|fQ^(t)fSTE(t)|+L2(ηfQ^(t)Q^(w(t))L)2+0superscript𝐸𝑡𝜂𝛼subscriptsuperscript𝑓𝑡^𝑄subscriptsuperscript𝑓𝑡𝑆𝑇𝐸superscript𝐿2superscript𝜂subscriptsuperscript𝑓𝑡^𝑄superscript^𝑄superscript𝑤𝑡subscript𝐿20\displaystyle E^{(t)}+\eta\alpha\left|\nabla f^{(t)}_{\hat{Q}}-\nabla f^{(t)}_% {STE}\right|+\frac{L^{\prime}}{2}\cdot\left(\frac{\eta\nabla f^{(t)}_{\hat{Q}}% \hat{Q}^{\prime}(w^{(t)})}{L_{-}}\right)^{2}+0italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_η italic_α | ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT - ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT | + divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ⋅ ( divide start_ARG italic_η ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 0 (34)
\displaystyle\leq E(t)+ηα|fQ^(t)fSTE(t)|+L2(ηL+fQ^(t)L)2superscript𝐸𝑡𝜂𝛼subscriptsuperscript𝑓𝑡^𝑄subscriptsuperscript𝑓𝑡𝑆𝑇𝐸superscript𝐿2superscript𝜂subscript𝐿subscriptsuperscript𝑓𝑡^𝑄subscript𝐿2\displaystyle E^{(t)}+\eta\alpha\left|\nabla f^{(t)}_{\hat{Q}}-\nabla f^{(t)}_% {STE}\right|+\frac{L^{\prime}}{2}\cdot\left(\frac{\eta L_{+}\nabla f^{(t)}_{% \hat{Q}}}{L_{-}}\right)^{2}italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_η italic_α | ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT - ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT | + divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ⋅ ( divide start_ARG italic_η italic_L start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (35)

This gives us Equation 6, as desired. ∎

Appendix D Theorem 5.1 for SGD with momentum

Here we give a version of Theorem 5.1 for stochastic gradient descent with momentum. The weight update rule for this learning algorithm is given by

g(t)(f(0)Q^(w(0)),,f(t)Q^(w(t)),η)=ηmtsuperscript𝑔𝑡superscript𝑓0superscript^𝑄superscript𝑤0superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡𝜂𝜂subscript𝑚𝑡g^{(t)}(\nabla f^{(0)}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^{(t)}\hat{Q}^{% \prime}(w^{(t)}),\eta)=-\eta m_{t}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) = - italic_η italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (36)

where mtsubscript𝑚𝑡m_{t}italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is defined recursively as

mt=βmt1+(1β)f(t)Q^(w(t))subscript𝑚𝑡𝛽subscript𝑚𝑡11𝛽superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡m_{t}=\beta m_{t-1}+(1-\beta)\nabla f^{(t)}\hat{Q}^{\prime}(w^{(t)})italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_β italic_m start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + ( 1 - italic_β ) ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) (37)

for a hyperparameter β[0,1)𝛽01\beta\in[0,1)italic_β ∈ [ 0 , 1 ), which is often set to 0.9 or a similar value [39]. We can expand this recursive definition, and obtain the single rule

g(t)(f(0)Q^(w(0)),,f(t)Q^(w(t)),η)=η(1β)i=0tβtif(i)Q^(w(i))superscript𝑔𝑡superscript𝑓0superscript^𝑄superscript𝑤0superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡𝜂𝜂1𝛽superscriptsubscript𝑖0𝑡superscript𝛽𝑡𝑖superscript𝑓𝑖superscript^𝑄superscript𝑤𝑖g^{(t)}(\nabla f^{(0)}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^{(t)}\hat{Q}^{% \prime}(w^{(t)}),\eta)=-\eta(1-\beta)\sum_{i=0}^{t}\beta^{t-i}\nabla f^{(i)}% \hat{Q}^{\prime}(w^{(i)})italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) = - italic_η ( 1 - italic_β ) ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) (38)

Theorems D.1 and D.2 show that Assumption C.1.3 holds for this update rule under mild conditions. From this we can apply Theorem C.1 for SGD with momentum to obtain Theorem D.3, a result similar to Theorem 5.1.

Theorem D.1.

Define g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT by Equation 38. Suppose that Assumption 5.1.1 holds. Further suppose that each f(t)superscript𝑓𝑡\nabla f^{(t)}∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT is bounded by

|f(t)|<g+L+(1βt+1).superscript𝑓𝑡subscript𝑔subscript𝐿1superscript𝛽𝑡1|\nabla f^{(t)}|<\frac{g_{+}}{L_{+}(1-\beta^{t+1})}.| ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | < divide start_ARG italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( 1 - italic_β start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT ) end_ARG . (39)

Then

|g(t)(f(0)Q^(w(0)),,f(t)Q^(w(t)),η)|<ηg+superscript𝑔𝑡superscript𝑓0superscript^𝑄superscript𝑤0superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡𝜂𝜂subscript𝑔|g^{(t)}(\nabla f^{(0)}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^{(t)}\hat{Q}^% {\prime}(w^{(t)}),\eta)|<\eta g_{+}| italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) | < italic_η italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT
Proof.

By the triangle inequality and Assumption 5.1.1, we have

|g(t)(f(0)Q^(w(0)),,f(t)Q^(w(t)),η)|<ηL+(1β)i=0tβti|f(i)|.superscript𝑔𝑡superscript𝑓0superscript^𝑄superscript𝑤0superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡𝜂𝜂subscript𝐿1𝛽superscriptsubscript𝑖0𝑡superscript𝛽𝑡𝑖superscript𝑓𝑖|g^{(t)}(\nabla f^{(0)}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^{(t)}\hat{Q}^% {\prime}(w^{(t)}),\eta)|<\eta L_{+}(1-\beta)\sum_{i=0}^{t}\beta^{t-i}|\nabla f% ^{(i)}|.| italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) | < italic_η italic_L start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ( 1 - italic_β ) ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT | ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT | .

Now applying the bound given in Equation 39, we have

|g(t)(f(0)Q^(w(0)),,f(t)Q^(w(t)),η)|<ηg+1β1βt+1i=0tβti.superscript𝑔𝑡superscript𝑓0superscript^𝑄superscript𝑤0superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡𝜂𝜂subscript𝑔1𝛽1superscript𝛽𝑡1superscriptsubscript𝑖0𝑡superscript𝛽𝑡𝑖|g^{(t)}(\nabla f^{(0)}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^{(t)}\hat{Q}^% {\prime}(w^{(t)}),\eta)|<\eta g_{+}\frac{1-\beta}{1-\beta^{t+1}}\sum_{i=0}^{t}% \beta^{t-i}.| italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) | < italic_η italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT divide start_ARG 1 - italic_β end_ARG start_ARG 1 - italic_β start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT .

Since

i=0tβti=1βt+11βsuperscriptsubscript𝑖0𝑡superscript𝛽𝑡𝑖1superscript𝛽𝑡11𝛽\sum_{i=0}^{t}\beta^{t-i}=\frac{1-\beta^{t+1}}{1-\beta}∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT = divide start_ARG 1 - italic_β start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT end_ARG start_ARG 1 - italic_β end_ARG

for all β<1𝛽1\beta<1italic_β < 1, we have

|g(t)(f(0)Q^(w(0)),,f(t)Q^(w(t)),η)|<ηg+superscript𝑔𝑡superscript𝑓0superscript^𝑄superscript𝑤0superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡𝜂𝜂subscript𝑔|g^{(t)}(\nabla f^{(0)}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^{(t)}\hat{Q}^% {\prime}(w^{(t)}),\eta)|<\eta g_{+}| italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) | < italic_η italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT (40)

as desired. ∎

Theorem D.2.

Define g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT by Equation 38. Suppose that

  1. D.2.1

    0<LQ^(w)0subscript𝐿superscript^𝑄𝑤0<L_{-}\leq\hat{Q}^{\prime}(w)0 < italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT ≤ over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ) for all w𝑤witalic_w

  2. D.2.2

    Q^(w)superscript^𝑄𝑤\hat{Q}^{\prime}(w)over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ) is Lsuperscript𝐿L^{\prime}italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT-Lipschitz

  3. D.2.3

    For each t𝑡titalic_t, Each g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT is bounded by |w(t+1)w(t)|<ηg+superscript𝑤𝑡1superscript𝑤𝑡𝜂subscript𝑔|w^{(t+1)}-w^{(t)}|<\eta g_{+}| italic_w start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | < italic_η italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT.

Then for each t𝑡titalic_t, we have

|g(t)(f(0)Q^(w(0)),,f(t)Q^(w(t)),η)Q^(w(t))g(t)(f(0),,f(t),η)|=O(η2).superscript𝑔𝑡superscript𝑓0superscript^𝑄superscript𝑤0superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡𝜂superscript^𝑄superscript𝑤𝑡superscript𝑔𝑡superscript𝑓0superscript𝑓𝑡𝜂𝑂superscript𝜂2\left|\frac{g^{(t)}(\nabla f^{(0)}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^{(% t)}\hat{Q}^{\prime}(w^{(t)}),\eta)}{\hat{Q}^{\prime}(w^{(t)})}-g^{(t)}(\nabla f% ^{(0)},\ldots,\nabla f^{(t)},\eta)\right|=O(\eta^{2}).| divide start_ARG italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG - italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , italic_η ) | = italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) . (41)

so that Assumption C.1.3 holds with c(η)=η2𝑐𝜂superscript𝜂2c(\eta)=\eta^{2}italic_c ( italic_η ) = italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

Proof.

We have by Equation 38

g(t)(f(0)Q^(w(0)),,f(t)Q^(w(t)),η)Q^(w(t))=η(1β)i=0tβtif(i)Q^(w(i))Q^(w(t)).superscript𝑔𝑡superscript𝑓0superscript^𝑄superscript𝑤0superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡𝜂superscript^𝑄superscript𝑤𝑡𝜂1𝛽superscriptsubscript𝑖0𝑡superscript𝛽𝑡𝑖superscript𝑓𝑖superscript^𝑄superscript𝑤𝑖superscript^𝑄superscript𝑤𝑡\frac{g^{(t)}(\nabla f^{(0)}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^{(t)}% \hat{Q}^{\prime}(w^{(t)}),\eta)}{\hat{Q}^{\prime}(w^{(t)})}=-\eta(1-\beta)\sum% _{i=0}^{t}\beta^{t-i}\nabla f^{(i)}\frac{\hat{Q}^{\prime}(w^{(i)})}{\hat{Q}^{% \prime}(w^{(t)})}.divide start_ARG italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG = - italic_η ( 1 - italic_β ) ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT divide start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG . (42)

We would like to show that for each i𝑖iitalic_i,

βtiQ^(w(i))Q^(w(t))=βti(1+O(η))superscript𝛽𝑡𝑖superscript^𝑄superscript𝑤𝑖superscript^𝑄superscript𝑤𝑡superscript𝛽𝑡𝑖1𝑂𝜂\beta^{t-i}\frac{\hat{Q}^{\prime}(w^{(i)})}{\hat{Q}^{\prime}(w^{(t)})}=\beta^{% t-i}(1+O(\eta))italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT divide start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG = italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ( 1 + italic_O ( italic_η ) )

since then we would have

|g(t)(f(0)Q^(w(0)),,f(t)Q^(w(t)),η)Q^(w(t))g(t)(f(0),,f(t),η)|=superscript𝑔𝑡superscript𝑓0superscript^𝑄superscript𝑤0superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡𝜂superscript^𝑄superscript𝑤𝑡superscript𝑔𝑡superscript𝑓0superscript𝑓𝑡𝜂absent\displaystyle\left|\frac{g^{(t)}(\nabla f^{(0)}\hat{Q}^{\prime}(w^{(0)}),% \ldots,\nabla f^{(t)}\hat{Q}^{\prime}(w^{(t)}),\eta)}{\hat{Q}^{\prime}(w^{(t)}% )}-g^{(t)}(\nabla f^{(0)},\ldots,\nabla f^{(t)},\eta)\right|=| divide start_ARG italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG - italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , italic_η ) | = (43)
|η(1β)i=0tβtif(i)Q^(w(i))Q^(w(t))+η(1β)i=0tβtif(i)|=𝜂1𝛽superscriptsubscript𝑖0𝑡superscript𝛽𝑡𝑖superscript𝑓𝑖superscript^𝑄superscript𝑤𝑖superscript^𝑄superscript𝑤𝑡𝜂1𝛽superscriptsubscript𝑖0𝑡superscript𝛽𝑡𝑖superscript𝑓𝑖absent\displaystyle\left|-\eta(1-\beta)\sum_{i=0}^{t}\beta^{t-i}\nabla f^{(i)}\frac{% \hat{Q}^{\prime}(w^{(i)})}{\hat{Q}^{\prime}(w^{(t)})}+\eta(1-\beta)\sum_{i=0}^% {t}\beta^{t-i}\nabla f^{(i)}\right|=| - italic_η ( 1 - italic_β ) ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT divide start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG + italic_η ( 1 - italic_β ) ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT | = (44)
|η(1β)i=0tβtif(i)(1+O(η))+η(1β)i=0tβtif(i)|=𝜂1𝛽superscriptsubscript𝑖0𝑡superscript𝛽𝑡𝑖superscript𝑓𝑖1𝑂𝜂𝜂1𝛽superscriptsubscript𝑖0𝑡superscript𝛽𝑡𝑖superscript𝑓𝑖absent\displaystyle\left|-\eta(1-\beta)\sum_{i=0}^{t}\beta^{t-i}\nabla f^{(i)}(1+O(% \eta))+\eta(1-\beta)\sum_{i=0}^{t}\beta^{t-i}\nabla f^{(i)}\right|=| - italic_η ( 1 - italic_β ) ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ( 1 + italic_O ( italic_η ) ) + italic_η ( 1 - italic_β ) ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT | = O(η2)𝑂superscript𝜂2\displaystyle O(\eta^{2})italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (45)

The first step is to note that log(Q^)superscript^𝑄\log(\hat{Q}^{\prime})roman_log ( over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) is Lipschitz with Lipschitz constant L/Lsuperscript𝐿subscript𝐿L^{\prime}/L_{-}italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT / italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT. To see this, first note that log(x)𝑥\log(x)roman_log ( italic_x ) is 1/L1subscript𝐿1/L_{-}1 / italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT-Lipschitz on the range [L,]subscript𝐿[L_{-},\infty][ italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT , ∞ ]. Then by Assumptions D.2.1 and D.2.2 and the fact that the composition of Lipschitz functions is Lipschitz with the product constant, we have

|log(Q^(w))log(Q^(v))|LL|wv|superscript^𝑄𝑤superscript^𝑄𝑣superscript𝐿subscript𝐿𝑤𝑣|\log(\hat{Q}^{\prime}(w))-\log(\hat{Q}^{\prime}(v))|\leq\frac{L^{\prime}}{L_{% -}}|w-v|| roman_log ( over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ) ) - roman_log ( over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_v ) ) | ≤ divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG | italic_w - italic_v |

which is our desired Lipschitz property. Making use of this property, Assumption D.2.3, and Equation 2, we have

|log(Q^(w(i)))log(Q^(w(t)))|superscript^𝑄superscript𝑤𝑖superscript^𝑄superscript𝑤𝑡absent\displaystyle|\log(\hat{Q}^{\prime}(w^{(i)}))-\log(\hat{Q}^{\prime}(w^{(t)}))|\leq| roman_log ( over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) ) - roman_log ( over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) | ≤ LL|w(i)w(t)|superscript𝐿subscript𝐿superscript𝑤𝑖superscript𝑤𝑡\displaystyle\frac{L^{\prime}}{L_{-}}|w^{(i)}-w^{(t)}|divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG | italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT - italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT | (47)
=\displaystyle== LL|j=it1w(i)w(i+1)|superscript𝐿subscript𝐿superscriptsubscript𝑗𝑖𝑡1superscript𝑤𝑖superscript𝑤𝑖1\displaystyle\frac{L^{\prime}}{L_{-}}\left|\sum_{j=i}^{t-1}w^{(i)}-w^{(i+1)}\right|divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG | ∑ start_POSTSUBSCRIPT italic_j = italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT - italic_w start_POSTSUPERSCRIPT ( italic_i + 1 ) end_POSTSUPERSCRIPT | (48)
\displaystyle\leq LLj=it1|w(i)w(i+1)|superscript𝐿subscript𝐿superscriptsubscript𝑗𝑖𝑡1superscript𝑤𝑖superscript𝑤𝑖1\displaystyle\frac{L^{\prime}}{L_{-}}\sum_{j=i}^{t-1}\left|w^{(i)}-w^{(i+1)}\right|divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_j = italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT | italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT - italic_w start_POSTSUPERSCRIPT ( italic_i + 1 ) end_POSTSUPERSCRIPT | (49)
\displaystyle\leq ηLL(ti)g+.𝜂superscript𝐿subscript𝐿𝑡𝑖subscript𝑔\displaystyle\eta\frac{L^{\prime}}{L_{-}}(t-i)g_{+}.italic_η divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG ( italic_t - italic_i ) italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT . (50)

Solving for the quotient Q^(w(i))/Q^(w(t))superscript^𝑄superscript𝑤𝑖superscript^𝑄superscript𝑤𝑡\hat{Q}^{\prime}(w^{(i)})/\hat{Q}^{\prime}(w^{(t)})over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) / over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ), we have

ηL(ti)g+/Llog(Q^(w(i)))log(Q^(w(t)))ηL(ti)g+/L𝜂superscript𝐿𝑡𝑖subscript𝑔subscript𝐿superscript^𝑄superscript𝑤𝑖superscript^𝑄superscript𝑤𝑡𝜂superscript𝐿𝑡𝑖subscript𝑔subscript𝐿-\eta L^{\prime}(t-i)g_{+}/L_{-}\leq\log(\hat{Q}^{\prime}(w^{(i)}))-\log(\hat{% Q}^{\prime}(w^{(t)}))\leq\eta L^{\prime}(t-i)g_{+}/L_{-}- italic_η italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t - italic_i ) italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT / italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT ≤ roman_log ( over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) ) - roman_log ( over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ≤ italic_η italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t - italic_i ) italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT / italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT
exp(ηL(ti)g+/L)Q^(w(i))Q^(w(t))exp(ηL(ti)g+/L)𝜂superscript𝐿𝑡𝑖subscript𝑔subscript𝐿superscript^𝑄superscript𝑤𝑖superscript^𝑄superscript𝑤𝑡𝜂superscript𝐿𝑡𝑖subscript𝑔subscript𝐿\exp(-\eta L^{\prime}(t-i)g_{+}/L_{-})\leq\frac{\hat{Q}^{\prime}(w^{(i)})}{% \hat{Q}^{\prime}(w^{(t)})}\leq\exp(\eta L^{\prime}(t-i)g_{+}/L_{-})roman_exp ( - italic_η italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t - italic_i ) italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT / italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT ) ≤ divide start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG ≤ roman_exp ( italic_η italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t - italic_i ) italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT / italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT )
βηL(ti)g+/(log(β)L)Q^(w(i))Q^(w(t))βηL(ti)g+/(log(β)L)superscript𝛽𝜂superscript𝐿𝑡𝑖subscript𝑔𝛽subscript𝐿superscript^𝑄superscript𝑤𝑖superscript^𝑄superscript𝑤𝑡superscript𝛽𝜂superscript𝐿𝑡𝑖subscript𝑔𝛽subscript𝐿\beta^{-\eta L^{\prime}(t-i)g_{+}/(\log(\beta)L_{-})}\leq\frac{\hat{Q}^{\prime% }(w^{(i)})}{\hat{Q}^{\prime}(w^{(t)})}\leq\beta^{\eta L^{\prime}(t-i)g_{+}/(% \log(\beta)L_{-})}italic_β start_POSTSUPERSCRIPT - italic_η italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t - italic_i ) italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT / ( roman_log ( italic_β ) italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT ≤ divide start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG ≤ italic_β start_POSTSUPERSCRIPT italic_η italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t - italic_i ) italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT / ( roman_log ( italic_β ) italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT

Thus we have shown that

Q^(w(i))Q^(w(t))=(βt,iβ)tisuperscript^𝑄superscript𝑤𝑖superscript^𝑄superscript𝑤𝑡superscriptsubscript𝛽𝑡𝑖𝛽𝑡𝑖\frac{\hat{Q}^{\prime}(w^{(i)})}{\hat{Q}^{\prime}(w^{(t)})}=\left(\frac{\beta_% {t,i}}{\beta}\right)^{t-i}divide start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG = ( divide start_ARG italic_β start_POSTSUBSCRIPT italic_t , italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_β end_ARG ) start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT

where

βt,i=β+O(η).subscript𝛽𝑡𝑖𝛽𝑂𝜂\beta_{t,i}=\beta+O(\eta).italic_β start_POSTSUBSCRIPT italic_t , italic_i end_POSTSUBSCRIPT = italic_β + italic_O ( italic_η ) .

Therefore we have

βtiQ^(w(i))Q^(w(t))=βt,iti=(β+O(η))ti=βti(1+O(η)),superscript𝛽𝑡𝑖superscript^𝑄superscript𝑤𝑖superscript^𝑄superscript𝑤𝑡superscriptsubscript𝛽𝑡𝑖𝑡𝑖superscript𝛽𝑂𝜂𝑡𝑖superscript𝛽𝑡𝑖1𝑂𝜂\beta^{t-i}\frac{\hat{Q}^{\prime}(w^{(i)})}{\hat{Q}^{\prime}(w^{(t)})}=\beta_{% t,i}^{t-i}=(\beta+O(\eta))^{t-i}=\beta^{t-i}(1+O(\eta)),italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT divide start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG = italic_β start_POSTSUBSCRIPT italic_t , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT = ( italic_β + italic_O ( italic_η ) ) start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT = italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ( 1 + italic_O ( italic_η ) ) ,

as desired. The final equality holds since (β+O(η))tisuperscript𝛽𝑂𝜂𝑡𝑖(\beta+O(\eta))^{t-i}( italic_β + italic_O ( italic_η ) ) start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT is a polynomial in β𝛽\betaitalic_β and O(η)𝑂𝜂O(\eta)italic_O ( italic_η ), which can be computed by expanding the product. Each term in the resulting sum is either βtisuperscript𝛽𝑡𝑖\beta^{t-i}italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT, O(η)𝑂𝜂O(\eta)italic_O ( italic_η ), or o(η)𝑜𝜂o(\eta)italic_o ( italic_η ). ∎

We now have all that we need to the following analog of Theorem 5.1 for gradient descent with momentum.

Theorem D.3.

Suppose that E(t)superscript𝐸𝑡E^{(t)}italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT is defined by Equation 14, for Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net with gradient estimators, learning rates, and initial weights given by Table 2. Suppose that Assumptions 5.1.1 and 5.1.2 hold and the model weights are updated according to Equation 2, where g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT is defined by Equation 38. In addition, suppose that each fQ^(t)subscriptsuperscript𝑓𝑡^𝑄\nabla f^{(t)}_{\hat{Q}}∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT is bounded by Equation 39. Then we have

E(t+1)E(t)+αη|(1β)i=0tβti(fQ^(i)fSTE(t))|+L2(ηg+L)2+O(η2)superscript𝐸𝑡1superscript𝐸𝑡𝛼𝜂1𝛽superscriptsubscript𝑖0𝑡superscript𝛽𝑡𝑖subscriptsuperscript𝑓𝑖^𝑄subscriptsuperscript𝑓𝑡𝑆𝑇𝐸superscript𝐿2superscript𝜂subscript𝑔subscript𝐿2𝑂superscript𝜂2E^{(t+1)}\leq E^{(t)}+\alpha\eta\left|(1-\beta)\sum_{i=0}^{t}\beta^{t-i}(% \nabla f^{(i)}_{\hat{Q}}-\nabla f^{(t)}_{STE})\right|+\frac{L^{\prime}}{2}% \cdot\left(\frac{\eta g_{+}}{L_{-}}\right)^{2}+O(\eta^{2})italic_E start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ≤ italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_α italic_η | ( 1 - italic_β ) ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT - ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT ) | + divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ⋅ ( divide start_ARG italic_η italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (51)
Proof.

Assumption C.1.3 holds by Theorem D.2 with c(η)=η2𝑐𝜂superscript𝜂2c(\eta)=\eta^{2}italic_c ( italic_η ) = italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, so that Theorem C.1 holds. Note that Assumption D.2.3 holds by a combination of Theorem D.1 and Equation 2. We can now obtain Equation 51 from Equation 17 by simplifying terms and applying the appropriate bounds:

E(t+1)superscript𝐸𝑡1absent\displaystyle E^{(t+1)}\leqitalic_E start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ≤ E(t)+|αg(t)(αfQ^(0),,fQ^(t),η)g(t)(fSTE(0),,fSTE(t),αη)|+superscript𝐸𝑡limit-from𝛼superscript𝑔𝑡𝛼subscriptsuperscript𝑓0^𝑄subscriptsuperscript𝑓𝑡^𝑄𝜂superscript𝑔𝑡subscriptsuperscript𝑓0𝑆𝑇𝐸subscriptsuperscript𝑓𝑡𝑆𝑇𝐸𝛼𝜂\displaystyle E^{(t)}+\left|\alpha g^{(t)}(\alpha\nabla f^{(0)}_{\hat{Q}},% \ldots,\nabla f^{(t)}_{\hat{Q}},\eta)-g^{(t)}(\nabla f^{(0)}_{STE},\ldots,% \nabla f^{(t)}_{STE},\alpha\eta)\right|+italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + | italic_α italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( italic_α ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , italic_η ) - italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , italic_α italic_η ) | + (52)
L2(g(t)(fQ^(0)Q^(w(0)),,fQ^(t)Q^(w(t)),η)L)2+O(c(η))superscript𝐿2superscriptsuperscript𝑔𝑡subscriptsuperscript𝑓0^𝑄superscript^𝑄superscript𝑤0subscriptsuperscript𝑓𝑡^𝑄superscript^𝑄superscript𝑤𝑡𝜂subscript𝐿2𝑂𝑐𝜂\displaystyle\frac{L^{\prime}}{2}\cdot\left(\frac{g^{(t)}(\nabla f^{(0)}_{\hat% {Q}}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^{(t)}_{\hat{Q}}\hat{Q}^{\prime}(% w^{(t)}),\eta)}{L_{-}}\right)^{2}+O(c(\eta))divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ⋅ ( divide start_ARG italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_O ( italic_c ( italic_η ) ) (53)
\displaystyle\leq E(t)+|αη(1β)i=0tβtifQ^(i)+αη(1β)i=0tβtifSTE(i)|+superscript𝐸𝑡limit-from𝛼𝜂1𝛽superscriptsubscript𝑖0𝑡superscript𝛽𝑡𝑖subscriptsuperscript𝑓𝑖^𝑄𝛼𝜂1𝛽superscriptsubscript𝑖0𝑡superscript𝛽𝑡𝑖subscriptsuperscript𝑓𝑖𝑆𝑇𝐸\displaystyle E^{(t)}+\left|-\alpha\eta(1-\beta)\sum_{i=0}^{t}\beta^{t-i}% \nabla f^{(i)}_{\hat{Q}}+\alpha\eta(1-\beta)\sum_{i=0}^{t}\beta^{t-i}\nabla f^% {(i)}_{STE}\right|+italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + | - italic_α italic_η ( 1 - italic_β ) ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT + italic_α italic_η ( 1 - italic_β ) ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT | + (54)
L2(ηg+L)2+O(η2)superscript𝐿2superscript𝜂subscript𝑔subscript𝐿2𝑂superscript𝜂2\displaystyle\frac{L^{\prime}}{2}\cdot\left(\frac{\eta g_{+}}{L_{-}}\right)^{2% }+O(\eta^{2})divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ⋅ ( divide start_ARG italic_η italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (55)
=\displaystyle== E(t)+αη|(1β)i=0tβti(fQ^(i)fSTE(t))|+L2(ηg+L)2+O(η2).superscript𝐸𝑡𝛼𝜂1𝛽superscriptsubscript𝑖0𝑡superscript𝛽𝑡𝑖subscriptsuperscript𝑓𝑖^𝑄subscriptsuperscript𝑓𝑡𝑆𝑇𝐸superscript𝐿2superscript𝜂subscript𝑔subscript𝐿2𝑂superscript𝜂2\displaystyle E^{(t)}+\alpha\eta\left|(1-\beta)\sum_{i=0}^{t}\beta^{t-i}(% \nabla f^{(i)}_{\hat{Q}}-\nabla f^{(t)}_{STE})\right|+\frac{L^{\prime}}{2}% \cdot\left(\frac{\eta g_{+}}{L_{-}}\right)^{2}+O(\eta^{2}).italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_α italic_η | ( 1 - italic_β ) ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT - ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT ) | + divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ⋅ ( divide start_ARG italic_η italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) . (56)

Appendix E Adam

In this Appendix we prove Theorem 5.2 in a manner similar to the proofs given in Appendix C.1. The weight update function for the Adam optimizer is defined by

mt=subscript𝑚𝑡absent\displaystyle m_{t}=italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = β1mt1+(1β1)f(t)Q^(w(t))subscript𝛽1subscript𝑚𝑡11subscript𝛽1superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡\displaystyle\beta_{1}m_{t-1}+(1-\beta_{1})\nabla f^{(t)}\hat{Q}^{\prime}(w^{(% t)})italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) (57)
vt=subscript𝑣𝑡absent\displaystyle v_{t}=italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = β2vt1+(1β2)(f(t)Q^(w(t)))2subscript𝛽2subscript𝑣𝑡11subscript𝛽2superscriptsuperscript𝑓𝑡superscript^𝑄superscript𝑤𝑡2\displaystyle\beta_{2}v_{t-1}+(1-\beta_{2})(\nabla f^{(t)}\hat{Q}^{\prime}(w^{% (t)}))^{2}italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_v start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + ( 1 - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ( ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (58)
m^t=subscript^𝑚𝑡absent\displaystyle\hat{m}_{t}=over^ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = mt/(1β1t)subscript𝑚𝑡1superscriptsubscript𝛽1𝑡\displaystyle m_{t}/(1-\beta_{1}^{t})italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT / ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) (59)
v^t=subscript^𝑣𝑡absent\displaystyle\hat{v}_{t}=over^ start_ARG italic_v end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = vt/(1β2t)subscript𝑣𝑡1superscriptsubscript𝛽2𝑡\displaystyle v_{t}/(1-\beta_{2}^{t})italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT / ( 1 - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) (60)
g(t)(f(0)Q^(w(0)),,f(t)Q^(w(t)),η)=superscript𝑔𝑡superscript𝑓0superscript^𝑄superscript𝑤0superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡𝜂absent\displaystyle g^{(t)}(\nabla f^{(0)}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^% {(t)}\hat{Q}^{\prime}(w^{(t)}),\eta)=italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) = ηm^t/(v^t+ϵ)𝜂subscript^𝑚𝑡subscript^𝑣𝑡italic-ϵ\displaystyle-\eta\hat{m}_{t}/\left(\sqrt{\hat{v}_{t}}+\epsilon\right)- italic_η over^ start_ARG italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT / ( square-root start_ARG over^ start_ARG italic_v end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG + italic_ϵ ) (61)

where β1,β2[0,1)subscript𝛽1subscript𝛽201\beta_{1},\beta_{2}\in[0,1)italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ [ 0 , 1 ) are hyperparameters and ϵitalic-ϵ\epsilonitalic_ϵ is a small constant.

We will first state and prove Theorem E.1, ageneral-purpose precursor to Theorem 5.2 that applies to a large class of adaptive learning rate optimizers. Then we will borrow work from the proof of Theorem D.2 to specify this result for the Adam optimizer and prove Theorem 5.2.

Throughout this section, we will follow [22] and assume for the sake of mathematical argument that the constant ϵitalic-ϵ\epsilonitalic_ϵ in Equation 61 is zero.

Theorem E.1.

Suppose that

E(t):=|wQ^(t)wSTE(t)|assignsuperscript𝐸𝑡subscriptsuperscript𝑤𝑡^𝑄subscriptsuperscript𝑤𝑡𝑆𝑇𝐸E^{(t)}:=\left|w^{(t)}_{\hat{Q}}-w^{(t)}_{STE}\right|italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT := | italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT - italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT | (62)

is the alignment error for Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net with gradient estimators, learning rates, and initial weights given by Table 2. Suppose that the model weights are updated according to Equation 2 for some function g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT. In addition, suppose that

  1. E.1.3

    For each t𝑡titalic_t, the quantity

    |g(t)(fQ^(0)Q^(w(0)),,fQ^(t)Q^(w(t)),η)g(t)(fQ^(0),,fQ^(t),η)|=O(c(η)).superscript𝑔𝑡subscriptsuperscript𝑓0^𝑄superscript^𝑄superscript𝑤0subscriptsuperscript𝑓𝑡^𝑄superscript^𝑄superscript𝑤𝑡𝜂superscript𝑔𝑡subscriptsuperscript𝑓0^𝑄subscriptsuperscript𝑓𝑡^𝑄𝜂𝑂𝑐𝜂\left|g^{(t)}(\nabla f^{(0)}_{\hat{Q}}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f% ^{(t)}_{\hat{Q}}\hat{Q}^{\prime}(w^{(t)}),\eta)-g^{(t)}(\nabla f^{(0)}_{\hat{Q% }},\ldots,\nabla f^{(t)}_{\hat{Q}},\eta)\right|=O(c(\eta)).| italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) - italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , italic_η ) | = italic_O ( italic_c ( italic_η ) ) . (63)

Then we have

E(t+1)superscript𝐸𝑡1absent\displaystyle E^{(t+1)}\leqitalic_E start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ≤ E(t)+|g(t)(fQ^(0),,fQ^(t),η)g(t)(fSTE(0),,fSTE(t),η)|+O(c(η))superscript𝐸𝑡superscript𝑔𝑡subscriptsuperscript𝑓0^𝑄subscriptsuperscript𝑓𝑡^𝑄𝜂superscript𝑔𝑡subscriptsuperscript𝑓0𝑆𝑇𝐸subscriptsuperscript𝑓𝑡𝑆𝑇𝐸𝜂𝑂𝑐𝜂\displaystyle E^{(t)}+\left|g^{(t)}(\nabla f^{(0)}_{\hat{Q}},\ldots,\nabla f^{% (t)}_{\hat{Q}},\eta)-g^{(t)}(\nabla f^{(0)}_{STE},\ldots,\nabla f^{(t)}_{STE},% \eta)\right|+O(c(\eta))italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + | italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , italic_η ) - italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , italic_η ) | + italic_O ( italic_c ( italic_η ) ) (64)
Proof.

By Equation 2, we have

E(t+1)=superscript𝐸𝑡1absent\displaystyle E^{(t+1)}=italic_E start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = |wQ^(t+1)wSTE(t+1)|superscriptsubscript𝑤^𝑄𝑡1superscriptsubscript𝑤𝑆𝑇𝐸𝑡1\displaystyle\left|w_{\hat{Q}}^{(t+1)}-w_{STE}^{(t+1)}\right|| italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT | (65)
=\displaystyle== |wQ^(t)+g(t)(fQ^(0)Q^(w(0)),,fQ^(t),η)\displaystyle\Big{|}w_{\hat{Q}}^{(t)}+g^{(t)}(\nabla f^{(0)}_{\hat{Q}}\hat{Q}^% {\prime}(w^{(0)}),\ldots,\nabla f^{(t)}_{\hat{Q}},\eta)-| italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , italic_η ) - (66)
(wSTE(t)+g(t)(fSTE(0),,fSTE(t),η)|\displaystyle\left(w_{STE}^{(t)}+g^{(t)}(\nabla f^{(0)}_{STE},\ldots,\nabla f^% {(t)}_{STE},\eta\right)\Big{|}( italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , italic_η ) | (67)
=\displaystyle== |wQ^(t)+g(t)(fQ^(0),,fQ^(t),η)+O(c(η))\displaystyle\Big{|}w_{\hat{Q}}^{(t)}+g^{(t)}(\nabla f^{(0)}_{\hat{Q}},\ldots,% \nabla f^{(t)}_{\hat{Q}},\eta)+O(c(\eta))-| italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , italic_η ) + italic_O ( italic_c ( italic_η ) ) - (68)
(wSTE(t)+g(t)(fSTE(0),,fSTE(t),η))|\displaystyle\left(w_{STE}^{(t)}+g^{(t)}(\nabla f^{(0)}_{STE},\ldots,\nabla f^% {(t)}_{STE},\eta)\right)\Big{|}( italic_w start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , italic_η ) ) | (69)
\displaystyle\leq E(t)+|g(t)(fQ^(0),,fQ^(t),η)g(t)(fSTE(0),,fSTE(t),η)|+O(c(η))superscript𝐸𝑡superscript𝑔𝑡subscriptsuperscript𝑓0^𝑄subscriptsuperscript𝑓𝑡^𝑄𝜂superscript𝑔𝑡subscriptsuperscript𝑓0𝑆𝑇𝐸subscriptsuperscript𝑓𝑡𝑆𝑇𝐸𝜂𝑂𝑐𝜂\displaystyle E^{(t)}+\left|g^{(t)}(\nabla f^{(0)}_{\hat{Q}},\ldots,\nabla f^{% (t)}_{\hat{Q}},\eta)-g^{(t)}(\nabla f^{(0)}_{STE},\ldots,\nabla f^{(t)}_{STE},% \eta)\right|+O(c(\eta))italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + | italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT , italic_η ) - italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT , italic_η ) | + italic_O ( italic_c ( italic_η ) ) (70)

Here Equation 70 follows from the triangle inequality, and Equation 69 follows from Assumption C.1.3. ∎

Now we can prove Theorem 5.2.

Proof of Theorem 5.2.

To prove Theorem 5.2, we need to show that the assumptions of Theorem 5.2 imply the Assumption E.1.3 of Theorem E.1 with the Adam update rule defined in Equations 57-61 and c(η)=η2𝑐𝜂superscript𝜂2c(\eta)=\eta^{2}italic_c ( italic_η ) = italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

We first expand Equations 57 and 58, which will allow us to express g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT more explicitly as a function of the fQ^(i)Q^(w(i))subscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖\nabla f^{(i)}_{\hat{Q}}\hat{Q}^{\prime}(w^{(i)})∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ):

mt=subscript𝑚𝑡absent\displaystyle m_{t}=italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = (1β1)i=0tβ1tifQ^(i)Q^(w(i))1subscript𝛽1superscriptsubscript𝑖0𝑡superscriptsubscript𝛽1𝑡𝑖subscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖\displaystyle(1-\beta_{1})\sum_{i=0}^{t}\beta_{1}^{t-i}\nabla f^{(i)}_{\hat{Q}% }\hat{Q}^{\prime}(w^{(i)})( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) (71)
vt=subscript𝑣𝑡absent\displaystyle v_{t}=italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = (1β2)i=0tβ2ti(fQ^(i)Q^(w(i)))21subscript𝛽2superscriptsubscript𝑖0𝑡superscriptsubscript𝛽2𝑡𝑖superscriptsubscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖2\displaystyle(1-\beta_{2})\sum_{i=0}^{t}\beta_{2}^{t-i}(\nabla f^{(i)}_{\hat{Q% }}\hat{Q}^{\prime}(w^{(i)}))^{2}( 1 - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (72)
g(t)(f(0)Q^(w(0)),,f(t)Q^(w(t)),η)=superscript𝑔𝑡superscript𝑓0superscript^𝑄superscript𝑤0superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡𝜂absent\displaystyle g^{(t)}(\nabla f^{(0)}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^% {(t)}\hat{Q}^{\prime}(w^{(t)}),\eta)=italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) = 1β11β1t1β2t1β2\displaystyle-\frac{1-\beta_{1}}{1-\beta_{1}^{t}}\cdot\sqrt{\frac{1-\beta_{2}^% {t}}{1-\beta_{2}}}\cdot- divide start_ARG 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG ⋅ square-root start_ARG divide start_ARG 1 - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG start_ARG 1 - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG ⋅ (73)
ηi=0tβ1tifQ^(i)Q^(w(i))i=0tβ2ti(fQ^(i)Q^(w(i)))2+ϵ𝜂superscriptsubscript𝑖0𝑡superscriptsubscript𝛽1𝑡𝑖subscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖superscriptsubscript𝑖0𝑡superscriptsubscript𝛽2𝑡𝑖superscriptsubscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖2italic-ϵ\displaystyle\frac{\eta\sum_{i=0}^{t}\beta_{1}^{t-i}\nabla f^{(i)}_{\hat{Q}}% \hat{Q}^{\prime}(w^{(i)})}{\sqrt{\sum_{i=0}^{t}\beta_{2}^{t-i}(\nabla f^{(i)}_% {\hat{Q}}\hat{Q}^{\prime}(w^{(i)}))^{2}}+\epsilon}divide start_ARG italic_η ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG start_ARG square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + italic_ϵ end_ARG (74)

Clearly the two fraction terms of Equation 73 are not dependent on Q^superscript^𝑄\hat{Q}^{\prime}over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT in any way, so we need only concern ourselves with the final fraction term in Equation 74. As stated earlier, we are ignoring the ϵitalic-ϵ\epsilonitalic_ϵ term, which allows us to write the final fraction as

i=0tβ1tifQ^(i)Q^(w(i))i=0tβ2ti(fQ^(i)Q^(w(i)))2=superscriptsubscript𝑖0𝑡superscriptsubscript𝛽1𝑡𝑖subscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖superscriptsubscript𝑖0𝑡superscriptsubscript𝛽2𝑡𝑖superscriptsubscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖2absent\displaystyle\frac{\sum_{i=0}^{t}\beta_{1}^{t-i}\nabla f^{(i)}_{\hat{Q}}\hat{Q% }^{\prime}(w^{(i)})}{\sqrt{\sum_{i=0}^{t}\beta_{2}^{t-i}(\nabla f^{(i)}_{\hat{% Q}}\hat{Q}^{\prime}(w^{(i)}))^{2}}}=divide start_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG start_ARG square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG = Q^(w(t))Q^(w(t))ηi=0tβ1tifQ^(i)Q^(w(i))i=0tβ2ti(fQ^(i)Q^(w(i)))2superscript^𝑄superscript𝑤𝑡superscript^𝑄superscript𝑤𝑡𝜂superscriptsubscript𝑖0𝑡superscriptsubscript𝛽1𝑡𝑖subscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖superscriptsubscript𝑖0𝑡superscriptsubscript𝛽2𝑡𝑖superscriptsubscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖2\displaystyle\frac{\hat{Q}^{\prime}(w^{(t)})}{\hat{Q}^{\prime}(w^{(t)})}\cdot% \frac{\eta\sum_{i=0}^{t}\beta_{1}^{t-i}\nabla f^{(i)}_{\hat{Q}}\hat{Q}^{\prime% }(w^{(i)})}{\sqrt{\sum_{i=0}^{t}\beta_{2}^{t-i}(\nabla f^{(i)}_{\hat{Q}}\hat{Q% }^{\prime}(w^{(i)}))^{2}}}divide start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG start_ARG over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG ⋅ divide start_ARG italic_η ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG start_ARG square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG (75)
=\displaystyle== ηi=0tβ1tifQ^(i)Q^(w(i))/Q^(w(t))i=0tβ2ti(fQ^(i)Q^(w(i))/Q^(w(t)))2𝜂superscriptsubscript𝑖0𝑡superscriptsubscript𝛽1𝑡𝑖subscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖superscript^𝑄superscript𝑤𝑡superscriptsubscript𝑖0𝑡superscriptsubscript𝛽2𝑡𝑖superscriptsubscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖superscript^𝑄superscript𝑤𝑡2\displaystyle\frac{\eta\sum_{i=0}^{t}\beta_{1}^{t-i}\nabla f^{(i)}_{\hat{Q}}% \hat{Q}^{\prime}(w^{(i)})/\hat{Q}^{\prime}(w^{(t)})}{\sqrt{\sum_{i=0}^{t}\beta% _{2}^{t-i}(\nabla f^{(i)}_{\hat{Q}}\hat{Q}^{\prime}(w^{(i)})/\hat{Q}^{\prime}(% w^{(t)}))^{2}}}divide start_ARG italic_η ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) / over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) end_ARG start_ARG square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) / over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG (76)

We would like to apply Theorem D.2 to both the numerator and denominator of the final term in the above Equation. Assumptions D.2.1 and D.2.2 are the same as Assumptions 5.2.1 and 5.2.2, respectively. By Equation 2, we can see that Assumption D.2.3 with g+=max{1,(1β1)/(1β2}g_{+}=\max\{1,(1-\beta_{1})/\sqrt{(1-\beta_{2}}\}italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT = roman_max { 1 , ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) / square-root start_ARG ( 1 - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG } is an inherent property of the Adam optimizer [22]. Now by applying Theorem D.2 to the numerator, we have

ηi=0tβ1tifQ^(i)Q^(w(i))/Q^(w(t))=ηi=0tβ1tif(i)+O(η2).𝜂superscriptsubscript𝑖0𝑡superscriptsubscript𝛽1𝑡𝑖subscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖superscript^𝑄superscript𝑤𝑡𝜂superscriptsubscript𝑖0𝑡superscriptsubscript𝛽1𝑡𝑖superscript𝑓𝑖𝑂superscript𝜂2\eta\sum_{i=0}^{t}\beta_{1}^{t-i}\nabla f^{(i)}_{\hat{Q}}\hat{Q}^{\prime}(w^{(% i)})/\hat{Q}^{\prime}(w^{(t)})=\eta\sum_{i=0}^{t}\beta_{1}^{t-i}\nabla f^{(i)}% +O(\eta^{2}).italic_η ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) / over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) = italic_η ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT + italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) .

we see that the numerator limits to i=0tβtif(i)superscriptsubscript𝑖0𝑡superscript𝛽𝑡𝑖superscript𝑓𝑖\sum_{i=0}^{t}\beta^{t-i}\nabla f^{(i)}∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT as η0𝜂0\eta\to 0italic_η → 0. We can show via a very similar proof that the denominator can be approximated as

i=0tβ2ti(f(i))2+O(η).superscriptsubscript𝑖0𝑡superscriptsubscript𝛽2𝑡𝑖superscriptsuperscript𝑓𝑖2𝑂𝜂\sqrt{\sum_{i=0}^{t}\beta_{2}^{t-i}(\nabla f^{(i)})^{2}+O(\eta)}.square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_O ( italic_η ) end_ARG .

The only notable differences are that we are removing an η𝜂\etaitalic_η term, and the exponent in the bound for Q^(w(i))/Q^(w(t))superscript^𝑄superscript𝑤𝑖superscript^𝑄superscript𝑤𝑡\hat{Q}^{\prime}(w^{(i)})/\hat{Q}^{\prime}(w^{(t)})over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) / over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) has an extra 2 in it, which does not affect the result. Therefore we have

g(t)(f(0)Q^(w(0)),,f(t)Q^(w(t)),η)=superscript𝑔𝑡superscript𝑓0superscript^𝑄superscript𝑤0superscript𝑓𝑡superscript^𝑄superscript𝑤𝑡𝜂absent\displaystyle g^{(t)}(\nabla f^{(0)}\hat{Q}^{\prime}(w^{(0)}),\ldots,\nabla f^% {(t)}\hat{Q}^{\prime}(w^{(t)}),\eta)=italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , italic_η ) = 1β11β1t1β2t1β2\displaystyle-\frac{1-\beta_{1}}{1-\beta_{1}^{t}}\cdot\sqrt{\frac{1-\beta_{2}^% {t}}{1-\beta_{2}}}\cdot- divide start_ARG 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG ⋅ square-root start_ARG divide start_ARG 1 - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG start_ARG 1 - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG ⋅ (77)
ηi=0tβ1tifQ^(i)Q^(w(i))i=0tβ2ti(fQ^(i)Q^(w(i)))2𝜂superscriptsubscript𝑖0𝑡superscriptsubscript𝛽1𝑡𝑖subscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖superscriptsubscript𝑖0𝑡superscriptsubscript𝛽2𝑡𝑖superscriptsubscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖2\displaystyle\frac{\eta\sum_{i=0}^{t}\beta_{1}^{t-i}\nabla f^{(i)}_{\hat{Q}}% \hat{Q}^{\prime}(w^{(i)})}{\sqrt{\sum_{i=0}^{t}\beta_{2}^{t-i}(\nabla f^{(i)}_% {\hat{Q}}\hat{Q}^{\prime}(w^{(i)}))^{2}}}divide start_ARG italic_η ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG start_ARG square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG (78)
=\displaystyle== 1β11β1t1β2t1β2\displaystyle-\frac{1-\beta_{1}}{1-\beta_{1}^{t}}\cdot\sqrt{\frac{1-\beta_{2}^% {t}}{1-\beta_{2}}}\cdot- divide start_ARG 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG ⋅ square-root start_ARG divide start_ARG 1 - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG start_ARG 1 - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG ⋅ (79)
ηi=0tβ1tifQ^(i)Q^(w(i))+O(η2)i=0tβ2ti(f(i))2+O(η)𝜂superscriptsubscript𝑖0𝑡superscriptsubscript𝛽1𝑡𝑖subscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖𝑂superscript𝜂2superscriptsubscript𝑖0𝑡superscriptsubscript𝛽2𝑡𝑖superscriptsuperscript𝑓𝑖2𝑂𝜂\displaystyle\frac{\eta\sum_{i=0}^{t}\beta_{1}^{t-i}\nabla f^{(i)}_{\hat{Q}}% \hat{Q}^{\prime}(w^{(i)})+O(\eta^{2})}{\sqrt{\sum_{i=0}^{t}\beta_{2}^{t-i}(% \nabla f^{(i)})^{2}+O(\eta)}}divide start_ARG italic_η ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) + italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG start_ARG square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_O ( italic_η ) end_ARG end_ARG (80)
ηi=0tβ1tifQ^(i)Q^(w(i))i=0tβ2ti(fQ^(i)Q^(w(i)))2𝜂superscriptsubscript𝑖0𝑡superscriptsubscript𝛽1𝑡𝑖subscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖superscriptsubscript𝑖0𝑡superscriptsubscript𝛽2𝑡𝑖superscriptsubscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖2\displaystyle\frac{\eta\sum_{i=0}^{t}\beta_{1}^{t-i}\nabla f^{(i)}_{\hat{Q}}% \hat{Q}^{\prime}(w^{(i)})}{\sqrt{\sum_{i=0}^{t}\beta_{2}^{t-i}(\nabla f^{(i)}_% {\hat{Q}}\hat{Q}^{\prime}(w^{(i)}))^{2}}}divide start_ARG italic_η ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG start_ARG square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG (81)
=\displaystyle== 1β11β1t1β2t1β2\displaystyle-\frac{1-\beta_{1}}{1-\beta_{1}^{t}}\cdot\sqrt{\frac{1-\beta_{2}^% {t}}{1-\beta_{2}}}\cdot- divide start_ARG 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG ⋅ square-root start_ARG divide start_ARG 1 - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG start_ARG 1 - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG end_ARG ⋅ (82)
ηi=0tβ1tifQ^(i)Q^(w(i))i=0tβ2ti(f(i))2+O(η2)𝜂superscriptsubscript𝑖0𝑡superscriptsubscript𝛽1𝑡𝑖subscriptsuperscript𝑓𝑖^𝑄superscript^𝑄superscript𝑤𝑖superscriptsubscript𝑖0𝑡superscriptsubscript𝛽2𝑡𝑖superscriptsuperscript𝑓𝑖2𝑂superscript𝜂2\displaystyle\frac{\eta\sum_{i=0}^{t}\beta_{1}^{t-i}\nabla f^{(i)}_{\hat{Q}}% \hat{Q}^{\prime}(w^{(i)})}{\sqrt{\sum_{i=0}^{t}\beta_{2}^{t-i}(\nabla f^{(i)})% ^{2}}}+O(\eta^{2})divide start_ARG italic_η ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) end_ARG start_ARG square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - italic_i end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG + italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (83)
=\displaystyle== g(t)(f(0),,f(t),η)+O(η2)superscript𝑔𝑡superscript𝑓0superscript𝑓𝑡𝜂𝑂superscript𝜂2\displaystyle g^{(t)}(\nabla f^{(0)},\ldots,\nabla f^{(t)},\eta)+O(\eta^{2})italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ( ∇ italic_f start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , … , ∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , italic_η ) + italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (84)

so that Assumption E.1.3 holds with c(η)=η2𝑐𝜂superscript𝜂2c(\eta)=\eta^{2}italic_c ( italic_η ) = italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. The only potential issue with this derivation is in the removal of the denominator O(η)𝑂𝜂O(\eta)italic_O ( italic_η ) term in Equation 83. In order for this to work, we need the denominator to be nonzero. However, if the denominator is zero, then Assumption E.1.3 holds trivially. This concludes the proof.

Note: The reader may be concerned as to why the Q^(w(i))superscript^𝑄superscript𝑤𝑖\hat{Q}^{\prime}(w^{(i)})over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) terms disappeared from g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT but the f(i)superscript𝑓𝑖\nabla f^{(i)}∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT terms did not. The reason is that the Q^(w(i))superscript^𝑄superscript𝑤𝑖\hat{Q}^{\prime}(w^{(i)})over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) terms vary continuously with the latent weight, whereas the f(i)superscript𝑓𝑖\nabla f^{(i)}∇ italic_f start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT terms are stochastic. ∎

Appendix F Learning Rate Schedules

Learning rate schedules. All of the learning algorithms described in Section 3 can make use of a learning rate schedule [37, 6], [24, 28, 43]. A learning rate schedule essentially amounts to scaling each the gradient update steps g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT by a pre-determined positive number ηtsubscript𝜂𝑡\eta_{t}italic_η start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. In this case, the initial learning rate η𝜂\etaitalic_η acts as a scale on the entire learning rate schedule.

Theorems C.1 and E.1 are general-purpose tools for proving results like Theorems 5.1 and 5.2 for non-adaptive learning rate optimizers and adaptive learning rate optimizers, respectively. Up until this point, we have only focused on fixed learning rate schedules, and here we describe how the theorems be applied to general learning rate schedules.

As stated in Section 3, a learning rate schedule applies a pre-determined scale ηtsubscript𝜂𝑡\eta_{t}italic_η start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to each of the gradient update steps g(t)superscript𝑔𝑡g^{(t)}italic_g start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT, which can effectively be absored into the f(t)superscript𝑓𝑡\nabla f^{(t)}∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT terms for non-adaptive optimizers. This does not affect Assumptions 5.1.1, 5.1.2, 5.2.1, or 5.2.2 in any way. It may affect the bounds on fQ^(t)subscriptsuperscript𝑓𝑡^𝑄\nabla f^{(t)}_{\hat{Q}}∇ italic_f start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT in Theorem D.3, but this would simply require a different value of g+subscript𝑔g_{+}italic_g start_POSTSUBSCRIPT + end_POSTSUBSCRIPT.

Thus we can confidently generalize our main results to gradient update rules that take advantage of learning rate schedules.

Appendix G On nonpositive gradient estimators

Here we describe the statements we can make that bear relation to Theorems 5.1 and 5.2 for gradient estimators that break the lower bound conditions in Assumptions 5.1.1 and 5.2.1.

The common case for nonpositive gradient estimators. Assumptions 5.1.1 and 5.2.1 are most commonly broken when Q^superscript^𝑄\hat{Q}^{\prime}over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, like the PWL estimator (See Section 2), is positive on some range [wmin,wmax]subscript𝑤𝑚𝑖𝑛subscript𝑤𝑚𝑎𝑥[w_{min},w_{max}][ italic_w start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT ] and zero outside of this range. The behavior of these gradient estimators cannot be mimicked by any model that uses the STE, since the latent weight can reach a point where it no longer receives updates from gradients. However, this behavior can be mimicked by a model that uses PWL estimator. If we set

w~min:=assignsubscript~𝑤𝑚𝑖𝑛absent\displaystyle\tilde{w}_{min}:=over~ start_ARG italic_w end_ARG start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT := M(wmin)𝑀subscript𝑤𝑚𝑖𝑛\displaystyle M(w_{min})italic_M ( italic_w start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT ) (85)
w~max:=assignsubscript~𝑤𝑚𝑎𝑥absent\displaystyle\tilde{w}_{max}:=over~ start_ARG italic_w end_ARG start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT := M(wmax),𝑀subscript𝑤𝑚𝑎𝑥\displaystyle M(w_{max}),italic_M ( italic_w start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT ) , (86)

then Theorems 5.1 and 5.2 clearly apply after replacing the STE with PWLw~min,w~max𝑃𝑊subscript𝐿subscript~𝑤𝑚𝑖𝑛subscript~𝑤𝑚𝑎𝑥PWL_{\tilde{w}_{min},\tilde{w}_{max}}italic_P italic_W italic_L start_POSTSUBSCRIPT over~ start_ARG italic_w end_ARG start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT , over~ start_ARG italic_w end_ARG start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT end_POSTSUBSCRIPT (for SGD), PWLwmin,wmax𝑃𝑊subscript𝐿subscript𝑤𝑚𝑖𝑛subscript𝑤𝑚𝑎𝑥PWL_{w_{min},w_{max}}italic_P italic_W italic_L start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT end_POSTSUBSCRIPT (for Adam), whenever wQ^(t)subscriptsuperscript𝑤𝑡^𝑄w^{(t)}_{\hat{Q}}italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT and wSTE(t)subscriptsuperscript𝑤𝑡𝑆𝑇𝐸w^{(t)}_{STE}italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT are in the representable range. Technically, M(wQ^(0))𝑀superscriptsubscript𝑤^𝑄0M(w_{\hat{Q}}^{(0)})italic_M ( italic_w start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) is only defined when wQ^(0)[wmin,wmax]subscriptsuperscript𝑤0^𝑄subscript𝑤𝑚𝑖𝑛subscript𝑤𝑚𝑎𝑥w^{(0)}_{\hat{Q}}\in[w_{min},w_{max}]italic_w start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT ∈ [ italic_w start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT ], but we can ignore this case under the assumption that no practitioner would initialize a weight to be untrainable. There are two remaining cases to consider. The first is where wQ^(t)subscriptsuperscript𝑤𝑡^𝑄w^{(t)}_{\hat{Q}}italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT over^ start_ARG italic_Q end_ARG end_POSTSUBSCRIPT and wSTE(t)subscriptsuperscript𝑤𝑡𝑆𝑇𝐸w^{(t)}_{STE}italic_w start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_S italic_T italic_E end_POSTSUBSCRIPT both lay outside of the representable range, in which case neither weight can move and there is no risk of increasing E(t)superscript𝐸𝑡E^{(t)}italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT. The second is where only one lies in this range, and one weight is “trapped" while the other is “free". This is unlikely to happen due to the bounds on E(t)superscript𝐸𝑡E^{(t)}italic_E start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT, but it could technically lead to high weight alignment errors.

Negative gradient estimators. The other way that the lower bound in Assumption 5.1.1 can be broken is if Q^(w)^𝑄𝑤\hat{Q}(w)over^ start_ARG italic_Q end_ARG ( italic_w ) is actually negative for some range of values of w𝑤witalic_w. There is some work [5, 46] that proposes gradient estimators with negative derivatives, but most choose a nonnegative derivative to align with the nondecreasing behavior of the quantizer function. In the cases with negative Q^superscript^𝑄\hat{Q}^{\prime}over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT values, slightly modified versions of Theorems 5.1 and 5.2 apply on the negative ranges, where the gradient estimator of STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net is the negative of the STE. Since this is a rare choice for QAT, we do not provide the details here.

Thus almost all common gradient estimators can be replaced with the STE or a PWL estimator.

Appendix H Calculating constants in Theorem 5.1

Many gradient estimators take the form

Q^(w)=tanh(k(wa)+a\hat{Q}(w)=\tanh(k\cdot(w-a)+aover^ start_ARG italic_Q end_ARG ( italic_w ) = roman_tanh ( italic_k ⋅ ( italic_w - italic_a ) + italic_a

for w𝑤witalic_w in the representable range, and a𝑎aitalic_a is the center of the quantization bin w𝑤witalic_w is in. This is the case for [12] and [32], hence our choice of the gradient estimator from [32] for the experiments. This is also very similar to the gradient estimator used in [48].

Given this definition of Q^(w)^𝑄𝑤\hat{Q}(w)over^ start_ARG italic_Q end_ARG ( italic_w ), we want to provide lower and upper bounds on the first and second derivatives of Q𝑄Qitalic_Q on the interval [Δ/2,Δ/2]Δ2Δ2[-\Delta/2,\Delta/2][ - roman_Δ / 2 , roman_Δ / 2 ] with a=0𝑎0a=0italic_a = 0. First note that we have

Q^(w)=kcosh2(kw)superscript^𝑄𝑤𝑘superscript2𝑘𝑤\hat{Q}^{\prime}(w)=\frac{k}{\cosh^{2}(kw)}over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_w ) = divide start_ARG italic_k end_ARG start_ARG roman_cosh start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_k italic_w ) end_ARG

This obtains a maximum value at w=1𝑤1w=1italic_w = 1, and a minimum value at ±Δ/2plus-or-minusΔ2\pm\Delta/2± roman_Δ / 2, so that L+=ksubscript𝐿𝑘L_{+}=kitalic_L start_POSTSUBSCRIPT + end_POSTSUBSCRIPT = italic_k and L=k/cosh2(kΔ/2)subscript𝐿𝑘superscript2𝑘Δ2L_{-}=k/\cosh^{2}(k\Delta/2)italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT = italic_k / roman_cosh start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_k roman_Δ / 2 ).

Q^′′(w)=2k2tanh(kw)cosh2(kw)superscript^𝑄′′𝑤2superscript𝑘2𝑘𝑤superscript2𝑘𝑤\hat{Q}^{\prime\prime}(w)=-2k^{2}\frac{\tanh(kw)}{\cosh^{2}(kw)}over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( italic_w ) = - 2 italic_k start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT divide start_ARG roman_tanh ( italic_k italic_w ) end_ARG start_ARG roman_cosh start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_k italic_w ) end_ARG

This obtains its maximum values at

w=±12klog(2+3)𝑤plus-or-minus12𝑘23w=\pm\frac{1}{2k}\log(2+\sqrt{3})italic_w = ± divide start_ARG 1 end_ARG start_ARG 2 italic_k end_ARG roman_log ( 2 + square-root start_ARG 3 end_ARG )

and is strictly decreasing on the interval between these points. Since a bound on |Q^′′(w)|superscript^𝑄′′𝑤|\hat{Q}^{\prime\prime}(w)|| over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( italic_w ) | is a Lipschitz constant for Q^superscript^𝑄\hat{Q}^{\prime}over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, Lsuperscript𝐿L^{\prime}italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is given by

2k2tanh(kw)cosh2(kw)2superscript𝑘2𝑘𝑤superscript2𝑘𝑤2k^{2}\frac{\tanh(kw)}{\cosh^{2}(kw)}2 italic_k start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT divide start_ARG roman_tanh ( italic_k italic_w ) end_ARG start_ARG roman_cosh start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_k italic_w ) end_ARG

where

w=min(Δ/2,12klog(2+3))𝑤Δ212𝑘23w=\min\left(\Delta/2,\frac{1}{2k}\log(2+\sqrt{3})\right)italic_w = roman_min ( roman_Δ / 2 , divide start_ARG 1 end_ARG start_ARG 2 italic_k end_ARG roman_log ( 2 + square-root start_ARG 3 end_ARG ) )

In [32], k𝑘kitalic_k is set to to 8, 6, 4, and 2 for 8, 4, 3, and 2-bit quantization. They initialize ΔΔ\Deltaroman_Δ to 2/(2b1)2superscript2𝑏12/(2^{b}-1)2 / ( 2 start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT - 1 ) where b𝑏bitalic_b is the number of bits used for quantization. This gives us the following values for LL+/2L2superscript𝐿subscript𝐿2superscriptsubscript𝐿2L^{\prime}L_{+}/2L_{-}^{2}italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT + end_POSTSUBSCRIPT / 2 italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT: 0.25 (8 bits), 2.66 (4 bits), 2.82 (3 bits), 1.77 (2 bits). These values are small relative to standard values of 1/η1𝜂1/\eta1 / italic_η, where η𝜂\etaitalic_η is the learning rate.

For [12], the quantizer is parametrized by a value α𝛼\alphaitalic_α defined by

α=1tanh(kΔ/2).𝛼1𝑘Δ2\alpha=1-\tanh(k\Delta/2).italic_α = 1 - roman_tanh ( italic_k roman_Δ / 2 ) .

This gives us convenient formulas:

tanh(kΔ/2)=1α𝑘Δ21𝛼\tanh(k\Delta/2)=1-\alpharoman_tanh ( italic_k roman_Δ / 2 ) = 1 - italic_α
1cosh(kΔ/2)2=1(1α)2=2αα2\frac{1}{\cosh(k\Delta/2)^{2}}=1-(1-\alpha)^{2}=2\alpha-\alpha^{2}divide start_ARG 1 end_ARG start_ARG roman_cosh ( italic_k roman_Δ / 2 ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG = 1 - ( 1 - italic_α ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 2 italic_α - italic_α start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
tanh(kΔ/2)cosh(kΔ/2)2=(1α)(2αα2)\frac{\tanh(k\Delta/2)}{\cosh(k\Delta/2)^{2}}=(1-\alpha)(2\alpha-\alpha^{2})divide start_ARG roman_tanh ( italic_k roman_Δ / 2 ) end_ARG start_ARG roman_cosh ( italic_k roman_Δ / 2 ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG = ( 1 - italic_α ) ( 2 italic_α - italic_α start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
L+L=12αα2subscript𝐿subscript𝐿12𝛼superscript𝛼2\frac{L_{+}}{L_{-}}=\frac{1}{2\alpha-\alpha^{2}}divide start_ARG italic_L start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT end_ARG = divide start_ARG 1 end_ARG start_ARG 2 italic_α - italic_α start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
LL+L21α2αα2superscript𝐿subscript𝐿superscriptsubscript𝐿21𝛼2𝛼superscript𝛼2\frac{L^{\prime}L_{+}}{L_{-}^{2}}\leq\frac{1-\alpha}{2\alpha-\alpha^{2}}divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ≤ divide start_ARG 1 - italic_α end_ARG start_ARG 2 italic_α - italic_α start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG

The constant of interest is then given by

LL+L21α2αα2superscript𝐿subscript𝐿superscriptsubscript𝐿21𝛼2𝛼superscript𝛼2\frac{L^{\prime}L_{+}}{L_{-}^{2}}\leq\frac{1-\alpha}{2\alpha-\alpha^{2}}divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ≤ divide start_ARG 1 - italic_α end_ARG start_ARG 2 italic_α - italic_α start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG

During training in [12], α𝛼\alphaitalic_α is varied for weight quantizers between 0.110.110.110.11 and 0.250.250.250.25, giving us

LL+L2[1.71,4.28].superscript𝐿subscript𝐿superscriptsubscript𝐿21.714.28\frac{L^{\prime}L_{+}}{L_{-}^{2}}\in[1.71,4.28].divide start_ARG italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT - end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∈ [ 1.71 , 4.28 ] .

These values are again small relative to 1/η1𝜂1/\eta1 / italic_η.

Appendix I Experiment Setup Details

Weight Initialization and Quantizers: We initialize the weights of Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net using He Uniform Initialization111https://www.tensorflow.org/api_docs/python/tf/keras/initializers/HeNormal. For quantization, we use a uniform weight quantizer with representable range limits given by bounds of the weight initialization distribution. We do not quantize activations. We focus primarily on two-bit weight quantization, and note that results are similar for 1-bit and 4-bit quantization. For gradient estimation, we use the Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG given by the HTGE [32] gradient estimator formula with shape parameter t𝑡titalic_t set to 5.5 times the maximum value from the weight initialization distribution. This value was chosen so that Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG differs significantly from the STE, but not so significantly that parts of Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG become essentially flat.

Optimization techniques. For optimization techniques on both models, we consider both SGD with momentum=0.9absent0.9=0.9= 0.9 and Adam with β1=0.9subscript𝛽10.9\beta_{1}=0.9italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 0.9 and β2=0.95subscript𝛽20.95\beta_{2}=0.95italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 0.95. For all experiments, we use a cosine decay learning rate schedule [28] with a linear learning rate warmup [13] for 2% of training epochs. The reported learning rate for each model is the initial learning rate for the cosine decay. We use a learning rate of 0.001 for our default MNIST SGD with momentum model, and 0.0001 for our default MNIST Adam model. For the ResNet50 on ImageNet model we apply the standard learning rate schedule implemented in [10] with a configured learning rate of 0.0001, for Adam and 0.001 for SGD and otherwise default parameters.

Identical Initial Training period. For the ImageNet-ResNet setup, we ensured that the first 10% of training for Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net and STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net were identical. To do this, we trained STE𝑆𝑇𝐸STEitalic_S italic_T italic_E-net by first training Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARG-net for the first 10 of 100 epochs, and then applied M𝑀Mitalic_M to the weights and optimizer state and switched the model’s quantizer for the STE before continuing training. This was applied for all model comparisons.