Adaptive Adversarial Cross-Entropy Loss
for Sharpness-Aware Minimization

Abstract

Recent advancements in learning algorithms have demonstrated that the sharpness of the loss surface is an effective measure for improving the generalization gap. Building upon this concept, Sharpness-Aware Minimization (SAM) was proposed to enhance model generalization and achieved state-of-the-art performance. SAM consists of two main steps, the weight perturbation step and the weight updating step. However, the perturbation in SAM is determined by only the gradient of the training loss, or cross-entropy loss. As the model approaches a stationary point, this gradient becomes small and oscillates, leading to inconsistent perturbation directions and also has a chance of diminishing the gradient. Our research introduces an innovative approach to further enhancing model generalization. We propose the Adaptive Adversarial Cross-Entropy (AACE) loss function to replace standard cross-entropy loss for SAM’s perturbation. AACE loss and its gradient uniquely increase as the model nears convergence, ensuring consistent perturbation direction and addressing the gradient diminishing issue. Additionally, a novel perturbation-generating function utilizing AACE loss without normalization is proposed, enhancing the model’s exploratory capabilities in near-optimum stages. Empirical testing confirms the effectiveness of AACE, with experiments demonstrating improved performance in image classification tasks using Wide ResNet and PyramidNet across various datasets. The reproduction code is available online 111http://www.vip.sc.e.titech.ac.jp/proj/AACE
.

Index Terms—  Adaptive Adversarial Cross-Entropy, Model Generalization, Sharpness-Aware Minimization, Deep Learning

© 2024 IEEE. Personal use of this material is permitted. Permission from IEEE must be obtained for all other uses, in any current or future media, including reprinting/republishing this material for advertising or promotional purposes, creating new collective works, for resale or redistribution to servers or lists, or reuse of any copyrighted component of this work in other works.

1 Introduction

In the recent development of machine learning, there has been a noticeable trend where models are becoming highly overparameterized. While these models are excellent at memorizing training data, a significant challenge arises in their performance on new, unseen data. This problem, known as overfitting, leads to a notable gap in performance between training and testing datasets [1]. Understanding how to improve the generalization of these models is crucial, as it can help them perform well not just on the data they were trained on, but also on new data they have never seen before.

To address the issue of generalization, researchers have explored various approaches. Some have taken a Bayesian perspective to understand this problem [2, 3], while others have looked at it through the information theory [4]. Other significant areas of research are to investigate the impact of learning rate [5, 6, 7] and batch size [8] on a model’s generalization ability. Numerous techniques have been proposed to improve model generalization. Entropy-SGD uses local entropy [9]. Using Adam [10] as an optimizer in early training and switching to SGD [11] in later phases is also proven to improve generalization [8]. Integrating a partial adaptive parameter to the adaptive gradient methods such as Adam, Amsgrad was also introduced [12]. Moreover, FOCA which avoids co-adaptation between a feature extractor and a particular classifier is another way to improve generalization [13].

Refer to caption
Fig. 1: Comparison of loss and gradient between standard cross-entropy loss and Adaptive Adversarial Cross-Entropy loss at early stage (wisubscript𝑤𝑖w_{i}italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT) and later stage (wtsubscript𝑤𝑡w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT) of training.

Another important aspect of research focuses on the techniques related to the shape of the loss landscape and its connection to model generalization. Studies have shown that the sharpness of the loss surface and the minimization of derived generalization bounds are crucial for achieving superior performance in various tasks [14, 8, 15].

Develo** efficient algorithms that aim for flatter minima, which in turn could improve generalization, remains a challenging area of research. Recently, Sharpness-Aware Minimization (SAM), an algorithm to search for flatter areas by adding small perturbations to the model parameters, was proposed and has proven to be generic and effective on several datasets and model architectures [16].

SAM seeks a flat landscape by modifying the optimization process to explicitly consider the sharpness of the minima. SAM’s algorithm can be decomposed into two main steps. First, it finds a parameter configuration (weights) where the loss is high within a small neighborhood around the current weights. Then, it minimizes the model loss by using the gradient at this worst-case configuration.

Several novel methods also improve SAM’s generalization performance further. GSAM introduced a sharpness measurement called surrogate gap [17]. PoF introduced a technique that updates the feature extractor to search for a flatter minima [18]. The adaptive sharpness which is scale-invariant was also introduced in ASAM [19]. GA-SAM is another work that analyzes the relationship between local minima and generalization ability [20].

In this research, we found that while SAM has shown promising performance, there are still some issues to be concerned about. In finding the worst-case parameters, SAM’s perturbation depends on the normalized gradient of cross-entropy loss and a pre-defined constant radius of the neighborhood. Since at the nearly optimum stage, the gradient of cross-entropy loss is very small and fluctuates around the optimum point, this leads to the unstable direction of the perturbation. Another noticeable issue is that, at the nearly optimum stage, the magnitude of the gradient of cross-entropy loss becomes smaller and smaller and has a risk of being zero which could cause devising by zero problem.

