HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: axessibility

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: CC BY-SA 4.0
arXiv:2403.13204v1 [cs.LG] 19 Mar 2024
11institutetext: Monash University, Australia 22institutetext: VinAI Research, Vietnam
22email: {tuananh.bui,tran.vo,trunglm, dinh.phung}@monash.edu, [email protected]

Diversity-Aware Agnostic Ensemble of
Sharpness Minimizers

Anh Bui*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT 11    Vy Vo*{}^{*}start_FLOATSUPERSCRIPT * end_FLOATSUPERSCRIPT 11    Tung Pham 22    Dinh Phung 11    Trung Le 11
Abstract

There has long been plenty of theoretical and empirical evidence supporting the success of ensemble learning. Deep ensembles in particular take advantage of training randomness and expressivity of individual neural networks to gain prediction diversity, ultimately leading to better generalization, robustness and uncertainty estimation. In respect of generalization, it is found that pursuing wider local minima result in models being more robust to shifts between training and testing sets. A natural research question arises out of these two approaches as to whether a boost in generalization ability can be achieved if ensemble learning and loss sharpness minimization are integrated. Our work investigates this connection and proposes DASH - a learning algorithm that promotes diversity and flatness within deep ensembles. More concretely, DASH encourages base learners to move divergently towards low-loss regions of minimal sharpness. We provide a theoretical backbone for our method along with extensive empirical evidence demonstrating an improvement in ensemble generalizability.

Keywords:
Ensemble learning Sharpness-aware Minimization Generalization
**footnotetext: Equal contribution

1 Introduction

Ensemble learning refers to learning a combination of multiple models in a way that the joint performance is better than than any of the ensemble members (so-called base learners). An ensemble can be an explicit collection of functionally independent models where the final decision is formed via approaches like averaging or majority voting of individual predictions. It can implicitly be a single model subject to stochastic perturbation of model architecture during training  [49, 51] or composed of sub-modules sharing some of the model parameters  [54, 53]. An ensemble is called homogeneous if its base learners belong to the same model family or architecture and heterogeneous otherwise.

Traditional bagging technique  [5] is shown to reduce variance among the base learners while boosting methods  [6, 57] are more likely to help reduce bias and improve generalization. Empirical evidence further points out that ensembles perform at least equally well as their base learners  [31] and are much less fallible when the members are independently erroneous in different regions of the feature space  [23]. Deep learning models in particular often land at different local minima valleys due to with training randomness, from initializations, mini-batch sampling, etc. This causes disagreement on predictions among model initializations given the same input. Meanwhile, deep ensembles (i.e., ensembles of deep neural networks) are found to be able to “smooth out” the highly non-convex loss surface, resulting in a better predictive performance  [23, 44, 21, 16, 34]. Ensemble models also benefit from the enhanced diversity in predictions, which is highlighted as another key driving force behind the success of ensemble learning  [10]. Further studies suggest that higher diversity among base learners leads to better robustness and predictive performance  [23, 41, 17, 48]. A recent work additionally shows that deep ensembles in general yield the best calibration under dataset shifts  [41].

Tackling model generalization from a different approach, sharpness-aware minimization is a line of work that seeks the minima within the flat loss regions, along which SAM  [15] is the most popular method. Flat minimizers have been theoretically and empirically proven in various applications to yield better testing accuracies  [29, 45, 12]. At every training step, SAM performs one gradient ascent step to find the worst-case perturbations on the parameters. Given plenty of advantages of ensemble models, a natural question thus arises as to whether ensemble learning and sharpness-aware minimization can be integrated to boost model generalizability. In other words, can we learn a deep ensemble of sharpness minimizers such that the entire ensemble is more generalizable?

Motivated by this connection, our work proposes to improve generalization performance by learning an ensemble of deep sharpness-aware learners. We first develop a theory showing that the general loss of the ensemble can be reduced by minimizing loss sharpness in both the ensemble and its base learners (See Theorem 3.1). Our theoretical development sheds lights on how to guide individual learners in a deep ensemble to be well-behaved and collaborate effectively on a high-dimensional loss landscape.

In addition to generalization, we also target other desiderata of ensemble learning including diversity, robustness and low uncertainty. While the endeavor to address all desiderata within a single framework might be ambitiously challenging, past studies suggest that fostering diversity among the base learners is a multi-purpose approach that can lead to improved generalizability  [35, 50, 40], stability  [59], and adversarial robustness  [42] in the ensemble. To this end, we contribute a novel agnostic diversity-aware constraint that aims to navigate the individual learners to explore multiple wide minima in a divergent fashion. The diversity-aware term attempts to minimize the pairwise KL divergence among the base learners. Such a term is agnostic in the sense that it is introduced early on in the process of searching for the perturbed model. Intuitively, we expect the term to "look ahead" for potential gradient pathways that would guide the updated model to satisfy the goal.

In summary, our contributions in this paper are summarized as follows:
➀ We propose DASH Ensemble - an ensemble learning method for Diversity-aware Agnostic Ensemble of Sharpness Minimizers. DASH seeks to minimize generalization loss by directing the ensemble and its base classifiers towards diverse loss regions of maximal flatness.
➁ We provide a theoretical development for our method, followed by the technical insights into how adding the agnostic diversity-aware term helps introduce diversity in the ensemble and results in better predictive performance and uncertainty estimation capability than the baseline methods.
➂ Across various image classification tasks, we demonstrate an improvement in model generalization capacity of both homogeneous and heterogeneous ensembles up to 6%percent66\%6 %, where the latter benefits significantly.

2 Related works

Ensemble Learning.

The rise of ensemble learning dates back to the development of classical techniques like bagging  [5] or boosting  [6, 18, 19, 57] for improving model generalization. While bagging algorithm involves training independent weak learners in parallel, boosting methods iteratively combine base learners to create a strong model where successor learners try to correct the errors of predecessor ones. In the era of deep learning, there has been an increase in attention towards ensembles of deep neural networks. A deep ensemble made up of low-loss neural learners has been consistently shown to yield to outperform an individual network  [23, 44, 26, 21, 13]. In addition to predictive accuracy, deep ensembles has achieved successes in such other areas as uncertainty estimation  [33, 41, 22] or adversarial robustness  [42, 56, 55].

Ensembles often come with high training and testing costs that can grow linearly with the size of ensembles. This motivates recent works on efficient ensembles for reducing computational overhead without compromising their performance. One direction is to leverage the success of Dynamic Sparse Training  [37, 14] to generate an ensemble of sparse networks with lower training costs while maintaining comparable performance with dense ensembles  [36]. Another light-weight ensemble learning method is via pseudo or implicit ensembles that involves training a single model that exhibits the behavior or characteristic of an ensemble. Regularization techniques such as Drop-out  [49, 20], Drop-connect  [51] or Stochastic Depth  [27] can be viewed as an ensemble network by masking the some units, connections or layers of the network. Other implicit strategies include training base learners with different hyperparameter configurations  [54], decomposing the weight matrices into individual weight modules for each base learners  [53] or using multi-input/output configuration to learn independent sub-networks within a single model  [24].

Sharpness-Aware Minimization.

There has been a growing body of theoretical and empirical studies on the connection between loss sharpness and generalization capacity  [25, 39, 11, 16]. Convergence in flat regions of wider local minima has been found to improve out-of-distribution robustness of neural networks  [29, 45, 12]. Some other works  [30, 28, 52] study the effect of the covariance of gradient or training configurations such as batch size, learning rate, dropout rate on the flatness of minima. One way to encourage search in flat minima is by adding regularization terms to the loss function such as Softmax output’s low entropy penalty  [43, 8] or distillation losses  [60, 58].

SAM  [15] is a recent flat minimizer widely known for its effectiveness and scalability, which encourages the model to search for parameters in the local regions that are uniformly low-loss. SAM has been actively exploited in various applications: meta-learning bi-level optimization in  [1], federated learning  [47], domain generalization  [7], multi-task learning  [46] or for vision transformers  [9] and language models  [4]. Coming from two different directions, ensemble learning and sharpness-aware minimization yet share the same goal of improving generalization. Leveraging these two powerful learning strategies in a single framework remains underexplored. Our work contributes an effort to fill in this research gap.

3 Proposed method

In this section, we first present the theoretical development demonstrating why sharpness-aware ensemble learning is beneficial for improving the generalization of ensemble models. We later introduce how to promote ensemble diversity by enforcing a novel agnostic diversity-aware constraint among the base learners.

Ensemble Setting and Notations.

We first describe the ensemble setting and the notations used throughout our paper. Given m𝑚mitalic_m base learners fθi(i)(x),i=1,,mformulae-sequencesuperscriptsubscript𝑓subscript𝜃𝑖𝑖𝑥𝑖1𝑚f_{\theta_{i}}^{(i)}\left(x\right),i=1,...,mitalic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ( italic_x ) , italic_i = 1 , … , italic_m, we define the ensemble model

fθens(x)=1mi=1mfθi(i)(x),superscriptsubscript𝑓𝜃ens𝑥1𝑚superscriptsubscript𝑖1𝑚superscriptsubscript𝑓subscript𝜃𝑖𝑖𝑥f_{\theta}^{\mathrm{ens}}\left(x\right)=\frac{1}{m}\sum_{i=1}^{m}f_{\theta_{i}% }^{(i)}\left(x\right),italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ens end_POSTSUPERSCRIPT ( italic_x ) = divide start_ARG 1 end_ARG start_ARG italic_m end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ( italic_x ) ,

where θens=[θi]i=1msubscript𝜃enssuperscriptsubscriptdelimited-[]subscript𝜃𝑖𝑖1𝑚\theta_{\mathrm{ens}}=\left[\theta_{i}\right]_{i=1}^{m}italic_θ start_POSTSUBSCRIPT roman_ens end_POSTSUBSCRIPT = [ italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT, xd𝑥superscript𝑑x\in\mathbb{R}^{d}italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, and f(x)ΔM1={πM:π0π1=1}𝑓𝑥subscriptΔ𝑀1conditional-set𝜋superscript𝑀𝜋0subscriptnorm𝜋11f\left(x\right)\in\Delta_{M-1}=\left\{\pi\in\mathbb{R}^{M}:\pi\geq 0\land\|\pi% \|_{1}=1\right\}italic_f ( italic_x ) ∈ roman_Δ start_POSTSUBSCRIPT italic_M - 1 end_POSTSUBSCRIPT = { italic_π ∈ blackboard_R start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT : italic_π ≥ 0 ∧ ∥ italic_π ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 1 }. Here θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and θenssubscript𝜃ens\theta_{\mathrm{ens}}italic_θ start_POSTSUBSCRIPT roman_ens end_POSTSUBSCRIPT denote the parameters w.r.t the classifier fθi(i)superscriptsubscript𝑓subscript𝜃𝑖𝑖f_{\theta_{i}}^{(i)}italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT and the ensemble classifier fθenssuperscriptsubscript𝑓𝜃ensf_{\theta}^{\mathrm{ens}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ens end_POSTSUPERSCRIPT, respectively. Note that the base learners fθi(i)superscriptsubscript𝑓subscript𝜃𝑖𝑖f_{\theta_{i}}^{(i)}italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT can have different architectures.

Assume that :M×𝒴:absentsuperscript𝑀𝒴\ell:\mathbb{R}^{M}\times\mathcal{Y}\xrightarrow{}\mathbb{R}roman_ℓ : blackboard_R start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT × caligraphic_Y start_ARROW start_OVERACCENT end_OVERACCENT → end_ARROW blackboard_R, where 𝒴=[M]={1,,M}𝒴delimited-[]𝑀1𝑀\mathcal{Y}=\left[M\right]=\{1,\dots,M\}caligraphic_Y = [ italic_M ] = { 1 , … , italic_M } is the label set, is a convex and bounded loss function. The training set is denoted by 𝒮={(xi,yi)}i=1N𝒮superscriptsubscriptsubscript𝑥𝑖subscript𝑦𝑖𝑖1𝑁\mathcal{S}=\left\{\left(x_{i},y_{i}\right)\right\}_{i=1}^{N}caligraphic_S = { ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT of data points (xi,yi)𝒟similar-tosubscript𝑥𝑖subscript𝑦𝑖𝒟\left(x_{i},y_{i}\right)\sim\mathcal{D}( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∼ caligraphic_D, where 𝒟𝒟\mathcal{D}caligraphic_D is a data-label distribution. We denote 𝒮(θi)=1Nj=1N(fθii(xj),yj)subscript𝒮subscript𝜃𝑖1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑓subscript𝜃𝑖𝑖subscript𝑥𝑗subscript𝑦𝑗\mathcal{L}_{\mathcal{S}}\left(\theta_{i}\right)=\frac{1}{N}\sum_{j=1}^{N}\ell% \left(f_{\theta_{i}}^{i}\left(x_{j}\right),y_{j}\right)caligraphic_L start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) and 𝒟(θi)=𝔼(x,y)𝒟[(fθii(x),y)]subscript𝒟subscript𝜃𝑖subscript𝔼similar-to𝑥𝑦𝒟delimited-[]superscriptsubscript𝑓subscript𝜃𝑖𝑖𝑥𝑦\mathcal{L}_{\mathcal{D}}\left(\theta_{i}\right)=\mathbb{E}_{(x,y)\sim\mathcal% {D}}\big{[}\ell(f_{\theta_{i}}^{i}(x),y)\big{]}caligraphic_L start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = blackboard_E start_POSTSUBSCRIPT ( italic_x , italic_y ) ∼ caligraphic_D end_POSTSUBSCRIPT [ roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ( italic_x ) , italic_y ) ] as the empirical and general losses w.r.t. the base learner θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, respectively.

Similarly, for the ensemble model, we respectively define the empirical and general losses as 𝒮(θens)=1Nj=1N(fθens(xj),yj)subscript𝒮subscript𝜃ens1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑓𝜃enssubscript𝑥𝑗subscript𝑦𝑗\mathcal{L}_{\mathcal{S}}\left(\theta_{\text{ens}}\right)=\frac{1}{N}\sum_{j=1% }^{N}\ell\left(f_{\theta}^{\mathrm{ens}}\left(x_{j}\right),y_{j}\right)caligraphic_L start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT ens end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ens end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) and 𝒟(θens)=𝔼(x,y)𝒟[(fθens(x),y)]subscript𝒟subscript𝜃enssubscript𝔼similar-to𝑥𝑦𝒟delimited-[]superscriptsubscript𝑓𝜃ens𝑥𝑦\mathcal{L}_{\mathcal{D}}\left(\theta_{\text{ens}}\right)=\mathbb{E}_{(x,y)% \sim\mathcal{D}}\big{[}\ell(f_{\theta}^{\mathrm{ens}}(x),y)\big{]}caligraphic_L start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT ens end_POSTSUBSCRIPT ) = blackboard_E start_POSTSUBSCRIPT ( italic_x , italic_y ) ∼ caligraphic_D end_POSTSUBSCRIPT [ roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ens end_POSTSUPERSCRIPT ( italic_x ) , italic_y ) ].

3.1 Sharpness-aware Ensemble Learning

Standard Sharpness-Aware Minimization.

As introduced in SAM [15], given a single model fθsubscript𝑓𝜃f_{\theta}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, the generalization error 𝒟(θ)subscript𝒟𝜃\mathcal{L}_{\mathcal{D}}(\theta)caligraphic_L start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ( italic_θ ) can be upper-bounded by the sharpness of the model, i.e., maxθ:θθρ𝒮(θ)𝒮(θ)subscript:superscript𝜃normsuperscript𝜃𝜃𝜌subscript𝒮superscript𝜃subscript𝒮𝜃\max_{\theta^{\prime}:\|\theta^{\prime}-\theta\|\leq\rho}\mathcal{L}_{\mathcal% {S}}\left(\theta^{\prime}\right)-\mathcal{L}_{\mathcal{S}}\left(\theta\right)roman_max start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT : ∥ italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - italic_θ ∥ ≤ italic_ρ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT ( italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) - caligraphic_L start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT ( italic_θ ), where ρ>0𝜌0\rho>0italic_ρ > 0 is the perturbed radius. More specifically, Theorem 1 in  [15] shows that

𝒟(θ)maxθ:θθρ𝒮(θ)+h(θ22/ρ),subscript𝒟𝜃subscript:superscript𝜃normsuperscript𝜃𝜃𝜌subscript𝒮superscript𝜃superscriptsubscriptnorm𝜃22𝜌\mathcal{L}_{\mathcal{D}}(\theta)\leq\max_{\theta^{\prime}:\|\theta^{\prime}-% \theta\|\leq\rho}\mathcal{L}_{\mathcal{S}}\left(\theta^{\prime}\right)+h(\|% \theta\|_{2}^{2}/\rho),caligraphic_L start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ( italic_θ ) ≤ roman_max start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT : ∥ italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - italic_θ ∥ ≤ italic_ρ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT ( italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) + italic_h ( ∥ italic_θ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / italic_ρ ) ,

where hhitalic_h is a strictly increasing function that depends on θ𝜃\thetaitalic_θ and ρ𝜌\rhoitalic_ρ. The theorem suggests that minimizing the sharpness of a single model can improve its generalizability. Upon the success of SAM, many consecutive works have been proposed to improve the sharpness-aware minimization. However, they are all limited to a single model.

Sharpness-Aware Ensemble learning.

We now present a sharpness-aware upper bound for the general loss of the ensemble model. To assist readability, we provide the simplified version in the following theorem. The full development can be found in the supplementary materials.

Theorem 3.1

Assume that the loss function normal-ℓ\ellroman_ℓ is convex and upper-bounded by L𝐿Litalic_L. With the probability at least 1δ1𝛿1-\delta1 - italic_δ over the choices of 𝒮𝒟Nsimilar-to𝒮superscript𝒟𝑁\mathcal{S}\sim\mathcal{D}^{N}caligraphic_S ∼ caligraphic_D start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT, for any 0γ10𝛾10\leq\gamma\leq 10 ≤ italic_γ ≤ 1, we have