We propose a new approach to mitigate those issues by modifying loss in SAM’s perturbation step. Instead of using cross-entropy loss that gets smaller as the model is trained, we introduce a new loss, Adaptive Adversarial Cross-Entropy (AACE) that grows as the model converges.

As demonstrated in Fig. 1, at the early stage of the training (wisubscript𝑤𝑖w_{i}italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT) both the loss and the magnitude of the gradient of the standard cross-entropy loss are high, and decrease as the model approaches convergence (wtsubscript𝑤𝑡w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT). On the contrary, AACE loss and its gradient magnitude start low and increase over the training process. This growing loss helps avoid the risk of the gradient diminishing at the saturated stage and leads to a more consistent direction of the gradient and the perturbation.

With this new loss, we also proposed not to normalize the loss in the perturbation step, making the perturbation not dependent on only a pre-defined constant. The new method also enlarges the magnitude of the perturbation step as the model converges, making the training more explorative even at the nearly optimum stage.

2 Preliminary

In traditional training of deep neural networks, optimization techniques like Stochastic Gradient Descent (SGD) seek to minimize the loss function. However, this process may converge to sharp minima, which are points in the parameter space where the loss is low for the training data but potentially high for unseen data. Sharp minima are believed to be less robust and generalize worse compared to flat minima.

Sharpness-Aware Minimization (SAM) is a novel training methodology designed to enhance the generalization performance of deep learning models. Traditional training methods often converge to sharp minima, leading to suboptimal generalization. SAM, however, aims to find parameters that reside in neighborhoods having uniformly low loss, thus avoiding sharp minima. This is achieved through a min-max optimization problem efficiently solvable via gradient descent.

Instead of trying to minimize the loss as in vanilla training, SAM’s objective is to minimize the perturbed loss which can be described as:

LSAM(w)=maxερLs(w+ε),subscript𝐿SAM𝑤subscriptnorm𝜀𝜌subscript𝐿𝑠𝑤𝜀L_{\rm SAM}(w)=\max_{\left\|\varepsilon\right\|\leqslant\rho}L_{s}(w+% \varepsilon)\,,italic_L start_POSTSUBSCRIPT roman_SAM end_POSTSUBSCRIPT ( italic_w ) = roman_max start_POSTSUBSCRIPT ∥ italic_ε ∥ ⩽ italic_ρ end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ( italic_w + italic_ε ) , (1)

where Ls(w)subscript𝐿𝑠𝑤L_{s}(w)italic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ( italic_w ) is the training loss, w𝑤witalic_w represents the model parameters, and ε𝜀\varepsilonitalic_ε is a perturbation vector bounded by ρ𝜌\rhoitalic_ρ in the L2-norm. The optimization seeks parameters w𝑤witalic_w such that the loss is minimized not just at w𝑤witalic_w but in its neighborhood within a radius of ρ𝜌\rhoitalic_ρ.

In the case of small ρ𝜌\rhoitalic_ρ, applying Taylor expansion around w𝑤witalic_w, the ε𝜀\varepsilonitalic_ε that satisfied the inner maximization in Eq. 1 can be expressed as:

ε=StopGrad(ρLs(w)Ls(w)2),𝜀StopGrad𝜌subscript𝐿𝑠𝑤subscriptnormsubscript𝐿𝑠𝑤2\varepsilon={\rm StopGrad}\left(\rho\frac{\triangledown L_{s}(w)}{\left\|% \triangledown L_{s}(w)\right\|_{2}}\right)\,,italic_ε = roman_StopGrad ( italic_ρ divide start_ARG ▽ italic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ( italic_w ) end_ARG start_ARG ∥ ▽ italic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ( italic_w ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ) , (2)

where StopGradStopGrad{\rm StopGrad}roman_StopGrad represents the stop graduation operation. Note that StopGradStopGrad{\rm StopGrad}roman_StopGrad is not necessary to consider the inner maximization in Eq. 1. But we put StopGradStopGrad{\rm StopGrad}roman_StopGrad for the later discussion. This formula determines the direction in the parameter space where the loss increases most sharply, scaled by the hyperparameter ρ𝜌\rhoitalic_ρ. The StopGradStopGrad{\rm StopGrad}roman_StopGrad function is added here to ensure that this ε𝜀\varepsilonitalic_ε is used only for the perturbation step and is treated as a fixed quantity during the computation of gradients for weight updates.

Refer to caption
(a) Probability distribution of standard cross-entropy loss.
Refer to caption
(b) Probability distribution of Adaptive Adversarial Cross-Entropy loss.
Refer to caption
(c) Loss trends of using standard cross-entropy loss vs ours.
Fig. 2: Probability distributions and trend patterns of standard cross-entropy loss and Adaptive Adversarial Cross-Entropy loss.

SAM’s algorithm includes two main steps. First, the algorithm finds a worst-case perturbation of the current parameters using the perturbation vector calculated from Eq. 2. Then, it updates the model weights by optimizing the model parameters using the gradients of the loss at the calculated perturbed position, as shown in the equation below.

wt+1=wtηLs(wt+ε),subscript𝑤𝑡1subscript𝑤𝑡𝜂subscript𝐿𝑠subscript𝑤𝑡𝜀w_{t+1}=w_{t}-\eta\triangledown L_{s}(w_{t}+\varepsilon)\,,italic_w start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_η ▽ italic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_ε ) , (3)