𝒟(θens)subscript𝒟subscript𝜃ens\displaystyle\mathcal{L}_{\mathcal{\mathcal{D}}}\left(\theta_{\mathrm{ens}}\right)caligraphic_L start_POSTSUBSCRIPT caligraphic_D end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT roman_ens end_POSTSUBSCRIPT ) (1γ)mi=1mmaxθi:θiθiρ𝒮(θi)absent1𝛾𝑚superscriptsubscript𝑖1𝑚subscript:superscriptsubscript𝜃𝑖normsuperscriptsubscript𝜃𝑖subscript𝜃𝑖𝜌subscript𝒮superscriptsubscript𝜃𝑖\displaystyle\leq\frac{\left(1-\gamma\right)}{m}\sum_{i=1}^{m}\max_{\theta_{i}% ^{{}^{\prime}}:\|\theta_{i}^{{}^{\prime}}-\theta_{i}\|\leq\rho}\mathcal{L}_{% \mathcal{S}}\left(\theta_{i}^{{}^{\prime}}\right)≤ divide start_ARG ( 1 - italic_γ ) end_ARG start_ARG italic_m end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT roman_max start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT : ∥ italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ ≤ italic_ρ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT ) (1)
+γmaxθens:θensθensmρ𝒮(θens)+(m,{θi22/ρ}i=1m),𝛾subscript:subscriptsuperscript𝜃ensnormsubscriptsuperscript𝜃enssubscript𝜃ens𝑚𝜌subscript𝒮subscriptsuperscript𝜃ens𝑚superscriptsubscriptsubscriptsuperscriptnormsubscript𝜃𝑖22𝜌𝑖1𝑚\displaystyle+\gamma\max_{\theta^{\prime}_{\mathrm{ens}}:\|\theta^{\prime}_{% \mathrm{ens}}-\theta_{\mathrm{ens}}\|\leq\sqrt{m}\rho}\mathcal{L}_{\mathcal{S}% }(\theta^{\prime}_{\mathrm{ens}})+\mathcal{H}(m,\{\|\theta_{i}\|^{2}_{2}/\rho% \}_{i=1}^{m}),+ italic_γ roman_max start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ens end_POSTSUBSCRIPT : ∥ italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ens end_POSTSUBSCRIPT - italic_θ start_POSTSUBSCRIPT roman_ens end_POSTSUBSCRIPT ∥ ≤ square-root start_ARG italic_m end_ARG italic_ρ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT caligraphic_S end_POSTSUBSCRIPT ( italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ens end_POSTSUBSCRIPT ) + caligraphic_H ( italic_m , { ∥ italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT / italic_ρ } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ) ,

where \mathcal{H}caligraphic_H is a strictly increasing function of m𝑚mitalic_m, ρ𝜌\rhoitalic_ρ and set of model parameter {θi}i=1msuperscriptsubscriptsubscript𝜃𝑖𝑖1𝑚\{\theta_{i}\}_{i=1}^{m}{ italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT.

In the RHS of the inequality, the first term refers to the average sharpness of each independent base learner, while the second term focuses on the sharpness of the entire ensemble model. The trade-off coefficient γ𝛾\gammaitalic_γ signifies the levels of sharpness-aware enforcement for the ensemble model alone and its base learners themselves. Our Theorem 3.1 indicates that the generalization performance of the ensemble model can be improved by promoting the sharpness in both the entire ensemble as well as in the individual base learners.

The dynamics of two modes of sharpness.

Intuitively, Eq. 1 suggests an effective ensemble dynamics where the base learners are not only encouraged to achieve good performance individual but also to contribute synergistically to the ensemble. It is worth noting that the former may foster the latter behavior while the latter alone is likely to be insufficient. Sec. 3.2 will later discuss one possible antagonistic behavior within an ensemble. We now investigate how γ𝛾\gammaitalic_γ should be optimally chosen by studying the impact of these two modes of sharpness on the ensemble performance.

We conduct the experiments on the CIFAR100 dataset by varying γ𝛾\gammaitalic_γ and observing the ensemble performance as shown in Fig. 1. It can be seen that varying γ𝛾\gammaitalic_γ does significantly affect the ensemble performance, with the difference of more than 1.8%percent1.81.8\%1.8 % in ensemble accuracy. Interestingly, the ensemble accuracy and its uncertainty estimation capability peak at γ=0.1𝛾0.1\gamma=0.1italic_γ = 0.1 and decrease when γ𝛾\gammaitalic_γ increases. This empirical observation confirms our intuition that to enhance the generalization ability of the ensemble model, one should focus more on minimizing the sharpness of the base learners than on minimizing the sharpness of the ensemble model. This observation concurs with the finding in  [2] that the ensemble model’s generalization ability is more sensitive to the sharpness of the base learners than the ensemble model itself.

Refer to caption
Figure 1: Tuning for hyper-parameter γ𝛾\gammaitalic_γ. Both the ensemble accuracy (ACC, higher is better) and the expected calibration error (ECE, lower is better) peak when γ=0.1𝛾0.1\gamma=0.1italic_γ = 0.1. See Tab. 4 for other metrics.

3.2 Diversity-Aware Agnostic Ensemble of Flat Base Learners

From the previous section, we have known that solely enforcing sharpness on the entire ensemble, that is to treat the ensemble as a single model and naively apply SAM, is not an optimal strategy. However, Fig. 1 also highlights that enforcing a larger degree of sharpness within individual learners does yield a positive collaborative effect. However, we argue that the current approach still has not maximized the synergy of the learners via this strategy alone. In the following, we provide a theoretical analysis for one potential antagonistic behavior of the base learners.

For the current mini-batch B𝐵Bitalic_B, we define

B~(θi)=B(θi)+γBens(θi,θi).~subscript𝐵subscript𝜃𝑖subscript𝐵subscript𝜃𝑖𝛾subscriptsuperscriptens𝐵subscript𝜃𝑖subscript𝜃absent𝑖\displaystyle\widetilde{\mathcal{L}_{B}}(\theta_{i})=\mathcal{L}_{B}(\theta_{i% })+\gamma\ \mathcal{L}^{\mathrm{ens}}_{B}(\theta_{i},\theta_{\neq i}).over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_γ caligraphic_L start_POSTSUPERSCRIPT roman_ens end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) .

When we enforce the sharpness within the learner θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, the model fθiisubscriptsuperscript𝑓𝑖subscript𝜃𝑖f^{i}_{\theta_{i}}italic_f start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT is updated as

θiasuperscriptsubscript𝜃𝑖𝑎\displaystyle\theta_{i}^{a}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a end_POSTSUPERSCRIPT =θi+ρ1θiB~(θi)θiB~(θi),absentsubscript𝜃𝑖subscript𝜌1subscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖normsubscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖\displaystyle=\theta_{i}+\rho_{1}\frac{\nabla_{\theta_{i}}\widetilde{\mathcal{% L}_{B}}\left(\theta_{i}\right)}{\|\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B% }}\left(\theta_{i}\right)\|},= italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_ρ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT divide start_ARG ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∥ end_ARG , (2)
θisubscript𝜃𝑖\displaystyle\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT =θiηθiB~(θia).absentsubscript𝜃𝑖𝜂subscriptsubscript𝜃𝑖~subscript𝐵superscriptsubscript𝜃𝑖𝑎\displaystyle=\theta_{i}-\eta\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}% \left(\theta_{i}^{a}\right).= italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_η ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a end_POSTSUPERSCRIPT ) .

where ρ1>0subscript𝜌10\rho_{1}>0italic_ρ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT > 0 is the perturbed radius, and η>0𝜂0\eta>0italic_η > 0 is the learning rate.

Using the first order Taylor expansion, we have

θiB~(θia)=θi[B~(θi+ρ1θiB~(θi)θiB~(θi))]subscriptsubscript𝜃𝑖~subscript𝐵superscriptsubscript𝜃𝑖𝑎subscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖subscript𝜌1subscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖normsubscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖\displaystyle\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(\theta_{i}^{a% }\right)=\nabla_{\theta_{i}}\left[\widetilde{\mathcal{L}_{B}}\left(\theta_{i}+% \rho_{1}\frac{\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(\theta_{i}% \right)}{\|\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(\theta_{i}% \right)\|}\right)\right]∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a end_POSTSUPERSCRIPT ) = ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_ρ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT divide start_ARG ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∥ end_ARG ) ]
\displaystyle\approx θi[B~(θi)+ρ1θiB~(θi)θiB~(θi)θiB~(θi)]subscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖subscript𝜌1subscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖subscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖normsubscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖\displaystyle\nabla_{\theta_{i}}\left[\widetilde{\mathcal{L}_{B}}\left(\theta_% {i}\right)+\rho_{1}\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(\theta_% {i}\right)\cdot\frac{\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(% \theta_{i}\right)}{\|\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(% \theta_{i}\right)\|}\right]∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_ρ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ divide start_ARG ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∥ end_ARG ]
=θi[B~(θi)+ρ1θiB~(θi)],absentsubscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖subscript𝜌1normsubscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖\displaystyle=\nabla_{\theta_{i}}\left[\widetilde{\mathcal{L}_{B}}\left(\theta% _{i}\right)+\rho_{1}\|\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(% \theta_{i}\right)\|\right],= ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_ρ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∥ ] , (3)

where \cdot represents the dot product.

The approximation in Eq. 3 indicates that since we follow the negative gradient θiB~(θia)subscriptsubscript𝜃𝑖~subscript𝐵superscriptsubscript𝜃𝑖𝑎-\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(\theta_{i}^{a}\right)- ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a end_POSTSUPERSCRIPT ) when updating the current model θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, the new model tends to decrease both the loss B~(θi)~subscript𝐵subscript𝜃𝑖\widetilde{\mathcal{L}_{B}}\left(\theta_{i}\right)over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) and the gradient norm θiB~(θi)normsubscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖\|\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(\theta_{i}\right)\|∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∥, directing the base learners to go into the low-loss and flat regions as expected. In this case, there is a possibility that all the base learners, each with sufficient expressitivity, will converge to areas surrounding the same low-loss region. Moreover, the normalized gradients θiB~(θi)θiB~(θi),i=1,,mformulae-sequencesubscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖normsubscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖𝑖1𝑚\frac{\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(\theta_{i}\right)}{% \|\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(\theta_{i}\right)\|},i=1% ,\dots,mdivide start_ARG ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∥ end_ARG , italic_i = 1 , … , italic_m reveals that the perturbed models θia,i=1,,mformulae-sequencesuperscriptsubscript𝜃𝑖𝑎𝑖1𝑚\theta_{i}^{a},i=1,\dots,mitalic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a end_POSTSUPERSCRIPT , italic_i = 1 , … , italic_m are also less diverse because they are computed using the same mini-batch B𝐵Bitalic_B. Our intuition is that whenever we add constraints to the objective function B~(θi)~subscript𝐵subscript𝜃𝑖\widetilde{\mathcal{L}_{B}}(\theta_{i})over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) of each individual base learner, if these constraints are independent and do not interact with other base learners, the solution space of each base learner is reduced. Because each base learner is optimized independently, this eventually leads to less diverse updated models θi,i=1,,mformulae-sequencesubscript𝜃𝑖𝑖1𝑚\theta_{i},i=1,\dots,mitalic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_i = 1 , … , italic_m. We illustrate this intuition in Fig. 2.

Refer to caption
Figure 2: Illustration of the model dynamics under sharpness-aware term on loss landscape. Two base learners θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and θjsubscript𝜃𝑗\theta_{j}italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT (represented by the red and black vectors respectively) happen to be initialized closely. At each step, since updated independently yet using the same mini-batch from θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and θjsubscript𝜃𝑗\theta_{j}italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, two perturbed models θiasuperscriptsubscript𝜃𝑖𝑎\theta_{i}^{a}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a end_POSTSUPERSCRIPT and θiasuperscriptsubscript𝜃𝑖𝑎\theta_{i}^{a}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a end_POSTSUPERSCRIPT are less diverse, hence two updated models θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and θjsubscript𝜃𝑗\theta_{j}italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT are also less diverse and more likely end up at the same low-loss and flat region.

Although we expect minimizing sharpness in the ensemble via Bens()superscriptsubscript𝐵ens\mathcal{L}_{B}^{\mathrm{ens}}(\cdot)caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ens end_POSTSUPERSCRIPT ( ⋅ ) would foster a model dynamics where each learner complements each other to support generalization, the above analysis warns us against an adverse effect where ensemble diversity is reduced, thus depriving us of the desirable synergy. The empirical evidence for this intuition can be found in Table 5. The question now is how to strengthen the sharpness-aware learning of each individual base learner that interacts with other base learners to achieve both sharpness and diversity?

To this end, we propose the following agnostic update

θiasuperscriptsubscript𝜃𝑖𝑎\displaystyle\theta_{i}^{a}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a end_POSTSUPERSCRIPT =θi+ρ1θiB~(θi)θiB~(θi)+ρ2θiBdiv(θi,θi)θiBdiv(θi,θi),absentsubscript𝜃𝑖subscript𝜌1subscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖normsubscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖subscript𝜌2subscriptsubscript𝜃𝑖superscriptsubscript𝐵𝑑𝑖𝑣subscript𝜃𝑖subscript𝜃absent𝑖normsubscriptsubscript𝜃𝑖superscriptsubscript𝐵𝑑𝑖𝑣subscript𝜃𝑖subscript𝜃absent𝑖\displaystyle=\theta_{i}+\rho_{1}\frac{\nabla_{\theta_{i}}\widetilde{\mathcal{% L}_{B}}\left(\theta_{i}\right)}{\|\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B% }}\left(\theta_{i}\right)\|}+\rho_{2}\frac{\nabla_{\theta_{i}}\mathcal{L}_{B}^% {div}\left(\theta_{i},\theta_{\neq i}\right)}{\|\nabla_{\theta_{i}}\mathcal{L}% _{B}^{div}\left(\theta_{i},\theta_{\neq i}\right)\|},= italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_ρ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT divide start_ARG ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∥ end_ARG + italic_ρ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT divide start_ARG ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) ∥ end_ARG , (4)
θisubscript𝜃𝑖\displaystyle\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT =θiηθiB~(θia),absentsubscript𝜃𝑖𝜂subscriptsubscript𝜃𝑖~subscript𝐵superscriptsubscript𝜃𝑖𝑎\displaystyle=\theta_{i}-\eta\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}% \left(\theta_{i}^{a}\right),= italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_η ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a end_POSTSUPERSCRIPT ) ,

where θisubscript𝜃absent𝑖\theta_{\neq i}italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT specifies the set of models excluding θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and the i𝑖iitalic_i-th divergence loss is defined as

Bdiv(θi,θi)=1|B|xB,jiKL(σ(hθii(x)/τ),σ(hθjj(x)/τ)),superscriptsubscript𝐵𝑑𝑖𝑣subscript𝜃𝑖subscript𝜃absent𝑖1𝐵subscriptformulae-sequence𝑥𝐵𝑗𝑖𝐾𝐿𝜎superscriptsubscriptsubscript𝜃𝑖𝑖𝑥𝜏𝜎superscriptsubscriptsubscript𝜃𝑗𝑗𝑥𝜏\displaystyle\mathcal{L}_{B}^{div}\left(\theta_{i},\theta_{\neq i}\right)=% \frac{1}{|B|}\sum_{x\in B,j\neq i}KL\left(\sigma\left(h_{\theta_{i}}^{i}\left(% x\right)/\tau\right),\sigma\left(h_{\theta_{j}}^{j}\left(x\right)/\tau\right)% \right),caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG | italic_B | end_ARG ∑ start_POSTSUBSCRIPT italic_x ∈ italic_B , italic_j ≠ italic_i end_POSTSUBSCRIPT italic_K italic_L ( italic_σ ( italic_h start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ( italic_x ) / italic_τ ) , italic_σ ( italic_h start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x ) / italic_τ ) ) , (5)

where hθkksubscriptsuperscript𝑘subscript𝜃𝑘h^{k}_{\theta_{k}}italic_h start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT returns non-targeted logits (i.e., excluding the logit value of the ground-truth class) of the k𝑘kitalic_k-th base learner, σ𝜎\sigmaitalic_σ is the softmax function, τ>0𝜏0\tau>0italic_τ > 0 is the temperature variable, ρ2subscript𝜌2\rho_{2}italic_ρ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is another perturbed radius, and KL specifies the Kullback-Leibler divergence. In practice, we choose ρ2=ρ1subscript𝜌2subscript𝜌1\rho_{2}=\rho_{1}italic_ρ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_ρ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT for simplicity and τ<1𝜏1\tau<1italic_τ < 1 to favor the distance on dominating modes on each base learner.

It is worth noting that Eq. 5 only considers the logits of the non-targeted labels for diversifying the base learners, to avoid interfering with their performance on predicting ground-truth labels. To inspect the agnostic behavior of the second gradient w.r.t the perturbed models θiasuperscriptsubscript𝜃𝑖𝑎\theta_{i}^{a}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a end_POSTSUPERSCRIPT, we again use the first-order Taylor expansion

θiB(θia)=subscriptsubscript𝜃𝑖subscript𝐵superscriptsubscript𝜃𝑖𝑎absent\displaystyle\nabla_{\theta_{i}}\mathcal{L}_{B}\left(\theta_{i}^{a}\right)=∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a end_POSTSUPERSCRIPT ) = θi[B(θi+ρ1θiB~(θi)θiB~(θi)\displaystyle\nabla_{\theta_{i}}\Biggl{[}\mathcal{L}_{B}\Biggl{(}\theta_{i}+% \rho_{1}\frac{\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(\theta_{i}% \right)}{\|\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(\theta_{i}% \right)\|}∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_ρ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT divide start_ARG ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∥ end_ARG
+ρ2θiBdiv(θi,θi)θiBdiv(θi,θi)])\displaystyle+\rho_{2}\frac{\nabla_{\theta_{i}}\mathcal{L}_{B}^{div}\left(% \theta_{i},\theta_{\neq i}\right)}{\|\nabla_{\theta_{i}}\mathcal{L}_{B}^{div}% \left(\theta_{i},\theta_{\neq i}\right)\|}\Biggr{]}\Biggr{)}+ italic_ρ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT divide start_ARG ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) ∥ end_ARG ] )
\displaystyle\approx θi[B~(θi)+ρ1θiB~(θi)θiB~(θi)θiB~(θi)\displaystyle\nabla_{\theta_{i}}\Biggl{[}\widetilde{\mathcal{L}_{B}}\left(% \theta_{i}\right)+\rho_{1}\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(% \theta_{i}\right)\cdot\frac{\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}% \left(\theta_{i}\right)}{\|\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left% (\theta_{i}\right)\|}∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_ρ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ divide start_ARG ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∥ end_ARG
+ρ2θiB~(θi)θiBdiv(θi,θi)θiBdiv(θi,θi)]\displaystyle+\rho_{2}\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(% \theta_{i}\right)\cdot\frac{\nabla_{\theta_{i}}\mathcal{L}_{B}^{div}\left(% \theta_{i},\theta_{\neq i}\right)}{\|\nabla_{\theta_{i}}\mathcal{L}_{B}^{div}% \left(\theta_{i},\theta_{\neq i}\right)\|}\Biggr{]}+ italic_ρ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ divide start_ARG ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) ∥ end_ARG ]
=\displaystyle== θi[B~(θi)+ρ1θiB~(θi)\displaystyle\nabla_{\theta_{i}}\Biggl{[}\widetilde{\mathcal{L}_{B}}\left(% \theta_{i}\right)+\rho_{1}\|\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}% \left(\theta_{i}\right)\|∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_ρ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∥
ρ2θiB~(θi)θiBdiv(θi,θi)θiBdiv(θi,θi)].\displaystyle-\rho_{2}\frac{-\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}% \left(\theta_{i}\right)\cdot\nabla_{\theta_{i}}\mathcal{L}_{B}^{div}\left(% \theta_{i},\theta_{\neq i}\right)}{\|\nabla_{\theta_{i}}\mathcal{L}_{B}^{div}% \left(\theta_{i},\theta_{\neq i}\right)\|}\Biggr{]}.- italic_ρ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT divide start_ARG - ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) ∥ end_ARG ] . (6)