while η𝜂\etaitalic_η is the learning rate. Note that for simplicity’s sake, the weights updating formula is based on standard SGD without the momentum. However, in practical applications, alternative optimization algorithms such as Adam, RMSprop, or SGD with momentum can also be applied.

This approach encourages the optimizer to find flatter minima, which are believed to generalize better to unseen data. SAM has been shown to improve the performance of various deep learning models across different tasks, such as image classification.

3 Proposed Method

Although SAM demonstrates encouraging results, it’s important to be aware of certain limitations. The method SAM uses to determine the worst-case parameters relies on a perturbation that is based on the normalized gradient of the cross-entropy loss, coupled with a predetermined constant defining the neighborhood’s radius.

In this research, we consider Eq. 2 as a composition of ρ𝜌\rhoitalic_ρ and a specific function designed for constructing a perturbation vector.

ε=StopGrad(ρg(w)),𝜀StopGrad𝜌𝑔𝑤\varepsilon={\rm StopGrad}(\rho\,g(w))\,,italic_ε = roman_StopGrad ( italic_ρ italic_g ( italic_w ) ) , (4)

where g(w)𝑔𝑤g(w)italic_g ( italic_w ) is named a perturbation generating function. In SAM, this function is described as

gSAMn(w)=Ls(w)Ls(w)2,superscriptsubscript𝑔SAM𝑛𝑤subscript𝐿𝑠𝑤subscriptnormsubscript𝐿𝑠𝑤2g_{\rm SAM}^{n}(w)=\frac{\triangledown L_{s}(w)}{\left\|\triangledown L_{s}(w)% \right\|_{2}}\,,italic_g start_POSTSUBSCRIPT roman_SAM end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_w ) = divide start_ARG ▽ italic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ( italic_w ) end_ARG start_ARG ∥ ▽ italic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ( italic_w ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG , (5)

where we put the superscript n𝑛nitalic_n because of the normalization.

Consequently, for SAM, the perturbation direction is solely dependent on the gradient of the training loss Ls(w)subscript𝐿𝑠𝑤L_{s}(w)italic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ( italic_w ), or the cross-entropy loss. At nearly stationary points, the gradient of cross-entropy loss becomes minuscule and oscillates around the optimal point, resulting in an inconsistent perturbation direction. Additionally, a significant concern arises when approaching stationary points. Here, the gradient of the cross-entropy loss tends to diminish, potentially reaching zero. This diminishing gradient poses a risk of a divide-by-zero error, which is a critical aspect to consider in the algorithm’s application. Moreover, since the perturbation is always in a constant small radius, it is less explorative at the nearly optimum point.

On the other hand, we suggest the suitable properties of the perturbation, especially, at the nearly stationary points. These properties include

  1. 1.

    The direction of the perturbation should be sufficiently stable to meaningfully adjust the parameters.

  2. 2.

    The gradient of loss used for perturbation calculation should not be too small and continuously decrease at the nearly optimum stage to avoid the gradient diminishing problem.

  3. 3.

    The magnitude of the perturbation should be large enough to remain explorative while the model converges.

Hence, we introduce an innovative method to address the challenges associated with SAM’s perturbation step and satisfy the required properties of the perturbation. Our approach involves altering the loss function used for calculating the perturbation vector. Rather than relying on the cross-entropy loss, which diminishes as the model trained, we propose a novel loss function named Adaptive Adversarial Cross-Entropy (AACE). This new loss function is designed to increase magnitude as the model approaches convergence.

According to the calculation of cross-entropy loss which is determined as

L=iτilog(qi),𝐿subscript𝑖subscript𝜏𝑖subscript𝑞𝑖L=-\sum_{i}\tau_{i}\log(q_{i})\,,italic_L = - ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_τ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , (6)

where qisubscript𝑞𝑖q_{i}italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the predicted probability corresponding to class i𝑖iitalic_i, and τisubscript𝜏𝑖\tau_{i}italic_τ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the target probability distribution for class i𝑖iitalic_i, which, for the standard cross-entropy loss, can be determined using one-hot encoding as

τiCE={1,(i=t)0,(it),superscriptsubscript𝜏𝑖CEcases1𝑖𝑡0𝑖𝑡\tau_{i}^{\rm CE}=\begin{cases}1,&(i=t)\\ 0,&(i\neq t)\\ \end{cases}\,,italic_τ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_CE end_POSTSUPERSCRIPT = { start_ROW start_CELL 1 , end_CELL start_CELL ( italic_i = italic_t ) end_CELL end_ROW start_ROW start_CELL 0 , end_CELL start_CELL ( italic_i ≠ italic_t ) end_CELL end_ROW , (7)

where t𝑡titalic_t stands for a ground truth class.

In our proposed method, instead of using hard 0 or 1 as a target for ground truth and negative classes as in the standard cross-entropy, Adaptive Adversarial Cross-Entropy (AACE) defines new adversarial labels. For a positive class, the label is set to 0. On the other hand, for negative classes, a new target distribution is adjusted by the ratio of the predicted probability qisubscript𝑞𝑖q_{i}italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT of a specific class i𝑖iitalic_i to the sum of predicted probabilities of all negative classes.

τiAACE=ξ(qi~),superscriptsubscript𝜏𝑖AACE𝜉~subscript𝑞𝑖\tau_{i}^{\rm AACE}=\xi(\tilde{q_{i}})\,,italic_τ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_AACE end_POSTSUPERSCRIPT = italic_ξ ( over~ start_ARG italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG ) , (8)

given

ξ(qi~)={0,(i=t)qi~itqi~,(it),𝜉~subscript𝑞𝑖cases0𝑖𝑡~subscript𝑞𝑖subscript𝑖𝑡~subscript𝑞𝑖𝑖𝑡\xi(\tilde{q_{i}})=\begin{cases}0,&(i=t)\\ \frac{\tilde{q_{i}}}{\sum_{i\neq t}\tilde{q_{i}}},&(i\neq t)\\ \end{cases}\,,italic_ξ ( over~ start_ARG italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG ) = { start_ROW start_CELL 0 , end_CELL start_CELL ( italic_i = italic_t ) end_CELL end_ROW start_ROW start_CELL divide start_ARG over~ start_ARG italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i ≠ italic_t end_POSTSUBSCRIPT over~ start_ARG italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG end_ARG , end_CELL start_CELL ( italic_i ≠ italic_t ) end_CELL end_ROW , (9)

and

qi~=StopGrad(qi).~subscript𝑞𝑖StopGradsubscript𝑞𝑖\tilde{q_{i}}={\rm StopGrad}(q_{i})\,.over~ start_ARG italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = roman_StopGrad ( italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) . (10)

As a result, our proposed adaptive adversarial labeling keeps the calculated loss high thanks to enlarging the difference between the target probability distributions and the model’s predicted probability distributions. As illustrated in Fig. 2, assuming that the predicted probabilities are equally distributed among the negative classes, in standard cross-entropy loss with one-hot encoding targets (Fig. 2 (a)), the differences between the predicted probabilities and the target probabilities decrease as the model converges. In contrast, with AACE, these differences increase as the model converges (Fig. 2 (b)). Also, while the standard cross-entropy loss decreases as the predicted probability of the positive class approaches 1, the AACE loss, conversely, increases as the predicted probability of the positive class nears 1 (Fig. 2 (c)).

Refer to caption
Fig. 3: Diagram illustrates the perturbation step and the updating step of original SAM and our proposed method.

Moreover, it is well-known that the gradient of cross-entropy loss with respect to the logit before the softmax activation can be calculated from:

Lzi=qiτi.𝐿subscript𝑧𝑖subscript𝑞𝑖subscript𝜏𝑖\frac{\partial L}{\partial z_{i}}=q_{i}-\tau_{i}\,.divide start_ARG ∂ italic_L end_ARG start_ARG ∂ italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_τ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT . (11)

In which Lzi𝐿subscript𝑧𝑖\frac{\partial L}{\partial z_{i}}divide start_ARG ∂ italic_L end_ARG start_ARG ∂ italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG represents the gradient of the loss with respect to the logits zisubscript𝑧𝑖z_{i}italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for class i𝑖iitalic_i. qisubscript𝑞𝑖q_{i}italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the predicted probability for class i𝑖iitalic_i, as outputted by the softmax function applied to the logits.

While the gradient of conventional cross-entropy loss decreases as the model converges, our AACE loss with adversarial targets increases due to the growth of the gaps between the predicted probabilities and the newly defined adversarial labels. Hence, the gradient for AACE loss remains high even at a nearly optimum stage.

As a result of the increase in the perturbation loss and its gradient while the model converges, the risk of gradient diminishes, which leads to devising by zero problem, is eliminated. More importantly, the larger and growing gradient gives rise to a stronger and more stable direction of the perturbation in SAM’s perturbation step.

In order to define our perturbation generating function, given that

LAACE(w)=iτiAACElog(qi),subscript𝐿AACE𝑤subscript𝑖superscriptsubscript𝜏𝑖AACEsubscript𝑞𝑖L_{\rm AACE}(w)=-\sum_{i}\tau_{i}^{\rm AACE}\log(q_{i})\,,italic_L start_POSTSUBSCRIPT roman_AACE end_POSTSUBSCRIPT ( italic_w ) = - ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_τ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_AACE end_POSTSUPERSCRIPT roman_log ( italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , (12)

the perturbation generating function can now be defined as

gAACEn(w)=LAACE(w)LAACE(w)2.superscriptsubscript𝑔AACE𝑛𝑤subscript𝐿AACE𝑤subscriptnormsubscript𝐿AACE𝑤2g_{\rm AACE}^{n}(w)=-\frac{\triangledown L_{\rm AACE}(w)}{\left\|\triangledown L% _{\rm AACE}(w)\right\|_{2}}\,.italic_g start_POSTSUBSCRIPT roman_AACE end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_w ) = - divide start_ARG ▽ italic_L start_POSTSUBSCRIPT roman_AACE end_POSTSUBSCRIPT ( italic_w ) end_ARG start_ARG ∥ ▽ italic_L start_POSTSUBSCRIPT roman_AACE end_POSTSUBSCRIPT ( italic_w ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG . (13)

Now that our trends of perturbation loss and gradient are converse to the original SAM, the negative sign is applied here because we need to perturb the model parameters to the worst configuration in which AACE loss is low, as opposed to the case of using normal cross-entropy loss.

Furthermore, since we prefer to enlarge the magnitude of the perturbation as the model converges, we also proposed to not normalize the gradient and define a new perturbation generating function as

gAACE(w)=LAACE(w).subscript𝑔AACE𝑤subscript𝐿AACE𝑤g_{\rm AACE}(w)=-\triangledown L_{\rm AACE}(w)\,.italic_g start_POSTSUBSCRIPT roman_AACE end_POSTSUBSCRIPT ( italic_w ) = - ▽ italic_L start_POSTSUBSCRIPT roman_AACE end_POSTSUBSCRIPT ( italic_w ) . (14)

Due to the nature of AACE loss, this newly defined perturbation vector guarantees to increase and remain consistence in direction, even at the nearly optimum stage.

Finally, the weights updating of SGD, original SAM, and our SAM with AACE can be represented by the following expressions.

wt+1SGDsuperscriptsubscript𝑤𝑡1SGD\displaystyle w_{t+1}^{\rm SGD}italic_w start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_SGD end_POSTSUPERSCRIPT =wtηLs(wt),absentsubscript𝑤𝑡𝜂subscript𝐿𝑠subscript𝑤𝑡\displaystyle=w_{t}-\eta\triangledown L_{s}(w_{t})\,,= italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_η ▽ italic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , (15)
wt+1SAMsuperscriptsubscript𝑤𝑡1SAM\displaystyle w_{t+1}^{\rm SAM}italic_w start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_SAM end_POSTSUPERSCRIPT =wtηLs(wt+εSAM),absentsubscript𝑤𝑡𝜂subscript𝐿𝑠subscript𝑤𝑡subscript𝜀SAM\displaystyle=w_{t}-\eta\triangledown L_{s}(w_{t}+\varepsilon_{\rm SAM})\,,= italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_η ▽ italic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_ε start_POSTSUBSCRIPT roman_SAM end_POSTSUBSCRIPT ) , (16)
wt+1AACEsuperscriptsubscript𝑤𝑡1AACE\displaystyle w_{t+1}^{\rm AACE}italic_w start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_AACE end_POSTSUPERSCRIPT =wtηLs(wt+εAACE),absentsubscript𝑤𝑡𝜂subscript𝐿𝑠subscript𝑤𝑡subscript𝜀AACE\displaystyle=w_{t}-\eta\triangledown L_{s}(w_{t}+\varepsilon_{\rm AACE})\,,= italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_η ▽ italic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_ε start_POSTSUBSCRIPT roman_AACE end_POSTSUBSCRIPT ) , (17)

where

εSAM=StopGrad(ρLs(w)Ls(w)2),subscript𝜀SAMStopGrad𝜌subscript𝐿𝑠𝑤subscriptnormsubscript𝐿𝑠𝑤2\varepsilon_{\rm SAM}={\rm StopGrad}\left(\rho\frac{\triangledown L_{s}(w)}{% \left\|\triangledown L_{s}(w)\right\|_{2}}\right)\,,italic_ε start_POSTSUBSCRIPT roman_SAM end_POSTSUBSCRIPT = roman_StopGrad ( italic_ρ divide start_ARG ▽ italic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ( italic_w ) end_ARG start_ARG ∥ ▽ italic_L start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ( italic_w ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ) , (18)

and

εAACE=StopGrad(ρLAACE(w)).subscript𝜀AACEStopGrad𝜌subscript𝐿AACE𝑤\varepsilon_{\rm AACE}=-{\rm StopGrad}(\rho\triangledown L_{\rm AACE}(w))\,.italic_ε start_POSTSUBSCRIPT roman_AACE end_POSTSUBSCRIPT = - roman_StopGrad ( italic_ρ ▽ italic_L start_POSTSUBSCRIPT roman_AACE end_POSTSUBSCRIPT ( italic_w ) ) . (19)

The comparison of these weight updates is shown in Fig. 3. Instead of updating the model configuration based on the loss’s gradient at the current position as in SGD, SAM and our proposed method slightly perturb the model weights to a new position. Then the gradient at the perturbed weights is calculated. This calculated gradient is used to update the model weight at the current configuration.

Table 1: Accuracies (%percent\%%) of SAM with AACE on Wide ResNet on CIFAR-100 with different ρ𝜌\rhoitalic_ρ, with and without gradient normalization in the perturbation.
ρ𝜌\rhoitalic_ρ g(w)𝑔𝑤g(w)italic_g ( italic_w ) LAACE(w)LAACE(w)2subscript𝐿AACE𝑤subscriptnormsubscript𝐿AACE𝑤2-\frac{\triangledown L_{\rm AACE}(w)}{\left\|\triangledown L_{\rm AACE}(w)% \right\|_{2}}- divide start_ARG ▽ italic_L start_POSTSUBSCRIPT roman_AACE end_POSTSUBSCRIPT ( italic_w ) end_ARG start_ARG ∥ ▽ italic_L start_POSTSUBSCRIPT roman_AACE end_POSTSUBSCRIPT ( italic_w ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG LAACE(w)subscript𝐿AACE𝑤-\triangledown L_{\rm AACE}(w)- ▽ italic_L start_POSTSUBSCRIPT roman_AACE end_POSTSUBSCRIPT ( italic_w )
0.05 82.11 83.82
0.1 83.19 84.09
0.2 83.66 84.33
0.5 84.13 84.02
1.0 84.10 84.23
2.0 70.08 78.56
5.0 27.38 71.19

4 Experiments

In order to prove AACE performance, empirical research has been conducted on several model architectures and datasets.

4.1 Hyperparameter grid search

First of all, to evaluate the performance of SAM with AACE loss, we trained Wide ResNet [21] on the CIFAR-100 [22] dataset. We used model depth = 28, width factor = 10, and used SGD as a base optimizer. We applied horizontal flip, padding by four pixels, and random crop for data augmentations. Cutout regularization [23] was also applied. SAM’s only hyperparameter, ρ𝜌\rhoitalic_ρ, has been tuned via grid search over {0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0}. We trained the models for 200 epochs with batch size = 256, momentum = 0.9, and weight decay = 0.0005. We set the initial learning rate to 0.1 and drop by 0.2 at 30%, 60%, and 80% of the training. The experiments were conducted using both perturbations with and without gradient normalization.

As seen in Table 1, the experiments in which we did not apply gradient normalization tend to have higher accuracy. This aligns with our hypothesis. Our proposed AACE elevates the gradient during training. When the gradient normalization part is applied, the perturbation’s magnitude consistently remains the same. Conversely, our proposed method suggests to remove the gradient normalization from the perturbation which leads to an increase in its magnitude. Hence, this approach makes the model to be more explorative at the nearly optimum stage.

Furthermore, during this grid search, the experiment that uses ρ=0.2𝜌0.2\rho=0.2italic_ρ = 0.2 without gradient normalization shows the best performance. We hence use this parameter for the following experiments.

Refer to caption
(a) Perturbation loss and training loss of original SAM.
Refer to caption
(b) Perturbation loss and training loss of SAM with AACE.
Fig. 4: Losses comparison of standard SAM and SAM with AACE. Each data point is the average loss in the epoch.

4.2 Effect of using AACE for perturbation loss

To prove the properties of AACE as discussed in the previous section, the experiments were conducted on the Wide ResNet on the CIFAR-100 dataset using the original SAM and our proposed method. We used the same values for all hyperparameters. For original SAM ρ𝜌\rhoitalic_ρ was set to 0.05, the same as in SAM’s original paper. For our proposed method, we set ρ𝜌\rhoitalic_ρ to 0.2.

Firstly, we investigated the characteristics of standard cross-entropy loss and our adaptive adversarial cross-entropy loss. In standard SAM (Fig. 4 (a)), the curve of the average perturbation loss aligns with the trend of training loss. However, when utilizing SAM with AACE (Fig. 4 (b)), the average perturbation loss increases as the model converges, which is against the trend of the training loss. Note that the rapid rises/drops of the curves are caused by the learning rate scheduler.

Refer to caption
(a) Magnitudes of perturbation loss’s gradients of SAM and the proposed method.
Refer to caption
(b) Perturbation distances of SAM and the proposed method.
Fig. 5: Comparison of magnitudes of perturbation loss’s gradients and perturbation distances between SAM and our method. Note that each data point represents the average value of samples in the epoch.

Moreover, the average magnitudes of the gradient of perturbation loss are compared as shown in Fig. 5 (a). In the standard SAM, the gradient magnitude of the perturbation loss tends to decrease as the model nears convergence. However, when SAM is integrated with AACE, this gradient magnitude shows an increase as the model approaches convergence. Since the magnitude of the gradient of the perturbation remains high, it leads to a more stable gradient direction and also avoids the gradient diminishing issue. Also, Fig. 5 (b) shows that the average perturbation distances are consistently equal to ρ𝜌\rhoitalic_ρ (0.05) in the original SAM, due to gradient normalization. However, our approach suggests not normalizing the perturbation loss’s gradient, leading to an increase in the average perturbation distances as the model progresses toward convergence thanks to the increase in magnitudes of the gradients of perturbation loss. We believe that this could lead the model to be more explorative at the final stage of training.

Refer to caption
Fig. 6: Validation loss and training loss comparison between the models trained with SAM using CE loss and AACE loss in perturbation step. Each data point represents the average training/validation loss in the epoch.

More importantly, as shown in Fig. 6, we explored the generalization ability of the original SAM versus our method. When comparing the performance of the models trained with SAM using standard cross-entropy (CE) loss and Adaptive Adversarial Cross-Entropy (AACE) loss for the perturbation step, it is noticed that while the model trained on SAM with CE achieve a lower training loss, SAM with AACE shows a lower validation loss. This indicates that SAM integrated with AACE loss exhibits superior generalization capabilities.

4.3 Image Classification Comparisons with Wide ResNet

To confirm the effectiveness of SAM integrated with AACE, we conducted empirical experiments on Wide ResNet with different datasets such as CIFAR-100, CIFAR-10 [22], Fashion-MNIST [24], and Food101 [25]. All the hyperparameters are the same as in the previous experiment. The models were trained for 200 epochs for the original SAM and proposed method and 400 epochs for vanilla SGD since SAM weight update requires twice backpropagation compared to SGD. The results for SGD, SAM, and SAM with AACE are shown in Table 2. As seen in the table, our proposed method beat SGD and original SAM on all datasets.

4.4 Image Classification Comparisons with PyramidNet

We also observed the performance of our proposed methods on different model architecture, PyramidNet [26]. In this experiment, the previous datasets were used. For the model setup, we used PyramidNet network with depth = 272, alpha = 200, and batch size = 64. The rest hyperparameters, including the number of epochs, are the same as in the previous experiment. Similar to the experiment on Wide ResNet, as seen in Table 3, our proposed method revealed the highest accuracies for most datasets, except for the CIFAR-10.

Table 2: Accuracies (%percent\%%) of models training with SGD, SAM, and our proposed method on Wide ResNet
Dataset SGD Original SAM Proposed
CIFAR-100 82.21 83.52 84.33
CIFAR-10 96.63 97.02 97.04
Fashion-MNIST 94.57 95.26 95.41
Food101 65.12 70.34 73.55
Table 3: Accuracies (%percent\%%) of models training with SGD, SAM, and our proposed method on PyramidNet
Dataset SGD Original SAM Proposed
CIFAR-100 81.25 83.85 84.13
CIFAR-10 95.74 96.95 96.52
Fashion-MNIST 95.03 95.51 95.57
Food101 66.43 72.97 75.94

5 Conclusion

In conclusion, this research addresses the key limitations of Sharpness-Aware Minimization (SAM) and makes an improvement by proposing a novel perturbation generating technique. We introduce the Adaptive Adversarial Cross-Entropy (AACE) loss which can replace the standard cross-entropy loss in SAM’s perturbation step. AACE loss and its gradient increase as the model approaches convergence, hence it ensures a more consistent direction of the perturbation and also prevents a gradient diminishing problem. We also suggest a new perturbation generating function that uses AACE loss without the normalization part, which increases the magnitude of the perturbation, making the model more explorative at the nearly optimum stage. The empirical results confirmed our hypothesis on AACE characteristics and the experiment results show that our proposed method helps SAM to perform better for image classification tasks on Wide ResNet and PyramidNet on various datasets.

6 Acknowledgements

This work was partially supported by Tateishi Research Grant (A) 2241011 and JSPS KAKENHI Grant Numbers 24K02957.

References

  • [1] Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, and Oriol Vinyals, “Understanding deep learning (still) requires rethinking generalization,” Communications of the ACM, vol. 64, no. 3, pp. 107–115, 2021.
  • [2] David A McAllester, “Pac-bayesian model averaging in Proceedings of the twelfth annual conference on Computational learning theory, pp. 164–170, 1999.
  • [3] Behnam Neyshabur, Srinadh Bhojanapalli, and Nathan Srebro, “A pac-bayesian approach to spectrally-normalized margin bounds for neural networks,” arXiv preprint arXiv:1707.09564, 2017.
  • [4] Tengyuan Liang, Tomaso Poggio, Alexander Rakhlin, and James Stokes, “Fisher-rao metric, geometry, and complexity of neural networks in The 22nd international conference on artificial intelligence and statistics. PMLR, pp. 888–896, 2019.
  • [5] Yuanzhi Li, Colin Wei, and Tengyu Ma, “Towards explaining the regularization effect of initial large learning rate in training neural networks,” Advances in Neural Information Processing Systems, vol. 32, 2019.
  • [6] Pratik Chaudhari and Stefano Soatto, “Stochastic gradient descent performs variational inference, converges to limit cycles for deep networks,” CoRR, vol. abs/1710.11029, 2017.
  • [7] Priya Goyal, Piotr Dollár, Ross B. Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He, “Accurate, large minibatch SGD: training imagenet in 1 hour,” CoRR, vol. abs/1706.02677, 2017.
  • [8] Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and ** Tak Peter Tang, “On large-batch training for deep learning: Generalization gap and sharp minima,” arXiv preprint arXiv:1609.04836, 2016.
  • [9] Pratik Chaudhari, Anna Choromanska, Stefano Soatto, Yann LeCun, Carlo Baldassi, Christian Borgs, Jennifer Chayes, Levent Sagun, and Riccardo Zecchina, “Entropy-sgd: Biasing gradient descent into wide valleys,” Journal of Statistical Mechanics: Theory and Experiment, vol. 2019, no. 12, pp. 124018, 2019.
  • [10] Diederik P Kingma and Jimmy Ba, “Adam: A method for stochastic optimization,” arXiv preprint arXiv:1412.6980, 2014.
  • [11] Herbert Robbins and Sutton Monro, “A stochastic approximation method,” The Annals of Mathematical Statistics, vol. 22, no. 3, pp. 400–407, 1951.
  • [12] **ghui Chen and Quanquan Gu, “Closing the generalization gap of adaptive gradient methods in training deep neural networks,” CoRR, vol. abs/1806.06763, 2018.
  • [13] Ikuro Sato, Kohta Ishikawa, Guoqing Liu, and Masayuki Tanaka, “Breaking inter-layer co-adaptation by classifier anonymization,” CoRR, vol. abs/1906.01150, 2019.
  • [14] Gintare Karolina Dziugaite and Daniel M Roy, “Computing nonvacuous generalization bounds for deep (stochastic) neural networks with many more parameters than training data,” arXiv preprint arXiv:1703.11008, 2017.
  • [15] Sepp Hochreiter and Jürgen Schmidhuber, “Flat minima,” Neural computation, vol. 9, no. 1, pp. 1–42, 1997.
  • [16] Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur, “Sharpness-aware minimization for efficiently improving generalization,” arXiv preprint arXiv:2010.01412, 2020.
  • [17] Juntang Zhuang, Boqing Gong, Liangzhe Yuan, Yin Cui, Hartwig Adam, Nicha Dvornek, Sekhar Tatikonda, James Duncan, and Ting Liu, “Surrogate gap minimization improves sharpness-aware training,” arXiv preprint arXiv:2203.08065, 2022.
  • [18] Ikuro Sato, Yamada Ryota, Masayuki Tanaka, Nakamasa Inoue, and Rei Kawakami, “Pof: Post-training of feature extractor for improving generalization in International Conference on Machine Learning. PMLR, pp. 19221–19230, 2022.
  • [19] Jungmin Kwon, Jeongseop Kim, Hyunseo Park, and In Kwon Choi, “ASAM: adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks,” CoRR, vol. abs/2102.11600, 2021.
  • [20] Zhiyuan Zhang, Ruixuan Luo, Qi Su, and Xu Sun, “Ga-sam: Gradient-strength based adaptive sharpness-aware minimization for improved generalization,” arXiv preprint arXiv:2210.06895, 2022.
  • [21] Sergey Zagoruyko and Nikos Komodakis, “Wide residual networks,” arXiv preprint arXiv:1605.07146, 2016.
  • [22] Alex Krizhevsky and Geoffrey Hinton, “Learning multiple layers of features from tiny images,” Tech. Rep. 0, University of Toronto, Toronto, Ontario, 2009.
  • [23] Terrance DeVries and Graham W Taylor, “Improved regularization of convolutional neural networks with cutout,” arXiv preprint arXiv:1708.04552, 2017.
  • [24] Han Xiao, Kashif Rasul, and Roland Vollgraf, “Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms,” arXiv preprint arXiv:1708.07747, 2017.
  • [25] Lukas Bossard, Matthieu Guillaumin, and Luc Van Gool, “Food-101 – mining discriminative components with random forests in European Conference on Computer Vision, 2014.
  • [26] Dongyoon Han, Jiwhan Kim, and Junmo Kim, “Deep pyramidal residual networks in Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 5927–5935, 2017.