In Eq. 6, the first two terms lead the base learners to go to their low-loss and flat regions as discussed before. We then analyze the agnostic behavior of the third term. According to the update formula of θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in Eq. 4, we follow the positive direction of θid=θi[θiB~(θi)θiBdiv(θi,θi)θiBdiv(θi,θi)]subscriptsubscript𝜃𝑖superscriptsubscript𝑑subscriptsubscript𝜃𝑖subscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖subscriptsubscript𝜃𝑖superscriptsubscript𝐵𝑑𝑖𝑣subscript𝜃𝑖subscript𝜃absent𝑖normsubscriptsubscript𝜃𝑖superscriptsubscript𝐵𝑑𝑖𝑣subscript𝜃𝑖subscript𝜃absent𝑖\nabla_{\theta_{i}}\mathcal{L}_{\mathcal{B}}^{d}=\nabla_{\theta_{i}}\left[% \frac{-\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(\theta_{i}\right)% \cdot\nabla_{\theta_{i}}\mathcal{L}_{B}^{div}\left(\theta_{i},\theta_{\neq i}% \right)}{\|\nabla_{\theta_{i}}\mathcal{L}_{B}^{div}\left(\theta_{i},\theta_{% \neq i}\right)\|}\right]∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT = ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ divide start_ARG - ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) ∥ end_ARG ], further implying that the updated base learner networks aim to maximize θiB~(θi)θiBdiv(θi,θi)θiBdiv(θi,θi)subscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖subscriptsubscript𝜃𝑖superscriptsubscript𝐵𝑑𝑖𝑣subscript𝜃𝑖subscript𝜃absent𝑖normsubscriptsubscript𝜃𝑖superscriptsubscript𝐵𝑑𝑖𝑣subscript𝜃𝑖subscript𝜃absent𝑖\frac{-\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(\theta_{i}\right)% \cdot\nabla_{\theta_{i}}\mathcal{L}_{B}^{div}\left(\theta_{i},\theta_{\neq i}% \right)}{\|\nabla_{\theta_{i}}\mathcal{L}_{B}^{div}\left(\theta_{i},\theta_{% \neq i}\right)\|}divide start_ARG - ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) ∥ end_ARG. Therefore, the low-loss direction θiB~(θi)subscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖-\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}\left(\theta_{i}\right)- ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) becomes more congruent with θiBdiv(θi,θi)θiBdiv(θi,θi)subscriptsubscript𝜃𝑖superscriptsubscript𝐵𝑑𝑖𝑣subscript𝜃𝑖subscript𝜃absent𝑖normsubscriptsubscript𝜃𝑖superscriptsubscript𝐵𝑑𝑖𝑣subscript𝜃𝑖subscript𝜃absent𝑖\frac{\nabla_{\theta_{i}}\mathcal{L}_{B}^{div}\left(\theta_{i},\theta_{\neq i}% \right)}{\|\nabla_{\theta_{i}}\mathcal{L}_{B}^{div}\left(\theta_{i},\theta_{% \neq i}\right)\|}divide start_ARG ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) ∥ end_ARG, meaning that the base learners tend to diverge while moving along the low-loss and flat directions. Fig. 3 visualizes our intuition.

Refer to caption
Figure 3: Illustration of the model dynamics under diversity-aware term. Given two base learners θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and θjsubscript𝜃𝑗\theta_{j}italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT (represented by the red and black vectors respectively), the gradients θiB~(θi)subscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖-\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}(\theta_{i})- ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) and θiB~(θi)subscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖-\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}(\theta_{i})- ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) navigate the models towards their low-loss (also flat) regions. Moreover, the two gradients θiBdiv(θi,θi)subscriptsubscript𝜃𝑖superscriptsubscript𝐵𝑑𝑖𝑣subscript𝜃𝑖subscript𝜃absent𝑖\nabla_{\theta_{i}}\mathcal{L}_{B}^{div}(\theta_{i},\theta_{\neq i})∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) and θjBdiv(θj,θj)subscriptsubscript𝜃𝑗superscriptsubscript𝐵𝑑𝑖𝑣subscript𝜃𝑗subscript𝜃absent𝑗\nabla_{\theta_{j}}\mathcal{L}_{B}^{div}(\theta_{j},\theta_{\neq j})∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_j end_POSTSUBSCRIPT ) encourage the models to move divergently. As discussed, our update strategy forces the two gradients θiB~(θi)subscriptsubscript𝜃𝑖~subscript𝐵subscript𝜃𝑖-\nabla_{\theta_{i}}\widetilde{\mathcal{L}_{B}}(\theta_{i})- ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_ARG ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) and θiBdiv(θi,θi)subscriptsubscript𝜃𝑖superscriptsubscript𝐵𝑑𝑖𝑣subscript𝜃𝑖subscript𝜃absent𝑖\nabla_{\theta_{i}}\mathcal{L}_{B}^{div}(\theta_{i},\theta_{\neq i})∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_i italic_v end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT ≠ italic_i end_POSTSUBSCRIPT ) to be more congruent. As the result, two models are divergently oriented to two non-overlap** low-loss and flat regions. This behavior is imposed similarly for the other pair w.r.t. the model θjsubscript𝜃𝑗\theta_{j}italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, altogether enhancing the ensemble diversity.

4 Experiments

We evaluate our methods on the classification tasks on CIFAR10/100 and Tiny-Imagenet. We experiment with homogeneous ensembles wherein all base learners has the same model architecture, i.e., R18x3 is an ensemble which consists of three ResNet18 models. We also experiment with heterogeneous ensemble, i.e., RME is an ensemble which consists of ResNet18, MobileNet and EfficientNet models. The configuration shared between our method and the baselines involves model training for 200200200200 epochs using SGD optimizer with weight decay of 0.0050.0050.0050.005. We follow the standard data pre-processing schemes that consists of zero-padding with 4 pixels on each side, random crop, horizon flip and normalization. The ensemble prediction has been aggregated by averaging the softmax predictions of all base classifiers.111Our code is anonymously published at https://anonymous.4open.science/r/DASH/. In all tables, bold/underline indicates the best/second-best method. \uparrow,\downarrow respectively indicates higher/lower performance is better. We provide our algorithm and more ablation studies in the supplementary materials.

4.1 Baselines

This work focuses on improving generalization of ensembles. We compare our method against top ensemble methods with high predictive accuracies across literature: Deep ensembles  [33], Snapshot ensembles  [26], Fast Geometric Ensemble (FGE)  [21], sparse ensembles EDST and DST  [36]. We also deploy SGD and SAM  [15] as different optimizers to train an ensemble model and consider as two additional baselines.

4.2 Metrics

We use Ensemble accuracy (Acc) as the primary metric used to measure the generalization of an ensemble learning method. To evaluate the uncertainty capability of a model, we use the standard metrics: Negative Log-Likelihood (NLL), Brier score, and Expected Calibration Error (ECE), which are widely used in the literature. We also report calibrated uncertainty estimation (UE) metrics, such as Cal-NLL, Cal-Brier, and Cal-AAC, at the optimal temperature to avoid measuring calibration error that can be eliminated by simple temperature scaling, as suggested in  [3]. To measure ensemble diversity, we use Disagreement (D) of predictions  [32] and Log of Determinant (LD) of a matrix consisting of non-target predictions of base classifiers, as proposed in  [42]. The LD metric provides an elegant geometric interpretation of ensemble diversity, which is better than the simple disagreement metric.

4.3 Evaluation of Predictive Performance

The results presented in Table 1 demonstrate the effectiveness of our proposed method, DASH, in improving the generalization ability of ensemble methods. Across all datasets and architectures, DASH consistently and significantly outperformed all baselines. For example, when compared to SGD with R18x3 architecture, DASH achieved substantial improvement gaps of 1.5%,3.3%percent1.5percent3.31.5\%,3.3\%1.5 % , 3.3 %, and 7.6%percent7.67.6\%7.6 % on the CIFAR10, CIFAR100, and Tiny-ImageNet datasets, respectively. When compared to Deep Ensemble, DASH achieved improvement gaps of 3.0%,6.8%percent3.0percent6.83.0\%,6.8\%3.0 % , 6.8 %, and 4.0%percent4.04.0\%4.0 %, respectively, on these same datasets. Our results also provide evidence that seeking more flat classifiers can bring significant benefits to ensemble learning. SAM achieves improvements over SGD or Deep Ensemble, but DASH achieved even greater improvements. Specifically, on the CIFAR100 dataset, DASH outperformed SAM by 3.1%,2.1%percent3.1percent2.13.1\%,2.1\%3.1 % , 2.1 %, and 2.3%percent2.32.3\%2.3 % with R10x5, R18x3, and RME architectures, respectively, while that improvement on the Tiny-ImageNet dataset was 3.8%percent3.83.8\%3.8 %. This improvement indicates the benefits of effectively collaborating between flatness and diversity seeking objectives in deep ensembles.

Unlike Fast Geometric, Snapshot, or EDST methods, which are limited to homogeneous ensemble settings, DASH is a general method capable of improving ensemble performance even when ensembling different architectures. This is evidenced by the larger improvement gaps over SAM on the RME architecture (i.e., 1.4%percent1.41.4\%1.4 % improvement on the CIFAR10 dataset) compared to the R18x3 architecture (i.e., 0.9%percent0.90.9\%0.9 % improvement on the same dataset). These results demonstrate the versatility and effectiveness of DASH in improving the generalization ability of deep ensembles across diverse architectures and datasets.

Table 1: Ensemble accuracy (%) on the CIFAR10/100 and Tiny-ImageNet datasets. R10x5 indicates an ensemble of five ResNet10 models. R18x3 indicates an ensemble of three ResNet18 models. RME indicates an ensemble of ResNet18, MobileNet and EfficientNet, respectively.
CIFAR10 CIFAR100 Tiny-ImageNet
Accuracy \uparrow R10x5 R18x3 RME R10x5 R18x3 RME R18x3
Deep Ensemble 92.7 93.7 89.0 73.7 75.4 62.7 65.9
Fast Geometric 92.5 93.3 - 63.2 72.3 - 61.8
Snapshot 93.6 94.8 - 72.8 75.7 - 62.2
EDST 92.0 92.8 - 68.4 69.6 - 62.3
DST 93.2 94.7 93.4 70.8 70.4 71.7 61.9
SGD 95.1 95.2 92.6 75.9 78.9 72.6 62.3
SAM 95.4 95.8 93.8 77.7 80.1 76.4 66.1
DASH (Ours) 95.7 96.7 95.2 80.8 82.2 78.7 69.9
Table 2: Evaluation of Uncertainty Estimation (UE). Calibrated-Brier score is chosen as the representative UE metric reported in this table. Evaluation on all six UE metrics for CIFAR10/100 can be found in the supplementary material. Overall, our method achieves better calibration than baselines on several metrics, especially in the heterogeneous ensemble setting.
CIFAR10 CIFAR100 Tiny-ImageNet
Cal-Brier \downarrow R10x5 R18x3 RME R10x5 R18x3 RME R18x3
Deep Ensemble 0.091 0.079 0.153 0.329 0.308 0.433 0.453
Fast Geometric 0.251 0.087 - 0.606 0.344 - 0.499
Snapshot 0.083 0.071 - 0.338 0.311 - 0.501
EDST 0.122 0.113 - 0.427 0.412 - 0.495
DST 0.102 0.083 0.102 0.396 0.405 0.393 0.500
SGD 0.078 0.076 0.113 0.346 0.304 0.403 0.518
SAM 0.073 0.067 0.094 0.321 0.285 0.347 0.469
DASH (Ours) 0.067 0.056 0.075 0.267 0.255 0.298 0.407

4.4 Evaluation of Uncertainty Estimation

Although improving uncertainty estimation is not the primary focus of our method, in this section we still would like to investigate the effectiveness of our method on this aspect by measuring six UE metrics across all experimental settings. We present the results of our evaluation in Table 2, where we compare the uncertainty estimation capacity of our method with various baselines using the Calibrated-Brier score as the representative metric. Our method consistently achieves the best performance over all baselines across all experimental settings. For instance, on the CIFAR10 dataset with the R10x5 setting, our method obtains a score of 0.067, a relative improvement of 26% over the Deep Ensemble method. Similarly, across all settings, our method achieves a relative improvement of 26%,29%,51%,18%,17%,31%percent26percent29percent51percent18percent17percent3126\%,29\%,51\%,18\%,17\%,31\%26 % , 29 % , 51 % , 18 % , 17 % , 31 %, and 10%percent1010\%10 % over the Deep Ensemble method. Furthermore, in Table 3, we evaluate the performance of our method on all six UE metrics on the Tiny-ImageNet dataset. In this setting, our method achieves the best performance on five UE metrics, except for the ECE metric. Compared to the Deep Ensemble method, our method obtains a relative improvement of 10%,3%percent10percent310\%,3\%10 % , 3 %, and 14%percent1414\%14 % on the Cal-Brier, Cal-ACC, and Cal-NLL metrics, respectively. In conclusion, our method shows promising results in improving uncertainty estimation, as demonstrated by its superior performance in various UE metrics.

Table 3: Evaluation of Uncertainty Estimation (UE) across six standard UE metrics on the Tiny-ImageNet dataset with R18x3.
NLL \downarrow Brier \downarrow ECE \downarrow Cal-Brier \downarrow Cal-AAC \downarrow Cal-NLL \downarrow
Deep Ensemble 1.400 0.452 0.110 0.453 0.210 1.413
Fast Geometric 1.548 0.501 0.116 0.499 0.239 1.544
Snapshot 1.643 0.505 0.118 0.501 0.237 1.599
EDST 1.581 0.496 0.115 0.495 0.235 1.548
DST 1.525 0.499 0.110 0.500 0.239 1.536
SGD 1.999 0.601 0.283 0.518 0.272 1.737
SAM 1.791 0.563 0.297 0.469 0.242 1.484
DASH (Ours) 1.379 0.447 0.184 0.407 0.204 1.213

4.5 Evaluation on Adversarial Robustness

In this section, our goal is to evaluate the adversarial robustness of our proposed method against adversarial attacks. To achieve this, we conducted experiments on the CIFAR10 dataset using the R18x3 architecture and employed the PGD attack  [38], which is considered the standard adversarial attack for evaluating robustness. Specifically, we set the number of attack steps to k=10𝑘10k=10italic_k = 10, step size to η=1/255𝜂1255\eta=1/255italic_η = 1 / 255, and varied the change in perturbation size ϵitalic-ϵ\epsilonitalic_ϵ from 1/25512551/2551 / 255 to 6/25562556/2556 / 255.

While it is widely recognized in the Adversarial Machine Learning literature that strong attacks are required to truly challenge defense methods (i.e., PGD attack with more than 200 attack steps with a perturbation size of ϵ=8/255italic-ϵ8255\epsilon=8/255italic_ϵ = 8 / 255), we chose a weaker attack for our experiments. This decision was based on the fact that all methods we evaluated were not specifically designed to enhance adversarial robustness, and therefore may not perform well against a stronger attack.

It can be seen from Fig. 3(a) that our DASH achieves better adversarial robustness than all baselines on the R18x3 architecture. More specifically, our method consistently outperforms SGD by around 3% across different levels of ϵitalic-ϵ\epsilonitalic_ϵ. While there is a huge drop of adversarial robustness on SAM when the attack becomes stronger (i.e., 61.28% with ϵ=1/255italic-ϵ1255\epsilon=1/255italic_ϵ = 1 / 255 and 27.61% with ϵ=2/255italic-ϵ2255\epsilon=2/255italic_ϵ = 2 / 255), our method is more robust with a smaller drop (i.e., 65.53% with ϵ=1/255italic-ϵ1255\epsilon=1/255italic_ϵ = 1 / 255 and 42.23% with ϵ=2/255italic-ϵ2255\epsilon=2/255italic_ϵ = 2 / 255). On the R10x5 architecture, our method still outperforms SGD and SAM across all levels of attack strength. However, it can be observed that our DASH achieves a lower performance than DST and EDST methods if the perturbation size ϵ2/255italic-ϵ2255\epsilon\geq 2/255italic_ϵ ≥ 2 / 255 as shown in Fig. 3(b). While our method does not specifically target improving adversarial robustness, the superior performance we achieve on the R18x3 architecture suggests that our principle of considering sharpness-aware and diverse-aware mechanisms could be a promising direction for addressing this issue.

Refer to caption
(a) R18x3
Refer to caption
(b) R10x5
Figure 4: Evaluation on Adversarial Robustness. The x-axis denotes the perturbation size ϵitalic-ϵ\epsilonitalic_ϵ (*255).

5 Ablation studies

5.1 Hyper-parameter sensitivity

Table 4 reports the effect of the hyper-parameter γ𝛾\gammaitalic_γ on the performance of our method by tuning it over the range of [0,1]01[0,1][ 0 , 1 ]. Recall that γ=0.1𝛾0.1\gamma=0.1italic_γ = 0.1 means that we prioritize seeking flatness on all individual base classifiers over the entire ensemble model, while γ=1𝛾1\gamma=1italic_γ = 1 means that we only seek flatness on the entire aggregated ensemble model only. We conduct the experiment on the CIFAR100 dataset with R10x5 architecture and report results on Table 4. It can be seen that our method achieves the best performance in both generalization and uncertainty estimation aspects when γ=0.1𝛾0.1\gamma=0.1italic_γ = 0.1 and there is a significant drop of 1.8%percent1.81.8\%1.8 % in accuracy when γ=1𝛾1\gamma=1italic_γ = 1. In our experiments, we set γ=0.1𝛾0.1\gamma=0.1italic_γ = 0.1 as the default setting.

Table 4: Ensemble performance under various the trade-off parameters γ𝛾\gammaitalic_γ on the CIFAR100 dataset with R10x5 architecture.
Accuracy \uparrow NLL \downarrow Brier \downarrow ECE \downarrow
γ=0.1𝛾0.1\gamma=0.1italic_γ = 0.1 80.84 0.86 0.32 0.18
γ=0.2𝛾0.2\gamma=0.2italic_γ = 0.2 80.48 0.97 0.35 0.23
γ=0.5𝛾0.5\gamma=0.5italic_γ = 0.5 80.42 0.95 0.34 0.22
γ=0.8𝛾0.8\gamma=0.8italic_γ = 0.8 79.81 1.08 0.38 0.29
γ=1.0𝛾1.0\gamma=1.0italic_γ = 1.0 78.86 1.12 0.40 0.28

5.2 Contribution of the diverse-aware agnostic constraint

In this section, our objective is to assess the impact of each component by comparing the performance of two variants: DASH and DASHFsuperscriptDASH𝐹\text{DASH}^{F}DASH start_POSTSUPERSCRIPT italic_F end_POSTSUPERSCRIPT, where the latter is our method with flat seeking mode only. We run the experiments on the CIFAR10 and CIFAR100 datasets with RME architecture, and the results are presented in Table 5. We observed that DASHFsuperscriptDASH𝐹\text{DASH}^{F}DASH start_POSTSUPERSCRIPT italic_F end_POSTSUPERSCRIPT outperforms the standard SGD method by a substantial margin when using the flat seeking mode only. The performance improvement is remarkable, with a gap of 1.72%percent1.721.72\%1.72 % and 3.73%percent3.733.73\%3.73 %on the CIFAR10 and CIFAR100 datasets, respectively. This enhancement can be attributed to the improvement of each single base classifier. The ensemble can achieve better generalization performance by combining these classifiers. In particular, the average accuracy of all base classifiers with DASHFsuperscriptDASH𝐹\text{DASH}^{F}DASH start_POSTSUPERSCRIPT italic_F end_POSTSUPERSCRIPT is 93.21%percent93.2193.21\%93.21 %, which is 5.07%percent5.075.07\%5.07 % higher than that achieved with the SGD method.

However, in terms of ensemble diversity, measured by the Log-Determinant metric, DASHFsuperscriptDASH𝐹\text{DASH}^{F}DASH start_POSTSUPERSCRIPT italic_F end_POSTSUPERSCRIPT’s base classifiers are less diverse than those of SGD. Specifically, on the same CIFAR100 dataset, SGD obtains a LD score of 16.8816.88-16.88- 16.88, while that of DASHFsuperscriptDASH𝐹\text{DASH}^{F}DASH start_POSTSUPERSCRIPT italic_F end_POSTSUPERSCRIPT is only 19.4719.47-19.47- 19.47, which is a 15.3%percent15.315.3\%15.3 % relatively lower. The lower LD score indicates that the predictions of the base classifiers on DASHFsuperscriptDASH𝐹\text{DASH}^{F}DASH start_POSTSUPERSCRIPT italic_F end_POSTSUPERSCRIPT have a higher similarity than those on SGD. Consequently, in some hard negative samples, the predictions of all base classifiers fall into similar incorrect patterns, and the final ensemble prediction becomes incorrect. On the other hand, when comparing between DASH and DASHFsuperscriptDASH𝐹\text{DASH}^{F}DASH start_POSTSUPERSCRIPT italic_F end_POSTSUPERSCRIPT, it can be observed that, DASH obtains a higher LD score in both datasets, while also improves the average performance of the base classifiers. As consequent, DASH improves over DASHFsuperscriptDASH𝐹\text{DASH}^{F}DASH start_POSTSUPERSCRIPT italic_F end_POSTSUPERSCRIPT by 0.84%percent0.840.84\%0.84 % to 2.44%percent2.442.44\%2.44 % on the CIFAR10 and CIFAR100, respectively.

Table 5: Ablation study on the contribution of each component on the CIFAR10 (C10) and CIFAR100 (C100) datasets with RME architecture. DASHFsuperscriptDASH𝐹\text{DASH}^{F}DASH start_POSTSUPERSCRIPT italic_F end_POSTSUPERSCRIPT represents our method with flat seeking mode only.
Accuracy \uparrow LD \uparrow D \uparrow Avg. Accuracy \uparrow
C10 SGD 92.61 -24.7 0.149 88.14
DASHFsuperscriptDASH𝐹\text{DASH}^{F}DASH start_POSTSUPERSCRIPT italic_F end_POSTSUPERSCRIPT 94.33 -25.8 0.034 93.21
DASH 95.17 -23.3 0.068 93.41
C100 SGD 72.55 -16.88 0.853 38.09
DASHFsuperscriptDASH𝐹\text{DASH}^{F}DASH start_POSTSUPERSCRIPT italic_F end_POSTSUPERSCRIPT 76.28 -19.47 0.123 73.38
DASH 78.72 -18.92 0.237 74.69

6 Conclusion

We developed DASH Ensemble - a learning algorithm that optimizes for deep ensembles of diverse and flat minimizers. Our method begins with a theoretical development to minimize sharpness-aware upper bound for the general loss of the ensemble, followed by a novel addition of an agnostic term to promote divergence among base classifiers. Our experimental results support the effectiveness of the agnostic term in introducing diversity in individual predictions, which ultimately leads to an improvement in generalization performance. This work has demonstrated the potential of integrating sharpness-aware minimization technique into the ensemble learning paradigm. We thus hope to motivate future works to exploit such a connection to develop more powerful and efficient ensemble models.

References

  • [1] Abbas, M., Xiao, Q., Chen, L., Chen, P.Y., Chen, T.: Sharp-maml: Sharpness-aware model-agnostic meta learning. arXiv preprint arXiv:2206.03996 (2022)
  • [2] Allen-Zhu, Z., Li, Y.: Towards understanding ensemble, knowledge distillation and self-distillation in deep learning. In: The Eleventh International Conference on Learning Representations (2022)
  • [3] Ashukha, A., Lyzhov, A., Molchanov, D., Vetrov, D.: Pitfalls of in-domain uncertainty estimation and ensembling in deep learning. In: International Conference on Learning Representations (2020)
  • [4] Bahri, D., Mobahi, H., Tay, Y.: Sharpness-aware minimization improves language model generalization. In: Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics. Association for Computational Linguistics, Dublin, Ireland (May 2022). https://doi.org/10.18653/v1/2022.acl-long.508, https://aclanthology.org/2022.acl-long.508
  • [5] Breiman, L.: Bagging predictors. Machine learning 24 (1996)
  • [6] Breiman, L.: Bias, variance, and arcing classifiers. Tech. rep., Tech. Rep. 460, Statistics Department, University of California, Berkeley … (1996)
  • [7] Cha, J., Chun, S., Lee, K., Cho, H.C., Park, S., Lee, Y., Park, S.: Swad: Domain generalization by seeking flat minima. Advances in Neural Information Processing Systems 34 (2021)
  • [8] Chaudhari, P., Choromańska, A., Soatto, S., LeCun, Y., Baldassi, C., Borgs, C., Chayes, J.T., Sagun, L., Zecchina, R.: Entropy-sgd: biasing gradient descent into wide valleys. Journal of Statistical Mechanics: Theory and Experiment (2019)
  • [9] Chen, X., Hsieh, C.J., Gong, B.: When vision transformers outperform resnets without pre-training or strong data augmentations. arXiv preprint arXiv:2106.01548 (2021)
  • [10] Dietterich, T.G.: Ensemble methods in machine learning. In: Multiple Classifier Systems: First International Workshop, Proceedings 1. Springer (2000)
  • [11] Dinh, L., Pascanu, R., Bengio, S., Bengio, Y.: Sharp minima can generalize for deep nets. In: ICML (2017)
  • [12] Dziugaite, G.K., Roy, D.M.: Computing nonvacuous generalization bounds for deep (stochastic) neural networks with many more parameters than training data. In: UAI. AUAI Press (2017)
  • [13] Evci, U., Gale, T., Menick, J., Castro, P.S., Elsen, E.: Rigging the lottery: Making all tickets winners. In: International Conference on Machine Learning. PMLR (2020)
  • [14] Evci, U., Ioannou, Y., Keskin, C., Dauphin, Y.: Gradient flow in sparse neural networks and how lottery tickets win. In: Proceedings of the AAAI Conference on Artificial Intelligence. vol. 36 (2022)
  • [15] Foret, P., Kleiner, A., Mobahi, H., Neyshabur, B.: Sharpness-aware minimization for efficiently improving generalization. In: International Conference on Learning Representations (2021), https://openreview.net/forum?id=6Tm1mposlrM
  • [16] Fort, S., Ganguli, S.: Emergent properties of the local geometry of neural loss landscapes. arXiv preprint arXiv:1910.05929 (2019)
  • [17] Fort, S., Hu, H., Lakshminarayanan, B.: Deep ensembles: A loss landscape perspective. arXiv preprint arXiv:1912.02757 (2019)
  • [18] Freund, Y., Schapire, R.E., et al.: Experiments with a new boosting algorithm. In: ICML. vol. 96 (1996)
  • [19] Friedman, J.H.: Greedy function approximation: a gradient boosting machine. Annals of statistics (2001)
  • [20] Gal, Y., Ghahramani, Z.: Dropout as a bayesian approximation: Representing model uncertainty in deep learning. In: international conference on machine learning. PMLR (2016)
  • [21] Garipov, T., Izmailov, P., Podoprikhin, D., Vetrov, D.P., Wilson, A.G.: Loss surfaces, mode connectivity, and fast ensembling of dnns. Advances in neural information processing systems 31 (2018)
  • [22] Gustafsson, F.K., Danelljan, M., Schon, T.B.: Evaluating scalable bayesian deep learning methods for robust computer vision. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition workshops (2020)
  • [23] Hansen, L.K., Salamon, P.: Neural network ensembles. IEEE transactions on pattern analysis and machine intelligence 12(10) (1990)
  • [24] Havasi, M., Jenatton, R., Fort, S., Liu, J.Z., Snoek, J., Lakshminarayanan, B., Dai, A.M., Tran, D.: Training independent subnetworks for robust prediction. arXiv preprint arXiv:2010.06610 (2020)
  • [25] Hochreiter, S., Schmidhuber, J.: Simplifying neural nets by discovering flat minima. In: NIPS. MIT Press (1994)
  • [26] Huang, G., Li, Y., Pleiss, G., Liu, Z., Hopcroft, J.E., Weinberger, K.Q.: Snapshot ensembles: Train 1, get m for free. arXiv preprint arXiv:1704.00109 (2017)
  • [27] Huang, G., Sun, Y., Liu, Z., Sedra, D., Weinberger, K.Q.: Deep networks with stochastic depth. In: Computer Vision–ECCV 2016: 14th European Conference, Amsterdam, The Netherlands, October 11–14, 2016, Proceedings, Part IV 14. Springer (2016)
  • [28] Jastrzebski, S., Kenton, Z., Arpit, D., Ballas, N., Fischer, A., Bengio, Y., Storkey, A.J.: Three factors influencing minima in sgd. ArXiv abs/1711.04623 (2017)
  • [29] Jiang, Y., Neyshabur, B., Mobahi, H., Krishnan, D., Bengio, S.: Fantastic generalization measures and where to find them. In: ICLR. OpenReview.net (2020)
  • [30] Keskar, N.S., Mudigere, D., Nocedal, J., Smelyanskiy, M., Tang, P.T.P.: On large-batch training for deep learning: Generalization gap and sharp minima. In: ICLR. OpenReview.net (2017)
  • [31] Krogh, A., Vedelsby, J.: Neural network ensembles, cross validation, and active learning. Advances in neural information processing systems 7 (1994)
  • [32] Kuncheva, L.I., Whitaker, C.J.: Measures of diversity in classifier ensembles and their relationship with the ensemble accuracy. Machine learning 51(2) (2003)
  • [33] Lakshminarayanan, B., Pritzel, A., Blundell, C.: Simple and scalable predictive uncertainty estimation using deep ensembles. Advances in neural information processing systems 30 (2017)
  • [34] Li, H., Xu, Z., Taylor, G., Studer, C., Goldstein, T.: Visualizing the loss landscape of neural nets. Advances in neural information processing systems 31 (2018)
  • [35] Li, N., Yu, Y., Zhou, Z.H.: Diversity regularized ensemble pruning. In: Machine Learning and Knowledge Discovery in Databases: European Conference, ECML PKDD 2012, Bristol, UK, September 24-28, 2012. Proceedings, Part I 23. pp. 330–345. Springer (2012)
  • [36] Liu, S., Chen, T., Atashgahi, Z., Chen, X., Sokar, G., Mocanu, E., Pechenizkiy, M., Wang, Z., Mocanu, D.C.: Deep ensembling with no overhead for either training or testing: The all-round blessings of dynamic sparsity. In: International Conference on Learning Representations (2022)
  • [37] Liu, S., Mocanu, D.C., Matavalam, A.R.R., Pei, Y., Pechenizkiy, M.: Sparse evolutionary deep learning with over one million artificial neurons on commodity hardware. Neural Computing and Applications 33 (2021)
  • [38] Madry, A., Makelov, A., Schmidt, L., Tsipras, D., Vladu, A.: Towards deep learning models resistant to adversarial attacks. In: International Conference on Learning Representations (2017)
  • [39] Neyshabur, B., Bhojanapalli, S., McAllester, D., Srebro, N.: Exploring generalization in deep learning. Advances in neural information processing systems 30 (2017)
  • [40] Ortega, L.A., Cabañas, R., Masegosa, A.: Diversity and generalization in neural network ensembles. In: International Conference on Artificial Intelligence and Statistics. pp. 11720–11743. PMLR (2022)
  • [41] Ovadia, Y., Fertig, E., Ren, J., Nado, Z., Sculley, D., Nowozin, S., Dillon, J., Lakshminarayanan, B., Snoek, J.: Can you trust your model’s uncertainty? evaluating predictive uncertainty under dataset shift. Advances in neural information processing systems 32 (2019)
  • [42] Pang, T., Xu, K., Du, C., Chen, N., Zhu, J.: Improving adversarial robustness via promoting ensemble diversity. In: International Conference on Machine Learning. PMLR (2019)
  • [43] Pereyra, G., Tucker, G., Chorowski, J., Kaiser, L., Hinton, G.E.: Regularizing neural networks by penalizing confident output distributions. In: ICLR (Workshop). OpenReview.net (2017)
  • [44] Perrone, M.P., Cooper, L.N.: When networks disagree: Ensemble methods for hybrid neural networks. In: How We Learn; How We Remember: Toward An Understanding Of Brain And Neural Systems: Selected Papers of Leon N Cooper. World Scientific (1995)
  • [45] Petzka, H., Kamp, M., Adilova, L., Sminchisescu, C., Boley, M.: Relative flatness and generalization. In: NeurIPS (2021)
  • [46] Phan, H., Tran, L., Tran, N.N., Ho, N., Phung, D., Le, T.: Improving multi-task learning via seeking task-based flat regions. arXiv preprint arXiv:2211.13723 (2022)
  • [47] Qu, Z., Li, X., Duan, R., Liu, Y., Tang, B., Lu, Z.: Generalized federated learning via sharpness aware minimization. arXiv preprint arXiv:2206.02618 (2022)
  • [48] Sinha, S., Bharadhwaj, H., Goyal, A., Larochelle, H., Garg, A., Shkurti, F.: Diversity inducing information bottleneck in model ensembles. arXiv preprint arXiv:2003.04514 (2020)
  • [49] Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I., Salakhutdinov, R.: Dropout: a simple way to prevent neural networks from overfitting. The journal of machine learning research 15(1) (2014)
  • [50] Sun, T., Zhou, Z.H.: Structural diversity for decision tree ensemble learning. Frontiers of Computer Science 12, 560–570 (2018)
  • [51] Wan, L., Zeiler, M., Zhang, S., Le Cun, Y., Fergus, R.: Regularization of neural networks using dropconnect. In: International conference on machine learning. PMLR (2013)
  • [52] Wei, C., Kakade, S., Ma, T.: The implicit and explicit regularization effects of dropout. In: International conference on machine learning. PMLR (2020)
  • [53] Wen, Y., Tran, D., Ba, J.: Batchensemble: an alternative approach to efficient ensemble and lifelong learning. arXiv preprint arXiv:2002.06715 (2020)
  • [54] Wenzel, F., Snoek, J., Tran, D., Jenatton, R.: Hyperparameter ensembles for robustness and uncertainty quantification. Advances in Neural Information Processing Systems 33 (2020)
  • [55] Yang, H., Zhang, J., Dong, H., Inkawhich, N., Gardner, A., Touchet, A., Wilkes, W., Berry, H., Li, H.: Dverge: diversifying vulnerabilities for enhanced robust generation of ensembles. Advances in Neural Information Processing Systems 33 (2020)
  • [56] Yang, Z., Li, L., Xu, X., Zuo, S., Chen, Q., Zhou, P., Rubinstein, B., Zhang, C., Li, B.: Trs: Transferability reduced ensemble via promoting gradient diversity and model smoothness. Advances in Neural Information Processing Systems 34 (2021)
  • [57] Zhang, C.X., Zhang, J.S.: Rotboost: A technique for combining rotation forest and adaboost. Pattern recognition letters 29(10) (2008)
  • [58] Zhang, L., Song, J., Gao, A., Chen, J., Bao, C., Ma, K.: Be your own teacher: Improve the performance of convolutional neural networks via self distillation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision (2019)
  • [59] Zhang, S., Liu, M., Yan, J.: The diversified ensemble neural network. Advances in Neural Information Processing Systems 33, 16001–16011 (2020)
  • [60] Zhang, Y., Xiang, T., Hospedales, T.M., Lu, H.: Deep mutual learning. 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition (2018)