HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: graphbox
  • failed: titletoc
  • failed: epic

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: CC BY-SA 4.0
arXiv:2312.15297v1 [cs.LG] 23 Dec 2023

Make Me a BNN: A Simple Strategy for Estimating
Bayesian Uncertainty from Pre-trained Models

Gianni Franchi,1, *, \dagger Olivier Laurent,1, 2, * Maxence Leguéry,1 Andrei Bursuc,3
Andrea Pilzer4 & Angela Yao5
U2IS, ENSTA Paris, Institut Polytechnique de Paris,1 Université Paris-Saclay,2
NVIDIA,3 valeo.ai,4 National University of Singapore5
Abstract

Deep Neural Networks (DNNs) are powerful tools for various computer vision tasks, yet they often struggle with reliable uncertainty quantification — a critical requirement for real-world applications. Bayesian Neural Networks (BNN) are equipped for uncertainty estimation but cannot scale to large DNNs that are highly unstable to train. To address this challenge, we introduce the Adaptable Bayesian Neural Network (ABNN), a simple and scalable strategy to seamlessly transform DNNs into BNNs in a post-hoc manner with minimal computational and training overheads. ABNN preserves the main predictive properties of DNNs while enhancing their uncertainty quantification abilities through simple BNN adaptation layers (attached to normalization layers) and a few fine-tuning steps on pre-trained models. We conduct extensive experiments across multiple datasets for image classification and semantic segmentation tasks, and our results demonstrate that ABNN achieves state-of-the-art performance without the computational budget typically associated with ensemble methods.

[Uncaptioned image]
[Uncaptioned image]
Figure 0: Benefits of the ABNN approach. (left) Evaluation of the trade-off between computational cost (training time and model size) and performance (in terms of FPR95 score, lower the better) for various uncertainty quantification techniques on CIFAR-10 [49] with WideResNet [94] and ensembles of size 4. (right) Using ABNN ensembling at test time: given an out-of-distribution input, a simple DNN may make a high-confidence incorrect predictions, whereas ABNN produces more uncertain decisions through its diverse predictions.
{NoHyper}**footnotetext: equal contribution,  \dagger [email protected]

1 Introduction

Deep Neural Networks (DNNs) have emerged as powerful tools with a profound impact on various perception tasks, such as image classification [50, 18], object detection [32, 70], natural language processing [76, 17, 69], etc. With this progress, there is growing excitement and expectation about the potential applications of DNNs across industries. To meet this end, there is a critical need to address a fundamental challenge: improving DNN reliability by quantifying the inherent uncertainty in their predictions [82, 37, 36]. Deploying DNNs in real-world applications, particularly in safety-critical domains such as autonomous driving, medical diagnoses, industrial visual inspection, etc., requires a comprehensive understanding of their limitations and vulnerabilities beyond their raw predictive accuracy, often considered a primary performance metric. By quantifying the uncertainty within these models with millions of parameters and non-trivial inner-working [1, 93] and failure modes [33, 66] in front of the many different long-tail scenarios [56, 7], we can make informed decisions about when and how to rely on their predictions.

In deep learning, uncertainty estimation has been traditionally addressed under Bayesian approaches drawing inspiration from findings in Bayesian Neural Networks (BNNs) [58, 64, 5, 85] that stand on solid theoretical grounds and properties [62, 89, 44]. BNNs estimate the posterior distribution of the model parameters given the training dataset. Ensembles can be sampled from this distribution at runtime, and their predictions can be averaged for reliable decisions. BNNs promise improved predictions and reliable uncertainty estimates with intuitive decomposition of the uncertainty sources [16, 41]. However, although they are easy to formulate, BNNs are notoriously difficult to train over large DNNs [67, 20], in particular for complex computer vision tasks, due to training instability and computational inefficiency as they are typically trained through variational inference [45]. This limitation of BNNs has inspired two major lines of research toward scalable uncertainty estimation: ensembles and last-layer uncertainty approaches.

Deep Ensembles [51] emerge as a simple and highly effective alternative to BNNs for uncertainty estimation on large DNNs. Deep Ensembles are trivial to train by essentially instantiating the same training procedure over different weight initializations of the same network and have been shown to preserve many of the properties of BNNs in terms of predictions diversity [22, 89]. This is a beneficial property for out-of-distribution (OOD) generalization [63]. However, their high computational cost (during both training and inference) makes them inapplicable to many practical applications with computational constraints. In the last few years, multiple computationally efficient alternatives have emerged aiming to reduce training cost during training and inference [59, 23, 25, 86, 30, 26, 13]. However, these methods propose specific network architectures and non-trivial trade-offs in computational cost, accuracy, and predictive uncertainty quality.

Last-layer uncertainty approaches aim for BNNs with fewer stochastic weights to produce ensembles or uncertainty estimates [60, 12, 6, 48, 9]. These methods leverage popular DNN architectures [31] to which they attach a stochastic layer and train all parameters for a complete training cycle. While training stability is improved compared to standard BNNs, joint optimization of deterministic and stochastic parameters requires careful tuning. Daxberger et al. [12] propose training the last layer separately in a post-hoc manner effectively leveraging Laplace approximation for optimization[80, 58, 72]. Decoupling the optimization of the encoder from the uncertainty layer enables the use of the typical training recipes for the encoder or simply leveraging off-the-shelf pre-trained networks. The limitation of last-layer methods is related to the access to only high-level features for producing uncertainty estimates. In contrast, signals of distribution shift or small anomaly patterns (e.g., in semantic segmentation) can be found primarily on low-level features earlier in the network. Indeed, strong uncertainty estimation methods leverage information from multiple layers of the networks [55, 25, 13].

In this work, we aim for scalable and effective uncertainty estimation without sophisticated optimization schemes and potential training instability and without compromising predictive performance. We propose a post-hoc strategy that starts from a pre-trained DNN and transforms it into a BNN with a simple plug-in module attached to the normalization layers and only a few epochs of fine-tuning. We show that this strategy, dubbed Adaptable-BNN (ABNN), can estimate the posterior distribution around the local minimum of the pre-trained model in a resource-efficient manner while still achieving competitive uncertainty estimates with diversity. Furthermore, ABNN allows for sequential training of multiple BNNs starting from the same checkpoint, thus modeling various modes within the true posterior distribution.

Our contributions are: (1) We propose ABNN, a simple strategy to transform a pre-trained DNN into a BNN with uncertainty estimation capabilities. ABNN is computationally efficient and compatible with multiple DNN architectures (ConvNets: ResNet-50, WideResnet28-10; ViTs), provided they are equipped with normalization layers. (2) We observe that the variance of the gradient for ABNN’s parameters is lower compared to that of a classic BNN, resulting in a more stable backpropagation. (3) Extensive experiments validate that ABNN, although simple and computationally frugal, achieves competitive performance in terms of accuracy and uncertainty estimation over multiple datasets and tasks: image classification (CIFAR-10, CIFAR-100 [49], ImageNet [15]) and semantic segmentation (StreetHazards, BDD-Anomaly [35], MUAD [24]) in both in- and out-of-distribution settings.

2 Related work

Epistemic uncertainty and Bayesian posterior. Tackling epistemic uncertainty estimation [39] - the uncertainty on the model itself - is essential to improve the reliability of DNNs [41]. However, obtaining satisfying approximations of this uncertainty remains a challenge as it requires a scalable estimation of the extremely high dimensional distribution of the weights, the posterior. Our work presents ABNN, a significantly more scalable method compared to the BNNs [28] that predominantly shape the landscape of epistemic uncertainty estimations [27].

Bayesian Neural Networks and Ensembles. BNNs [81] formulate probabilistic predictions by both introducing explicit and controllable prior knowledge on the network weights [64, 46] and estimating the posterior. While formulating mathematically, the posterior distribution is possible [89], its computation for modern models is intractable. This need for scalability leads to approximation techniques that include variational inference [5, 4] BNNs, which fit simpler distributions to the posterior (diagonal Gaussian for the former). Many other approximation methods have been proposed, like the efficient probabilistic backpropagation [38], Monte-Carlo Dropout [26] that model the posterior as a mixture of Diracs and Laplace methods that estimate the posterior thanks to the local curvature of the loss [80, 58, 72, 74]. However, deep ensembles [29, 51], the most successful solution is simpler. It consists of simply averaging the predictions of several independently trained models. This method is powerful as it improves the reliability and the quality of the predictions, albeit costly in training and inference. Many approximate methods stepped into the breach and proposed to reduce the number of parameters, the training time, or the number of forward passes [86, 25, 23, 30, 54, 90]. The proposed method is in this line of scalability and computational efficiency.

Post-hoc uncertainty quantification. In today’s context of computer vision, with ever-increasing datasets [75, 47], model sizes [14], and inference constraints, post-hoc methods could be a solution to benefit from both the power of foundation models [42, 73, 95] and uncertainty quantification. The Laplace method reigns among these post-hoc uncertainty quantification methods, especially with its last-layer approximations [72] that make them even more scalable. However, even with the coarse approximation of the posterior achieved by Kronecker-factored Laplace [74], it is not always very efficient with modern datasets. We propose a straightforward and effective method that works with pre-trained models.

3 Background

We start by introducing our formalism and a brief overview of the Bayesian posterior and BNNs for uncertainty quantification.

3.1 Preliminaries

Notations. Let us denote 𝒟={(𝐱i,yi)}i=1N𝒟superscriptsubscriptsubscript𝐱𝑖subscript𝑦𝑖𝑖1𝑁\mathcal{D}=\{(\mathbf{x}_{i},y_{i})\}_{i=1}^{N}caligraphic_D = { ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT the training set containing N𝑁Nitalic_N samples and labels drawn from a joint distribution P(X,Y)subscript𝑃𝑋𝑌P_{(X,Y)}italic_P start_POSTSUBSCRIPT ( italic_X , italic_Y ) end_POSTSUBSCRIPT. The input 𝐱idsubscript𝐱𝑖superscript𝑑\mathbf{x}_{i}\in\mathbb{R}^{d}bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is processed by a neural network f𝝎subscript𝑓𝝎f_{\boldsymbol{\mathbf{\omega}}}italic_f start_POSTSUBSCRIPT bold_italic_ω end_POSTSUBSCRIPT, of parameters 𝝎𝝎\boldsymbol{\mathbf{\omega}}bold_italic_ω, that outputs classification predictions y^i=f𝝎(𝐱i)subscript^𝑦𝑖subscript𝑓𝝎subscript𝐱𝑖\hat{y}_{i}=f_{\boldsymbol{\mathbf{\omega}}}(\mathbf{x}_{i})\in\mathbb{R}over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_f start_POSTSUBSCRIPT bold_italic_ω end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∈ blackboard_R.

From MLE to MAP. In our context, P(Y=yiX=𝐱i,𝝎)P(Y\!=\!y_{i}\mid X\!=\!\mathbf{x}_{i},\boldsymbol{\mathbf{\omega}})italic_P ( italic_Y = italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∣ italic_X = bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_ω ) is a categorical distribution over the classes within the range of Y𝑌Yitalic_Y. We omit the random variable notation in the following for clarity. The log-likelihood of this distribution typically corresponds to the cross-entropy loss, which practitioners often minimize with stochastic gradient descent to obtain a maximum likelihood estimate (MLE): MLE(𝝎)=(𝐱i,yi)𝒟logP(yi𝐱i,𝝎)subscriptMLE𝝎subscriptsubscript𝐱𝑖subscript𝑦𝑖𝒟𝑃conditionalsubscript𝑦𝑖subscript𝐱𝑖𝝎\mathcal{L}_{\scriptscriptstyle\text{MLE}}(\boldsymbol{\mathbf{\omega}})=-\sum% _{(\mathbf{x}_{i},y_{i})\in\mathcal{D}}\log P(y_{i}\mid\mathbf{x}_{i},% \boldsymbol{\mathbf{\omega}})caligraphic_L start_POSTSUBSCRIPT MLE end_POSTSUBSCRIPT ( bold_italic_ω ) = - ∑ start_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∈ caligraphic_D end_POSTSUBSCRIPT roman_log italic_P ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∣ bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_ω ).

Going further, the Bayesian framework allows us to incorporate prior knowledge regarding 𝝎𝝎\boldsymbol{\mathbf{\omega}}bold_italic_ω denoted as the distribution P(𝝎)𝑃𝝎P(\boldsymbol{\mathbf{\omega}})italic_P ( bold_italic_ω ) that complements the likelihood and leads to the research of the maximum a posteriori (MAP), via the minimization of the following loss function:

MAP(𝝎)=(𝐱i,yi)𝒟logP(yi𝐱i,𝝎)logP(𝝎).subscriptMAP𝝎subscriptsubscript𝐱𝑖subscript𝑦𝑖𝒟𝑃conditionalsubscript𝑦𝑖subscript𝐱𝑖𝝎𝑃𝝎\displaystyle\mathcal{L}_{\scriptscriptstyle\textrm{MAP}}(\boldsymbol{\mathbf{% \omega}})=-\sum_{(\mathbf{x}_{i},y_{i})\in\mathcal{D}}\log P(y_{i}\mid\mathbf{% x}_{i},\boldsymbol{\mathbf{\omega}})-\log P(\boldsymbol{\mathbf{\omega}}).caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) = - ∑ start_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∈ caligraphic_D end_POSTSUBSCRIPT roman_log italic_P ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∣ bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_ω ) - roman_log italic_P ( bold_italic_ω ) . (1)

The normal prior is the standard choice for P(𝝎)𝑃𝝎P(\boldsymbol{\mathbf{\omega}})italic_P ( bold_italic_ω ), leading to the omnipresent L2 weight regularization.

3.2 Bayesian Posterior and BNNs

Typically, DNNs retain a single set of weights 𝝎MAPsubscript𝝎MAP\boldsymbol{\mathbf{\omega}}_{\textrm{MAP}}bold_italic_ω start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT at the end of the training to use at inference. As such, we de facto consider this model as an oracle. In contrast, BNNs attempt to model the posterior distribution P(𝝎𝒟)𝑃conditional𝝎𝒟P(\boldsymbol{\mathbf{\omega}}\mid\mathcal{D})italic_P ( bold_italic_ω ∣ caligraphic_D ) to take all possible models into account. The prediction y𝑦yitalic_y for a new sample 𝐱𝐱\mathbf{x}bold_x is computed as the expected outcome from an infinite ensemble, including all possible weights sampled from the posterior distribution:

P(y𝐱,𝒟)=𝝎ΩP(y𝐱,𝝎)P(𝝎𝒟)𝑑𝝎.𝑃conditional𝑦𝐱𝒟subscript𝝎Ω𝑃conditional𝑦𝐱𝝎𝑃conditional𝝎𝒟differential-d𝝎P(y\mid\mathbf{x},\mathcal{D})=\int\limits_{\boldsymbol{\mathbf{\omega}}\in% \Omega}P(y\mid\mathbf{x},\boldsymbol{\mathbf{\omega}})P(\boldsymbol{\mathbf{% \omega}}\mid\mathcal{D})d\boldsymbol{\mathbf{\omega}}.italic_P ( italic_y ∣ bold_x , caligraphic_D ) = ∫ start_POSTSUBSCRIPT bold_italic_ω ∈ roman_Ω end_POSTSUBSCRIPT italic_P ( italic_y ∣ bold_x , bold_italic_ω ) italic_P ( bold_italic_ω ∣ caligraphic_D ) italic_d bold_italic_ω . (2)

However, in practice, this Bayes ensemble approach is intractable since the integral Eq. (2) is computed over the entire parameter space ΩΩ\Omegaroman_Ω. Practitioners [51] approximate this integral by averaging predictions derived from a finite set {𝝎1,𝝎M}subscript𝝎1subscript𝝎𝑀\{\boldsymbol{\mathbf{\omega}}_{1},\ldots\boldsymbol{\mathbf{\omega}}_{M}\}{ bold_italic_ω start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … bold_italic_ω start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT } of M𝑀Mitalic_M weight configurations sampled from the posterior distribution:

P(y𝐱,𝒟)1Mm=1MP(y𝐱,𝝎m).𝑃conditional𝑦𝐱𝒟1𝑀superscriptsubscript𝑚1𝑀𝑃conditional𝑦𝐱subscript𝝎𝑚\displaystyle P(y\mid\mathbf{x},\mathcal{D})\approx\frac{1}{M}\sum_{m=1}^{M}P(% y\mid\mathbf{x},\boldsymbol{\mathbf{\omega}}_{m}).italic_P ( italic_y ∣ bold_x , caligraphic_D ) ≈ divide start_ARG 1 end_ARG start_ARG italic_M end_ARG ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_P ( italic_y ∣ bold_x , bold_italic_ω start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) . (3)

Let us start with a simple Multi-Layer Perceptron (MLP) with two hidden layers without loss of generality. For a given input data point 𝐱𝐱\mathbf{x}bold_x, the prediction of the DNN is defined as, with 𝐡1subscript𝐡1\mathbf{h}_{1}bold_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and 𝐡2subscript𝐡2\mathbf{h}_{2}bold_h start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT the preactivation maps, a()𝑎a(\cdot)italic_a ( ⋅ ) the activation function:

𝐡1=W(1)𝐱𝐮1=norm(𝐡1,β1,γ1)=𝐡1μ^1σ^1×γ1+β1𝐚1=a(𝐮1)𝐮2=norm(W(2)𝐚1,β2,γ2), and 𝐚2=a(𝐮2)𝐡3=W(3)𝐚2, and P(y𝐱,𝝎)=soft(𝐡3)subscript𝐡1superscript𝑊1𝐱subscript𝐮1normsubscript𝐡1subscript𝛽1subscript𝛾1subscript𝐡1subscript^𝜇1subscript^𝜎1subscript𝛾1subscript𝛽1subscript𝐚1𝑎subscript𝐮1subscript𝐮2normsuperscript𝑊2subscript𝐚1subscript𝛽2subscript𝛾2, and subscript𝐚2𝑎subscript𝐮2subscript𝐡3superscript𝑊3subscript𝐚2, and 𝑃conditional𝑦𝐱𝝎softsubscript𝐡3\displaystyle\begin{split}&\mathbf{h}_{1}=W^{(1)}\mathbf{x}\\ &\mathbf{u}_{1}=\operatorname{norm}(\mathbf{h}_{1},\beta_{1},\gamma_{1})=\frac% {\mathbf{h}_{1}-\hat{\mu}_{1}}{\hat{\sigma}_{1}}\times\gamma_{1}+\beta_{1}\\ &\mathbf{a}_{1}=a(\mathbf{u}_{1})\\ &\mathbf{u}_{2}=\operatorname{norm}\left(W^{(2)}\mathbf{a}_{1},\beta_{2},% \gamma_{2}\right)\mbox{, and }\mathbf{a}_{2}=a(\mathbf{u}_{2})\\ &\mathbf{h}_{3}=W^{(3)}\mathbf{a}_{2}\mbox{, and }P(y\mid\mathbf{x},% \boldsymbol{\mathbf{\omega}})=\operatorname{soft}(\mathbf{h}_{3})\end{split}start_ROW start_CELL end_CELL start_CELL bold_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_W start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT bold_x end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL bold_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = roman_norm ( bold_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = divide start_ARG bold_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - over^ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG over^ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG × italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL bold_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_a ( bold_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL bold_u start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = roman_norm ( italic_W start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT bold_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) , and bold_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_a ( bold_u start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL bold_h start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT = italic_W start_POSTSUPERSCRIPT ( 3 ) end_POSTSUPERSCRIPT bold_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , and italic_P ( italic_y ∣ bold_x , bold_italic_ω ) = roman_soft ( bold_h start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ) end_CELL end_ROW (4)

In Eq. (4), soft()soft\operatorname{soft}(\cdot)roman_soft ( ⋅ ) is the softmax, 𝐚1subscript𝐚1\mathbf{a}_{1}bold_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and 𝐚2subscript𝐚2\mathbf{a}_{2}bold_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT are the hidden activations, and {W(j)}j{0,1,2}subscriptsuperscript𝑊𝑗𝑗012\{W^{(j)}\}_{j\in\{0,1,2\}}{ italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ { 0 , 1 , 2 } end_POSTSUBSCRIPT correspond to the weights of the linear layers. The operator norm(,βj,γj)normsubscript𝛽𝑗subscript𝛾𝑗\operatorname{norm}(\cdot,\beta_{j},\gamma_{j})roman_norm ( ⋅ , italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ), of trainable parameters βjsubscript𝛽𝑗\beta_{j}italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT and γjsubscript𝛾𝑗\gamma_{j}italic_γ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, can refer to any batch, layer, or instance normalization (BN, LN, IN). Finally, normnorm\operatorname{norm}roman_norm comes with its empirical mean μ^𝐮jsubscript^𝜇subscript𝐮𝑗\hat{\mu}_{\mathbf{u}_{j}}over^ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT bold_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT and variance σ^𝐮jsubscript^𝜎subscript𝐮𝑗\hat{\sigma}_{\mathbf{u}_{j}}over^ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT bold_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT. We omit the small value often added for computational stability.

In the current form, we can leverage this architecture to learn different tasks. However, modeling the uncertainty of the predictions beyond the use of softmax scores as confidence proxies is non-trivial. BNNs [5] are one solution to improve uncertainty estimation. Generally, they hypothesize the independence of the layers and sample from the resulting posterior estimate. For the j𝑗jitalic_j-th layer, this yields:

𝐮j=norm(W(j)𝐱,βj,γj)W(j)P(W(j)|𝒟), and𝐚j=a(𝐮j).subscript𝐮𝑗normsuperscript𝑊𝑗𝐱subscript𝛽𝑗subscript𝛾𝑗superscript𝑊𝑗similar-to𝑃conditionalsuperscript𝑊𝑗𝒟, andsubscript𝐚𝑗𝑎subscript𝐮𝑗\displaystyle\begin{split}&\mathbf{u}_{j}=\operatorname{norm}(W^{(j)}\mathbf{x% },\beta_{j},\gamma_{j})\mbox{, }W^{(j)}\sim P(W^{(j)}|\mathcal{D})\mbox{, and}% \\ &\mathbf{a}_{j}=a(\mathbf{u}_{j}).\\ \end{split}start_ROW start_CELL end_CELL start_CELL bold_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = roman_norm ( italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT bold_x , italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) , italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT ∼ italic_P ( italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT | caligraphic_D ) , and end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL bold_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = italic_a ( bold_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) . end_CELL end_ROW (5)

As such, BNNs approximate the marginalization (2) of the parameters - an extremely complex task - by generating multiple predictions. Variational inference BNNs [5], the most scalable version among these methods, base their estimation on the ”reparametrization trick”, here at layer j𝑗jitalic_j:

𝐮j=norm([Wμ(j)+ϵjWσ(j)]𝐡j1,βj,γj) and𝐚j=a(𝐮j),subscript𝐮𝑗normdelimited-[]subscriptsuperscript𝑊𝑗𝜇subscriptbold-italic-ϵ𝑗subscriptsuperscript𝑊𝑗𝜎subscript𝐡𝑗1subscript𝛽𝑗subscript𝛾𝑗 andsubscript𝐚𝑗𝑎subscript𝐮𝑗\displaystyle\begin{split}&\mathbf{u}_{j}=\operatorname{norm}\left(\left[W^{(j% )}_{\mu}+\boldsymbol{\mathbf{\epsilon}}_{j}W^{(j)}_{\sigma}\right]\mathbf{h}_{% j-1},\beta_{j},\gamma_{j}\right)\mbox{ and}\\ &\mathbf{a}_{j}=a(\mathbf{u}_{j}),\end{split}start_ROW start_CELL end_CELL start_CELL bold_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = roman_norm ( [ italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT + bold_italic_ϵ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT ] bold_h start_POSTSUBSCRIPT italic_j - 1 end_POSTSUBSCRIPT , italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) and end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL bold_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = italic_a ( bold_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) , end_CELL end_ROW (6)

where the matrices Wμ(j)subscriptsuperscript𝑊𝑗𝜇W^{(j)}_{\mu}italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT and Wσ(j)subscriptsuperscript𝑊𝑗𝜎W^{(j)}_{\sigma}italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT denote the mean and standard deviation of the posterior distribution of layer j𝑗jitalic_j, and ϵj𝒩(𝟎,𝟙)similar-tosubscriptbold-italic-ϵ𝑗𝒩01\boldsymbol{\mathbf{\epsilon}}_{j}\sim\mathcal{N}(\mathbf{0},\mathds{1})bold_italic_ϵ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∼ caligraphic_N ( bold_0 , blackboard_1 ) is a zero-mean unit-diagonal Gaussian vector or matrix. This method enables learning an estimate of a diagonal posterior distribution at the cost of tripling the number of parameters compared to a standard network.

4 ABNN

Refer to caption
Figure 1: Illustration of the training process for the ABNN. The procedure begins with training a single DNN 𝝎MAPsubscript𝝎MAP\boldsymbol{\mathbf{\omega}}_{\textrm{MAP}}bold_italic_ω start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT, followed by architectural adjustments to transform it into an ABNN. The final step involves fine-tuning the ABNN model.

4.1 Converting DNNs into BNNs

We base our post-hoc Bayesian strategy on pre-trained DNNs that incorporate normalization layers such as batch  [43], layer [3], or instance normalization [83]. This is not a limiting factor as most modern architectures include one type of these layers [31, 18, 57]. Subsequently, we modify these normalization layers by introducing a Gaussian perturbation, incorporating our novel Bayesian Normalization Layer (BNL). This adaptation aims to transform the initially deterministic DNN into a BNN. The introduction of the BNL allows us to efficiently leverage pre-trained models, facilitating the conversion to a BNN with minimal alterations. We propose replacing the normalization layers with our novel Bayesian normalization layers (BNL) that incorporate Gaussian noise to transform the deterministic DNNs into BNNs easily. BNLs unlock the power of pre-trained models for uncertainty-aware Bayesian networks. Formally, our BNN is defined as:

𝐮j=𝐁𝐍𝐋(W(j)𝐡j1), and𝐚j=a(𝐮j) with𝐁𝐍𝐋(𝐡j)=𝐡jμ^jσ^j×γj(1+ϵj)+βj.formulae-sequencesubscript𝐮𝑗𝐁𝐍𝐋superscript𝑊𝑗subscript𝐡𝑗1 andsubscript𝐚𝑗𝑎subscript𝐮𝑗 with𝐁𝐍𝐋subscript𝐡𝑗subscript𝐡𝑗subscript^𝜇𝑗subscript^𝜎𝑗subscript𝛾𝑗1subscriptbold-italic-ϵ𝑗subscript𝛽𝑗\displaystyle\begin{split}&\mathbf{u}_{j}=\operatorname{\textbf{BNL}}\left(W^{% (j)}\mathbf{h}_{j-1}\right),\mbox{ and}\\ &\mathbf{a}_{j}=a(\mathbf{u}_{j})\mbox{ with}\\ &\operatorname{\textbf{BNL}}(\mathbf{h}_{j})=\frac{\mathbf{h}_{j}-\hat{\mu}_{j% }}{\hat{\sigma}_{j}}\times\gamma_{j}(1+\boldsymbol{\mathbf{\epsilon}}_{j})+% \beta_{j}.\\ \end{split}start_ROW start_CELL end_CELL start_CELL bold_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = BNL ( italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT bold_h start_POSTSUBSCRIPT italic_j - 1 end_POSTSUBSCRIPT ) , and end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL bold_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = italic_a ( bold_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) with end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL BNL ( bold_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = divide start_ARG bold_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - over^ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG over^ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG × italic_γ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( 1 + bold_italic_ϵ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT . end_CELL end_ROW (7)

The empirical mean and variance are still represented by μ^𝐮jsubscript^𝜇subscript𝐮𝑗\hat{\mu}_{\mathbf{u}_{j}}over^ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT bold_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT and σ^𝐮jsubscript^𝜎subscript𝐮𝑗\hat{\sigma}_{\mathbf{u}_{j}}over^ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT bold_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT and computed through batch, layer or instance normalization. In the equation, ϵj𝒩(𝟎,𝟙)similar-tosubscriptbold-italic-ϵ𝑗𝒩01\boldsymbol{\mathbf{\epsilon}}_{j}\sim\mathcal{N}(\mathbf{0},\mathds{1})bold_italic_ϵ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∼ caligraphic_N ( bold_0 , blackboard_1 ) signifies a sample drawn from a Normal distribution, and γjsubscript𝛾𝑗\gamma_{j}italic_γ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT and βjsubscript𝛽𝑗\beta_{j}italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT are the two learnable vectors of the j𝑗jitalic_j-th layer.

The DNN being transformed into a BNN, we exclusively retrain the parameters γjsubscript𝛾𝑗\gamma_{j}italic_γ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT and βjsubscript𝛽𝑗\beta_{j}italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT for a limited number of epochs using the loss introduced in Section 4.2. To further improve its reliability and generalization properties [89], we do not train a singular ABNN, but rather multiple copies of ABNNs, as explained in section 4.2, resulting in a finite set 𝝎1,𝝎Msubscript𝝎1subscript𝝎𝑀{\boldsymbol{\mathbf{\omega}}_{1},\ldots\boldsymbol{\mathbf{\omega}}_{M}}bold_italic_ω start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … bold_italic_ω start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT of M𝑀Mitalic_M weight configurations. We discuss the benefits of this multi-modality in Appendix A.2.

During inference, for each sample from ABNN 𝝎msubscript𝝎𝑚\boldsymbol{\mathbf{\omega}}_{m}bold_italic_ω start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT, we augment the number of samples by independently sampling multiple ϵj𝒩(𝟎,𝟙)similar-tosubscriptbold-italic-ϵ𝑗𝒩01\boldsymbol{\mathbf{\epsilon}}_{j}\sim\mathcal{N}(\mathbf{0},\mathds{1})bold_italic_ϵ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∼ caligraphic_N ( bold_0 , blackboard_1 ). With ϵbold-italic-ϵ\boldsymbol{\mathbf{\epsilon}}bold_italic_ϵ the concatenation of all ϵjsubscriptbold-italic-ϵ𝑗\boldsymbol{\mathbf{\epsilon}}_{j}bold_italic_ϵ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, and {ϵl}l[1,L]subscriptsubscriptbold-italic-ϵ𝑙𝑙1𝐿\{\boldsymbol{\mathbf{\epsilon}}_{l}\}_{l\in[1,L]}{ bold_italic_ϵ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l ∈ [ 1 , italic_L ] end_POSTSUBSCRIPT the set of ϵbold-italic-ϵ\boldsymbol{\mathbf{\epsilon}}bold_italic_ϵs, each individual ABNN sample is expressed as P(yx,ω,ϵ)𝑃conditional𝑦𝑥𝜔bold-italic-ϵP(y\mid x,\omega,\boldsymbol{\mathbf{\epsilon}})italic_P ( italic_y ∣ italic_x , italic_ω , bold_italic_ϵ ). During inference, the prediction y𝑦yitalic_y for a new sample 𝐱𝐱\mathbf{x}bold_x is computed as the expected outcome from a finite ensemble, encompassing all the weights sampled from the posterior distribution:

P(y𝐱,𝒟)1MLl=1Lm=1MP(y𝐱,𝝎m,ϵl).𝑃conditional𝑦𝐱𝒟1𝑀𝐿superscriptsubscript𝑙1𝐿superscriptsubscript𝑚1𝑀𝑃conditional𝑦𝐱subscript𝝎𝑚subscriptbold-italic-ϵ𝑙\displaystyle P(y\mid\mathbf{x},\mathcal{D})\approx\frac{1}{ML}\sum_{l=1}^{L}% \sum_{m=1}^{M}P(y\mid\mathbf{x},\boldsymbol{\mathbf{\omega}}_{m},\boldsymbol{% \mathbf{\epsilon}}_{l}).italic_P ( italic_y ∣ bold_x , caligraphic_D ) ≈ divide start_ARG 1 end_ARG start_ARG italic_M italic_L end_ARG ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_P ( italic_y ∣ bold_x , bold_italic_ω start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , bold_italic_ϵ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) . (8)

4.2 ABNN training loss

The multimodality [22, 44, 53] of the posterior distribution of DNNs is a challenge to any attempt to perform variational inference with on a mono-modal distribution. Wilson and Izmailov [89] have proposed to tackle this issue by training multiple BNNs. However, such approaches inherit the instability of classical BNNs and may struggle to capture different modes accurately. ABNN encounters a similar challenge, requiring safeguards against collapsing into the same local minima during post-training. To mitigate this problem, we introduce a small perturbation to the loss function, preventing collapse and encouraging diversity into the training process. This perturbation involves a modification of the class weights within the cross-entropy loss, now defined as:

(𝝎)=(𝐱i,yi)𝒟ηilogP(yi𝐱i,𝝎).𝝎subscriptsubscript𝐱𝑖subscript𝑦𝑖𝒟subscript𝜂𝑖𝑃conditionalsubscript𝑦𝑖subscript𝐱𝑖𝝎\displaystyle\mathcal{E}(\boldsymbol{\mathbf{\omega}})=-\sum_{(\mathbf{x}_{i},% y_{i})\in\mathcal{D}}\eta_{i}\log P(y_{i}\mid\mathbf{x}_{i},\boldsymbol{% \mathbf{\omega}}).caligraphic_E ( bold_italic_ω ) = - ∑ start_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∈ caligraphic_D end_POSTSUBSCRIPT italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log italic_P ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∣ bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_ω ) . (9)

In this formula, ηisubscript𝜂𝑖\eta_{i}italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT represents the class-dependent random weight we initialize at the beginning of training. Typically, it can be set to zero or one to amplify the effect of certain classes. In contrast to classical variational BNNs [5] that optimize the evidence lower-bound loss, ABNN maximizes the MAP. The optimization involves the following loss:

(𝝎)=MAP(𝝎)+(𝝎).𝝎subscriptMAP𝝎𝝎\mathcal{L}(\boldsymbol{\mathbf{\omega}})=\mathcal{L}_{\scriptscriptstyle% \textrm{MAP}}(\boldsymbol{\mathbf{\omega}})+\mathcal{E}(\boldsymbol{\mathbf{% \omega}}).caligraphic_L ( bold_italic_ω ) = caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) + caligraphic_E ( bold_italic_ω ) . (10)

We employ this procedure to train various ABNNs, resulting in ABNNs trained with different losses due to the presence of (𝝎)𝝎\mathcal{E}(\boldsymbol{\mathbf{\omega}})caligraphic_E ( bold_italic_ω ).

4.3 ABNN training procedure

ABNN is trained through a post-hoc process designed to leverage the strength of Bayesian concepts and improve the uncertainty prediction of DNNs. The training pseudo code for ABNN, detailed in Alg. 1, outlines the step-by-step process of transforming a conventional DNN into an ensemble of Bayesian models.

We start from a pre-trained neural network (Alg. 1, line 1) and introduce the Bayesian normalization layers, replacing the old batch, instance or layer normalization of the former DNN (Alg. 1 lines 2 to 10). This operation transforms the conventional deterministic network into a BNN to help quantify the uncertainty. We initialize the weights of the new layer with the values of the replaced normalizations.

Then, we fine-tune the modified network to capture better the inherent uncertainty (Alg. 1 lines 12 to 24). The full process is described in Figure 3, providing a clear overview of the modifications made to enable Bayesian modeling. More precisely, we only fine-tune the normalization weights in Table 1. To improve the posterior estimation of our ABNN models, we fine-tune multiple instances of the normalization layers (typically 3 to 4). This ensemble approach provides robustness and contributes to a more reliable estimation of the posterior distribution. Training multiple ABNNs, each starting from the same checkpoint, enhances our ability to capture diverse modes of the true posterior, thereby improving the overall uncertainty quantification.

Algorithm 1 ABNN training procedure
1:f𝝎MAPsubscript𝑓subscript𝝎MAPf_{\boldsymbol{\mathbf{\omega}}_{\textrm{MAP}}}italic_f start_POSTSUBSCRIPT bold_italic_ω start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT end_POSTSUBSCRIPT, : pre-trained network ,λ𝜆\lambdaitalic_λ : learning rate, nb_epoch : number of epoch
2:(Step 2: adapt the DNN to a DNN)
3:# Build a list of all the normalisation layers
4:normalisation=[Batch_normalisation, Layer_normalisation, Instance_normalisation]
5:for layerf𝝎MAP.layers:layer\in\mathrm{f_{\boldsymbol{\mathbf{\omega}}_{\textrm{MAP}}}.layers:}italic_l italic_a italic_y italic_e italic_r ∈ roman_f start_POSTSUBSCRIPT bold_italic_ω start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT end_POSTSUBSCRIPT . roman_layers : to do
6:    begin
7:    (Transform all Normalization Layers)
8:    if (layernormalisation)::layer𝑛𝑜𝑟𝑚𝑎𝑙𝑖𝑠𝑎𝑡𝑖𝑜𝑛absent(\mathrm{layer}\in normalisation):( roman_layer ∈ italic_n italic_o italic_r italic_m italic_a italic_l italic_i italic_s italic_a italic_t italic_i italic_o italic_n ) : then
9:         replace layer by BNL
10:    end
11:(Step 3: train ABNN)
12:t=0
13:for (epoch)nb_epoch::𝑒𝑝𝑜𝑐nb_epochabsent(epoch)\in\mathrm{nb\_epoch:}( italic_e italic_p italic_o italic_c italic_h ) ∈ roman_nb _ roman_epoch : to do
14:    begin
15:    for (x,y)trainloader::𝑥𝑦trainloaderabsent(x,y)\in\mathrm{trainloader:}( italic_x , italic_y ) ∈ roman_trainloader : to do
16:         begin
17:         (Forward pass)
18:         xiB(t)for-allsubscript𝑥𝑖𝐵𝑡\forall x_{i}\in B(t)∀ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ italic_B ( italic_t ) calculate f𝝎(t)(xi)subscript𝑓𝝎𝑡subscript𝑥𝑖f_{\boldsymbol{\mathbf{\omega}}(t)}(x_{i})italic_f start_POSTSUBSCRIPT bold_italic_ω ( italic_t ) end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )
19:         evaluate the loss (𝝎(t),B(t))𝝎𝑡𝐵𝑡\mathcal{L}(\boldsymbol{\mathbf{\omega}}(t),B(t))caligraphic_L ( bold_italic_ω ( italic_t ) , italic_B ( italic_t ) )
20:         𝝎(t)𝝎(t1)λ𝝎(t)𝝎𝑡𝝎𝑡1𝜆subscript𝝎𝑡\boldsymbol{\mathbf{\omega}}(t)\leftarrow\boldsymbol{\mathbf{\omega}}(t-1)-% \lambda\nabla\mathcal{L}_{\boldsymbol{\mathbf{\omega}}(t)}bold_italic_ω ( italic_t ) ← bold_italic_ω ( italic_t - 1 ) - italic_λ ∇ caligraphic_L start_POSTSUBSCRIPT bold_italic_ω ( italic_t ) end_POSTSUBSCRIPT
21:         (Step update)
22:         tt+1𝑡𝑡1t\leftarrow t+1italic_t ← italic_t + 1
23:         end
24:    end

4.4 Theoretical analysis

Our approach raises several theoretical questions. In the supplementary material - as detailed in Section A.1 - we show that ABNN exhibits greater stability than classical BNNs. Indeed, in variational inference BNNs, the gradients vary greatly: ϵbold-italic-ϵ\boldsymbol{\mathbf{\epsilon}}bold_italic_ϵ, crucial for obtaining the Bayesian interpretation, often introduces instability, perturbating the training. ABNN reduces this burden by applying this term on the latent space rather than the weights, thereby reducing the variance of the gradients, as empirically demonstrated in Appendix A.1.

Another question concerns the theoretical need to modify the loss and add the second term \mathcal{E}caligraphic_E. We show in Appendix A.2 that it is theoretically sound in the case of a convex problem. Given that DNN optimization is inherently non-convex, adding this term may be theoretically debatable. However, a sensitivity analysis of this term - developed in Appendix D - shows empirical benefits for performance and uncertainty quantification

Finally, we discuss the challenge of estimating the equivalent BNNs to our networks in Appendix A.3. Despite the theoretical value this information could provide concerning the posterior, it remains unused in practice. We solely sample the ϵbold-italic-ϵ\boldsymbol{\mathbf{\epsilon}}bold_italic_ϵ and average over multiple training terms to generate robust predictions during inference.

5 Experiments & Results

We test ABNN on image classification and semantic segmentation tasks. In each case, we measure metrics relative to the performance of the models but measure their uncertainty quantification abilities. All our models are implemented in PyTorch and trained on a single Nvidia RTX 3090. Appendix G details the hyper-parameters used in our experiments across architectures and datasets.

5.1 Image classification

Method Acc \uparrow NLL \downarrow ECE \downarrow AUPR \uparrow AUC \uparrow FPR95 \downarrow 𝚫𝚫\mathbf{\Delta}bold_ΔParam \downarrow Time (h) \downarrow
CIFAR-10 ResNet-50 Single Model 95.1 0.211 3.1 95.2 91.9 23.6 \varnothing 1.7
BatchEnsemble 93.9 0.255 3.3 94.7 91.3 20.1 0.11 17.2
MIMO (ρ=1𝜌1\rho=1italic_ρ = 1) 95.4 0.197 3.0 95.1 90.8 26.0 0.07 6.7
LPBNN 95.0 0.251 9.4 98.4 96.9 10.3 1.83 17.2
Deep Ensembles 96.0 0.136 0.8 97.0 94.7 15.5 70.56 6.8
Laplace 95.3 0.160 1.3 96.0 93.3 18.8 / 1.7
ABNN 95.4 0.215 0.845 97.0 94.7 15.1 0.16 2.0
WideResNet-28×\times×10 Single Model 95.4 0.200 2.9 96.1 93.2 20.4 \varnothing 4.2
BatchEnsemble 95.6 0.206 2.7 95.5 92.5 22.1 0.10 25.6
MIMO (ρ=1𝜌1\rho=1italic_ρ = 1) 94.7 0.234 3.4 94.9 90.6 30.9 0.12 12.6
LPBNN 95.1 0.249 2.9 95.4 91.2 29.5 0.71 23.3
Deep Ensembles 95.8 0.143 1.3 97.8 96.0 12.5 109.47 16.6
Laplace 95.6 0.151 0.8 95.0 90.7 31.9 / 4.2
ABNN 93.7 0.198 1.8 98.5 96.9 12.6 0.05 5.0
CIFAR-100 ResNet-50 Single Model 78.3 0.905 8.9 87.4 77.9 57.6 \varnothing 1.7
BatchEnsemble 66.6 1.788 18.2 85.2 74.6 60.6 0.11 17.2
MIMO (ρ=1𝜌1\rho=1italic_ρ = 1) 79.0 0.876 7.9 87.5 76.9 64.7 0.63 6.7
LPBNN 78.5 1.02 11.3 88.2 77.8 73.5 1.83 17.2
Deep Ensembles 80.9 0.713 2.6 89.2 80.8 52.5 71.12 6.8
Laplace 78.2 0.987 14.2 89.2 81.0 51.8 / 1.7
ABNN 78.9 0.889 5.5 89.4 81.0 50.1 0.16 2.0
WideResNet-28×\times×10 Single Model 80.3 0.963 15.6 81.0 64.2 80.1 \varnothing 4.2
BatchEnsemble 82.3 0.835 13.0 88.1 78.2 69.8 0.10 25.6
MIMO (ρ=1𝜌1\rho=1italic_ρ = 1) 80.2 0.822 2.8 84.9 72.0 72.8 0.19 12.6
LPBNN 79.7 0.831 7.0 79.0 70.1 71.4 0.72 23.3
Deep Ensembles 82.5 0.903 22.9 81.6 67.9 71.3 109.64 16.6
Laplace 80.1 0.942 16.0 83.4 72.1 59.9 / 4.2
ABNN 80.4 1.08 5.5 85.0 75.0 57.7 0.05 5.0
Table 1: Performance comparison (averaged over five runs) on CIFAR-10/100 using ResNet-50 and Wide ResNet28×\times×10. All ensembles have M=4𝑀4M=4italic_M = 4 subnetworks. we highlight the best performances in bold. The ResNet-50 single model has respectively {23.52, 23.70}106absentsuperscript106\cdot 10^{6}⋅ 10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT parameters. 𝚫𝚫\mathbf{\Delta}bold_ΔParam is the number of parameters in excess of the corresponding method compared to the single model. Time is the training time in hours on a single RTX 3090.

Datasets. We demonstrate the efficiency of ABNN on different datasets and backbones. We start with CIFAR-10 and CIFAR-100 [49] with ResNet-50 [31] and WideResNet28-10 [94]. We then report results for ABNN on ImageNet [15] with ResNet-50 and ViT [18]. In the former case, we train all models from scratch. In the latter, we start from torchvision pre-trained models [68].

Baselines. We compare ABNN against Deep Ensembles [51] and four other ensembles: BatchEnsemble [86], MIMO [30], Masksembles [19], and Laplace [12].

Metrics. We evaluate the performance on classification tasks with the accuracy (Acc) and the Negative Log-Likelihood (NLL). We complete these metrics with the expected top-label calibration error (ECE) [61] and measure the quality of the OOD detection using the Areas Under the Precision/Recall curve (AUPR) and the operating Curve (AUC), as well as the False Positive Rate at 95%percent9595\%95 % recall (FPR95) similarly to Hendrycks et al. [34]. We express all metrics in %.

OOD detection datasets. For OOD detection tasks on CIFAR-10 and CIFAR-100, we use the SVHN dataset [65] as the out-of-distribution dataset and transform the initial problem into binary classification between in-distribution and out-of-distribution data using the maximum softmax probability as the criterion. For ImageNet, we use Describable Texture [84] as the out-of-distribution dataset.

Results. Tables 1 and 2 present the performance of ABNN across various architectures for CIFAR-10/100 and ImageNet, respectively. Notably, ABNN consistently surpasses Laplace approximation and single-model baselines on most datasets and architectures. Furthermore, ABNN exhibits competitive performance compared to Deep Ensembles, achieving equivalent results with a similar number of parameters and training time to a single model. These findings underscore ABNN as a powerful and efficient method, demonstrating superior uncertainty quantification capabilities in image classification tasks while being easier to train.

5.2 Semantic segmentation

For the semantic segmentation part, we compare ABNN against MCP [34], Deep Ensembles  [51], MC Dropout [26], TRADI [23], MIMO [30] and LP-BNN [25], on StreetHazards [35], BDD-Anomaly [35], and MUAD [24] that allows comparison on diverse uncertainty quantification aspects of semantic segmentation.

StreetHazards [35]. StreetHazards is a large-scale synthetic dataset comprising various images depicting street scenes. The dataset consists of 5,12551255,1255 , 125 images for training and an additional 1,50015001,5001 , 500 images for testing. The training dataset has pixel-wise annotations available for 13 different classes. The test dataset is designed with 13 classes seen during training and an additional 250 out-of-distribution (OOD) classes that were not part of the training set. This diverse composition allows for assessing the algorithm’s robustness in the face of various potential scenarios. In our experiments, we employed DeepLabv3+ with a ResNet-50 encoder, as introduced by Chen et al. [8].

BDD-Anomaly [35]. BDD-Anomaly, a subset of the BDD100K dataset [91], comprises 6,68866886,6886 , 688 street scenes for training and an additional 361361361361 for the test set. Within the training set, pixel-level annotations are available for 17171717 distinct classes. The test dataset consists of the same 17171717 classes seen during training and introduces 2 out-of-distribution (OOD) classes: motorcycle and train. In our experimental setup, we adopted DeepLabv3+[8] and followed the experimental protocol outlined in[35]. Similar to previous experiments, we utilized a ResNet-50 encoder [31] for the neural network architecture.

Method Acc normal-↑\uparrow ECE normal-↓\downarrow AUPR normal-↑\uparrow AUC normal-↑\uparrow FPR95 normal-↓\downarrow
ResNet-50 Single Model 77.8 12.1 18.0 80.9 68.6
BatchEnsemble 75.9 3.5 20.2 81.6 66.5
MIMO (ρ=1𝜌1\rho=1italic_ρ = 1) 77.6 14.7 18.4 81.6 66.8
Deep Ensembles 79.2 23.3 19.6 83.4 62.1
Laplace 80.4 44.3 13.9 75.9 82.8
ABNN 79.5 9.65 17.8 82.0 65.2
ViT Single Model 80.0 5.2 19.5 84.1 58.5
Deep Ensembles 81.7 13.5 21.7 85.5 60.3
Laplace 81.0 10.8 22.1 83.1 70.6
ABNN 80.6 4.32 21.7 85.4 55.1
Table 2: Performance on ImageNet using ResNet-50 and ViT concerning in distribution and out-of-distribution metrics.

MUAD [24]. MUAD consists of 3,420 images in the training set and 492 in the validation set. The test set comprises 6,501 images, distributed across various subsets: 551 in the normal set, 102 in the normal set with no shadow, 1,668 in the out-of-distribution (OOD) set. All these sets cover both day and night conditions, with a distribution of 2/3 day images and 1/3 night images. MUAD encompasses 21 classes, with the initial 19 classes mirroring those found in CityScapes [10]. Additionally, three classes are introduced to represent object anomalies and animals, adding diversity to the dataset. In our first experiment, we employed a DeepLabV3+ [8] network with a ResNet50 encoder[31] for training on MUAD.

Results. Table 3 presents the results of ABNN, compared to various baselines on the three datasets. ABNN performs competitively with Deep Ensembles, a technique known for accurately quantifying uncertainty. Moreover, our approach exhibits faster training times, making it potentially more appealing for practitioners. We have not included a comparison with Laplace Approximation, as it is not commonly applied to semantic segmentation, and adapting DNNs for Laplace Approximation is not straightforward.

Method mIoU \uparrow AUPR \uparrow AUC \uparrow FPR95 \downarrow ECE \downarrow
StreetHazards Single Model 53.90 6.91 86.60 35.74 6.52
TRADI 52.46 6.93 87.39 38.26 6.33
Deep Ensembles 55.59 8.32 87.94 30.29 5.33
MIMO 55.44 6.90 87.38 32.66 5.57
BatchEnsemble 56.16 7.59 88.17 32.85 6.09
LP-BNN 54.50 7.18 88.33 32.61 5.20
ABNN (ours) 53.82 7.85 88.39 32.02 6.09
BDD-Anomaly Single Model 47.63 4.50 85.15 28.78 17.68
TRADI 44.26 4.54 84.80 36.87 16.61
Deep Ensembles 51.07 5.24 84.80 28.55 14.19
MIMO 47.20 4.32 84.38 35.24 16.33
BatchEnsemble 48.09 4.49 84.27 30.17 16.90
LP-BNN 49.01 4.52 85.32 29.47 17.16
ABNN (ours) 48.76 5.98 85.74 29.01 14.03
MUAD Single Model 57.32 26.04 86.24 39.43 6.07
MC-Dropout 55.62 22.25 84.39 45.75 6.45
Deep Ensembles 58.29 28.02 87.10 37.60 5.88
BatchEnsemble 57.10 25.70 86.90 38.81 6.01
MIMO 57.10 24.18 86.62 34.80 5.81
ABNN (ours) 61.96 24.37 91.55 21.68 5.58
Table 3: Comparative results on the OOD task for semantic segmentation. We run all methods in similar settings using publicly available code for related methods. Results are averaged over three seeds. The architecture is a DeepLabv3+ based on ResNet50.

6 Discussions

6.1 General discussions

We develop several discussions in the supplementary materials. First, we explore the theoretical aspects, including the stability of the DNNs in Appendix A.1, the importance of multi-mode in Section A.2, and the relationship with classical BNNs in Appendix A.3. We experiment on the transfer of ViT-B-16 from Imagenet 21k [71] to CIFAR-100 which highlights the potential of ABNN in transfer learning, achieving an accuracy of 92.18%. Additionally, we perform several ablation studies, notably on the impact of discarding multi-mode or the loss term \mathcal{E}caligraphic_E (defined in Section 4.2) in Appendix D. We show that discarding \mathcal{E}caligraphic_E reduces the performance while incorporating the multi-mode improves uncertainty quantification. Moreover, in Section B, we analyze the variance of the gradients, confirming that our technique exhibits lower gradients than BNNs, making it more stable and easier to train. Finally, Appendix C delves into the variability of our method under different scenarios, exploring cases where we initiate ABNN from a single model and optimize from various initial checkpoints. Although our technique inherits the instabilities of the DNN, we observe that the standard variation is five times larger than that of the single model, indicating less stability than a standard DNN.

6.2 Diversity of ABNN

Concerning the diversity, we train ABNN on CIFAR-10 using a ResNet-50 architecture Specifically, we optimize a ResNet for 200 epochs and then fine-tune three ABNNs, starting from the optimal checkpoint. Additionally, we train two other ResNet-50s on CIFAR-10 to form a proper Deep Ensembles. As depicted in Figure 2, ABNN does not exhibit the same level of diversity as the Deep Ensembles. However, it is intriguing that even when initiated from a single DNN, ABNN manages to depart from its local minimum and explore different modes. This concept of different modes is further supported by Section E, where we analyze the mutual information of various ABNN checkpoints.

Refer to caption
(a) ABNN
Refer to caption
(b) Deep Ensembles
Figure 2: Comparison of the diversities of ABNN and Deep Ensembles [51]. T-SNE plot of the 20 principal components of the logits generated from 384 images for ABNN (a) and Deep Ensembles (b).

7 Conclusion

In conclusion, our proposed approach, ABNN, introduces a novel perspective to uncertainty quantification. Leveraging the strengths of pre-trained deterministic models, ABNN strategically transforms them into Bayesian networks with minimal modifications, offering enhanced stability and efficient posterior exploration. Through comprehensive experimental analyses, we demonstrate the effectiveness of ABNN both in predictive performance and uncertainty quantification, showcasing its potential applications in diverse scenarios.

The multi-mode characteristic of ABNN, coupled with a carefully designed loss function, not only addresses the challenge of the multi-modality of the posterior but also provides a stable and diverse ensemble of models. Our empirical evaluations on various datasets and architectures highlight the superiority of ABNN over traditional BNNs and post-hoc uncertainty quantification methods such as the Laplace approximation and showcase its competitiveness compared to state-of-the-art such as Deep Ensembles.

Moreover, ABNN exhibits promising results in transfer learning scenarios, underscoring its versatility and potential for broader applications. The insights gained from theoretical discussions and ablation studies further elucidate the underlying mechanisms of ABNN, contributing to a deeper understanding of its behavior and performance.

In summary, ABNN emerges as a robust and flexible solution for uncertainty-aware deep learning, offering a pragmatic bridge between deterministic and Bayesian paradigms. Its simplicity in implementation, coupled with superior performance and stability, positions ABNN as a valuable tool in the contemporary landscape of machine learning and Bayesian modeling.

Acknowledgments

This work was performed using HPC resources from GENCI-IDRIS (Grant 2022 - AD011011970R2) and (Grant 2023 - Grant 2022 - AD011011970R3).

References

  • Arrieta et al. [2020] Alejandro Barredo Arrieta, Natalia Díaz-Rodríguez, Javier Del Ser, Adrien Bennetot, Siham Tabik, Alberto Barbado, Salvador García, Sergio Gil-López, Daniel Molina, Richard Benjamins, et al. Explainable artificial intelligence (xai): Concepts, taxonomies, opportunities and challenges toward responsible ai. Information fusion, 2020.
  • Ashukha et al. [2019] Arsenii Ashukha, Alexander Lyzhov, Dmitry Molchanov, and Dmitry Vetrov. Pitfalls of in-domain uncertainty estimation and ensembling in deep learning. In ICLR, 2019.
  • Ba et al. [2016] Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. In NeurIPSW, 2016.
  • Blei et al. [2017] David M Blei, Alp Kucukelbir, and Jon D McAuliffe. Variational inference: A review for statisticians. JASA, 2017.
  • Blundell et al. [2015] Charles Blundell, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra. Weight uncertainty in neural network. In ICML, 2015.
  • Brosse et al. [2020] Nicolas Brosse, Carlos Riquelme, Alice Martin, Sylvain Gelly, and Éric Moulines. On last-layer algorithms for classification: Decoupling representation from uncertainty estimation. arXiv preprint arXiv:2001.08049, 2020.
  • Chan et al. [2021] Robin Chan, Krzysztof Lis, Svenja Uhlemeyer, Hermann Blum, Sina Honari, Roland Siegwart, Pascal Fua, Mathieu Salzmann, and Matthias Rottmann. Segmentmeifyoucan: A benchmark for anomaly segmentation. In NeurIPS, 2021.
  • Chen et al. [2018] Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, and Hartwig Adam. Encoder-decoder with atrous separable convolution for semantic image segmentation. In ECCV, 2018.
  • Corbière et al. [2021] Charles Corbière, Marc Lafon, Nicolas Thome, Matthieu Cord, and Patrick Pérez. Beyond first-order uncertainty estimation with evidential models for open-world recognition. In ICMLW, 2021.
  • Cordts et al. [2016] Marius Cordts, Mohamed Omran, Sebastian Ramos, Timo Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe Franke, Stefan Roth, and Bernt Schiele. The cityscapes dataset for semantic urban scene understanding. In CVPR, 2016.
  • Cubuk et al. [2020] Ekin D Cubuk, Barret Zoph, Jonathon Shlens, and Quoc V Le. Randaugment: Practical automated data augmentation with a reduced search space. In CVPR, 2020.
  • Daxberger et al. [2021a] Erik Daxberger, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, and Philipp Hennig. Laplace redux–effortless Bayesian deep learning. In NeurIPS, 2021a.
  • Daxberger et al. [2021b] Erik Daxberger, Eric Nalisnick, James U Allingham, Javier Antorán, and José Miguel Hernández-Lobato. Bayesian deep learning via subnetwork inference. In ICML, 2021b.
  • Dehghani et al. [2023] Mostafa Dehghani, Josip Djolonga, Basil Mustafa, Piotr Padlewski, Jonathan Heek, Justin Gilmer, Andreas Peter Steiner, Mathilde Caron, Robert Geirhos, Ibrahim Alabdulmohsin, et al. Scaling vision transformers to 22 billion parameters. In ICML, 2023.
  • Deng et al. [2009] Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In CVPR, 2009.
  • Depeweg et al. [2018] Stefan Depeweg, Jose-Miguel Hernandez-Lobato, Finale Doshi-Velez, and Steffen Udluft. Decomposition of uncertainty in bayesian deep learning for efficient and risk-sensitive learning. In ICML, 2018.
  • Devlin et al. [2018] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
  • Dosovitskiy et al. [2021] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. In ICLR, 2021.
  • Durasov et al. [2021] Nikita Durasov, Timur Bagautdinov, Pierre Baque, and Pascal Fua. Masksembles for uncertainty estimation. In CVPR, 2021.
  • Dusenberry et al. [2020] Michael Dusenberry, Ghassen Jerfel, Yeming Wen, Yian Ma, Jasper Snoek, Katherine Heller, Balaji Lakshminarayanan, and Dustin Tran. Efficient and scalable bayesian neural nets with rank-1 factors. In ICML, 2020.
  • Feller [1991] William Feller. An introduction to probability theory and its applications, Volume 2. John Wiley & Sons, 1991.
  • Fort et al. [2019] Stanislav Fort, Huiyi Hu, and Balaji Lakshminarayanan. Deep ensembles: A loss landscape perspective. arXiv preprint arXiv:1912.02757, 2019.
  • Franchi et al. [2020] Gianni Franchi, Andrei Bursuc, Emanuel Aldea, Séverine Dubuisson, and Isabelle Bloch. Tradi: Tracking deep neural network weight distributions. In ECCV, 2020.
  • Franchi et al. [2022] Gianni Franchi, Xuanlong Yu, Andrei Bursuc, Angel Tena, Rémi Kazmierczak, Séverine Dubuisson, Emanuel Aldea, and David Filliat. Muad: Multiple uncertainties for autonomous driving, a benchmark for multiple uncertainty types and tasks. In BMVC, 2022.
  • Franchi et al. [2023] Gianni Franchi, Andrei Bursuc, Emanuel Aldea, Séverine Dubuisson, and Isabelle Bloch. Encoding the latent posterior of bayesian neural networks for uncertainty quantification. T-PAMI, 2023.
  • Gal and Ghahramani [2016] Yarin Gal and Zoubin Ghahramani. Dropout as a bayesian approximation: Representing model uncertainty in deep learning. In ICML, 2016.
  • Gawlikowski et al. [2023] Jakob Gawlikowski, Cedrique Rovile Njieutcheu Tassi, Mohsin Ali, Jongseok Lee, Matthias Humt, Jianxiang Feng, Anna Kruspe, Rudolph Triebel, Peter Jung, Ribana Roscher, et al. A survey of uncertainty in deep neural networks. Artificial Intelligence Review, 2023.
  • Goan and Fookes [2020] Ethan Goan and Clinton Fookes. Bayesian neural networks: An introduction and survey. Case Studies in Applied Bayesian Data Science, 2020.
  • Hansen and Salamon [1990] Lars Kai Hansen and Peter Salamon. Neural network ensembles. IEEE transactions on pattern analysis and machine intelligence, 1990.
  • Havasi et al. [2021] Marton Havasi, Rodolphe Jenatton, Stanislav Fort, Jeremiah Zhe Liu, Jasper Snoek, Balaji Lakshminarayanan, Andrew Mingbo Dai, and Dustin Tran. Training independent subnetworks for robust prediction. In ICLR, 2021.
  • He et al. [2016] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In CVPR, 2016.
  • He et al. [2017] Kaiming He, Georgia Gkioxari, Piotr Dollár, and Ross Girshick. Mask r-cnn. In ICCV, 2017.
  • Hein et al. [2019] Matthias Hein, Maksym Andriushchenko, and Julian Bitterwolf. Why relu networks yield high-confidence predictions far away from the training data and how to mitigate the problem. In CVPR, 2019.
  • Hendrycks and Gimpel [2017] Dan Hendrycks and Kevin Gimpel. A baseline for detecting misclassified and out-of-distribution examples in neural networks. In ICLR, 2017.
  • Hendrycks et al. [2019] Dan Hendrycks, Steven Basart, Mantas Mazeika, Mohammadreza Mostajabi, Jacob Steinhardt, and Dawn Song. A benchmark for anomaly segmentation. arXiv preprint arXiv:1911.11132, 2019.
  • Hendrycks et al. [2021a] Dan Hendrycks, Steven Basart, Norman Mu, Saurav Kadavath, Frank Wang, Evan Dorundo, Rahul Desai, Tyler Zhu, Samyak Parajuli, Mike Guo, et al. Jacob steinhardt et justin gilmer. the many faces of robustness: A critical analysis of out-of-distribution generalization. In ICCV, 2021a.
  • Hendrycks et al. [2021b] Dan Hendrycks, Nicholas Carlini, John Schulman, and Jacob Steinhardt. Unsolved problems in ml safety. arXiv preprint arXiv:2109.13916, 2021b.
  • Hernández-Lobato and Adams [2015] José Miguel Hernández-Lobato and Ryan Adams. Probabilistic backpropagation for scalable learning of bayesian neural networks. In ICML, 2015.
  • Hora [1996] Stephen C Hora. Aleatory and epistemic uncertainty in probability elicitation with an example from hazardous waste management. Reliability Engineering & System Safety, 1996.
  • Hron et al. [2022] Jiri Hron, Roman Novak, Jeffrey Pennington, and Jascha Sohl-Dickstein. Wide bayesian neural networks have a simple weight posterior: theory and accelerated sampling. In ICML, pages 8926–8945. PMLR, 2022.
  • Hüllermeier and Waegeman [2021] Eyke Hüllermeier and Willem Waegeman. Aleatoric and epistemic uncertainty in machine learning: An introduction to concepts and methods. Machine Learning, 2021.
  • Ilharco et al. [2021] Gabriel Ilharco, Mitchell Wortsman, Ross Wightman, Cade Gordon, Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, Hongseok Namkoong, John Miller, Hannaneh Hajishirzi, Ali Farhadi, and Ludwig Schmidt. Openclip, 2021.
  • Ioffe and Szegedy [2015] Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In ICML, 2015.
  • Izmailov et al. [2021] Pavel Izmailov, Sharad Vikram, Matthew D Hoffman, and Andrew Gordon Gordon Wilson. What are bayesian neural network posteriors really like? In ICML, 2021.
  • Jordan et al. [1999] Michael I Jordan, Zoubin Ghahramani, Tommi S Jaakkola, and Lawrence K Saul. An introduction to variational methods for graphical models. ML, 1999.
  • Kendall and Gal [2017] Alex Kendall and Yarin Gal. What uncertainties do we need in bayesian deep learning for computer vision? NeurIPS, 2017.
  • Kirillov et al. [2023] Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alexander C Berg, Wan-Yen Lo, et al. Segment anything. arXiv preprint arXiv:2304.02643, 2023.
  • Kristiadi et al. [2020] Agustinus Kristiadi, Matthias Hein, and Philipp Hennig. Being bayesian, even just a bit, fixes overconfidence in relu networks. In ICML, 2020.
  • Krizhevsky [2009] Alex Krizhevsky. Learning multiple layers of features from tiny images. Technical report, MIT, 2009.
  • Krizhevsky et al. [2012] A Krizhevsky, I Sutskever, and G Hinton. Imagenet classification with deep convolutional networks. In NeurIPS, 2012.
  • Lakshminarayanan et al. [2017] Balaji Lakshminarayanan, Alexander Pritzel, and Charles Blundell. Simple and scalable predictive uncertainty estimation using deep ensembles. In NeurIPS, 2017.
  • Lambert et al. [2018] John Lambert, Ozan Sener, and Silvio Savarese. Deep learning under privileged information using heteroscedastic dropout. In CVPR, pages 8886–8895, 2018.
  • Laurent et al. [2023a] Olivier Laurent, Emanuel Aldea, and Gianni Franchi. A symmetry-aware exploration of bayesian neural network posteriors. arXiv preprint arXiv:2310.08287, 2023a.
  • Laurent et al. [2023b] Olivier Laurent, Adrien Lafage, Enzo Tartaglione, Geoffrey Daniel, Jean-Marc Martinez, Andrei Bursuc, and Gianni Franchi. Packed-ensembles for efficient uncertainty estimation. In ICLR, 2023b.
  • Lee et al. [2018] Kimin Lee, Kibok Lee, Honglak Lee, and **woo Shin. A simple unified framework for detecting out-of-distribution samples and adversarial attacks. In NeurIPS, 2018.
  • Li et al. [2022] Kaican Li, Kai Chen, Haoyu Wang, Lanqing Hong, Chaoqiang Ye, Jianhua Han, Yukuai Chen, Wei Zhang, Chun**g Xu, Dit-Yan Yeung, et al. Coda: A real-world road corner case dataset for object detection in autonomous driving. In ECCV, 2022.
  • Liu et al. [2022] Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, and Saining Xie. A convnet for the 2020s. In CVPR, 2022.
  • MacKay [1992] David JC MacKay. A practical bayesian framework for backpropagation networks. Neural computation, 1992.
  • Maddox et al. [2019] Wesley J Maddox, Pavel Izmailov, Timur Garipov, Dmitry P Vetrov, and Andrew Gordon Wilson. A simple baseline for bayesian uncertainty in deep learning. In NeurIPS, 2019.
  • Malinin and Gales [2018] Andrey Malinin and Mark Gales. Predictive uncertainty estimation via prior networks. In NeurIPS, 2018.
  • Naeini et al. [2015] Mahdi Pakdaman Naeini, Gregory F. Cooper, and Milos Hauskrecht. Obtaining well calibrated probabilities using bayesian binning. In AAAI, 2015.
  • Nalisnick [2018] Eric Thomas Nalisnick. On priors for Bayesian neural networks. University of California, Irvine, 2018.
  • Nayman et al. [2022] Niv Nayman, Avram Golbert, Asaf Noy, Tan **, and Lihi Zelnik-Manor. Diverse imagenet models transfer better. arXiv preprint arXiv:2204.09134, 2022.
  • Neal [2012] Radford M Neal. Bayesian learning for neural networks. 2012.
  • Netzer et al. [2011] Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, Bo Wu, and Andrew Y. Ng. Reading digits in natural images with unsupervised feature learning. In NeurIPSW, 2011.
  • Neuhaus et al. [2023] Yannic Neuhaus, Maximilian Augustin, Valentyn Boreiko, and Matthias Hein. Spurious features everywhere-large-scale detection of harmful spurious features in imagenet. In ICCV, 2023.
  • Ovadia et al. [2019] Yaniv Ovadia, Emily Fertig, Jie Ren, Zachary Nado, David Sculley, Sebastian Nowozin, Joshua Dillon, Balaji Lakshminarayanan, and Jasper Snoek. Can you trust your model’s uncertainty? evaluating predictive uncertainty under dataset shift. In NeurIPS, 2019.
  • Paszke et al. [2019] Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, et al. Pytorch: An imperative style, high-performance deep learning library. In NeurIPS, 2019.
  • Radford et al. [2018] Alec Radford, Karthik Narasimhan, Tim Salimans, Ilya Sutskever, et al. Improving language understanding by generative pre-training. Technical report, OpenAI, 2018.
  • Redmon et al. [2016] Joseph Redmon, Santosh Divvala, Ross Girshick, and Ali Farhadi. You only look once: Unified, real-time object detection. In CVPR, 2016.
  • Ridnik et al. [2021] Tal Ridnik, Emanuel Ben-Baruch, Asaf Noy, and Lihi Zelnik-Manor. Imagenet-21k pretraining for the masses, 2021.
  • Ritter et al. [2018] Hippolyt Ritter, Aleksandar Botev, and David Barber. A scalable laplace approximation for neural networks. In ICLR, 2018.
  • Rombach et al. [2022] Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, and Björn Ommer. High-resolution image synthesis with latent diffusion models. In CVPR, 2022.
  • Roy et al. [2022] Subhankar Roy, Martin Trapp, Andrea Pilzer, Juho Kannala, Nicu Sebe, Elisa Ricci, and Arno Solin. Uncertainty-guided source-free domain adaptation. In ECCV, 2022.
  • Schuhmann et al. [2022] Christoph Schuhmann, Romain Beaumont, Richard Vencu, Cade Gordon, Ross Wightman, Mehdi Cherti, Theo Coombes, Aarush Katta, Clayton Mullis, Mitchell Wortsman, et al. Laion-5b: An open large-scale dataset for training next generation image-text models. NeurIPS, 2022.
  • Severyn and Moschitti [2015] Aliaksei Severyn and Alessandro Moschitti. Learning to rank short text pairs with convolutional deep neural networks. In SIGIR, 2015.
  • Shanmugam et al. [2021] Divya Shanmugam, Davis Blalock, Guha Balakrishnan, and John Guttag. Better aggregation in test-time augmentation. In ICCV, pages 1214–1223, 2021.
  • Son and Kang [2023] Jongwook Son and Seokho Kang. Efficient improvement of classification accuracy via selective test-time augmentation. Information Sciences, 642:119148, 2023.
  • Szegedy et al. [2016] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jon Shlens, and Zbigniew Wojna. Rethinking the inception architecture for computer vision. In CVPR, 2016.
  • Tierney and Kadane [1986] Luke Tierney and Joseph B Kadane. Accurate approximations for posterior moments and marginal densities. Journal of the American Statistical Association, 1986.
  • Tishby et al. [1989] Tishby, Levin, and Solla. Consistent inference of probabilities in layered networks: predictions and generalizations. In IJCNN, 1989.
  • Tran et al. [2022] Dustin Tran, Jeremiah Liu, Michael W Dusenberry, Du Phan, Mark Collier, Jie Ren, Kehang Han, Zi Wang, Zelda Mariet, Huiyi Hu, et al. Plex: Towards reliability using pretrained large model extensions. arXiv preprint arXiv:2207.07411, 2022.
  • Ulyanov et al. [2016] Dmitry Ulyanov, Andrea Vedaldi, and Victor Lempitsky. Instance normalization: The missing ingredient for fast stylization. arXiv preprint arXiv:1607.08022, 2016.
  • Wang et al. [2022] Haoqi Wang, Zhizhong Li, Litong Feng, and Wayne Zhang. ViM: Out-of-distribution with virtual-logit matching. In CVPR, 2022.
  • Welling and Teh [2011] Max Welling and Yee W Teh. Bayesian learning via stochastic gradient langevin dynamics. In ICML, 2011.
  • Wen et al. [2019] Yeming Wen, Dustin Tran, and Jimmy Ba. BatchEnsemble: an alternative approach to efficient ensemble and lifelong learning. In ICLR, 2019.
  • Wightman [2019] Ross Wightman. Pytorch image models. https://github.com/rwightman/pytorch-image-models, 2019.
  • Wightman et al. [2021] Ross Wightman, Hugo Touvron, and Herve Jegou. Resnet strikes back: An improved training procedure in timm. In NeurIPSW, 2021.
  • Wilson and Izmailov [2020] Andrew G Wilson and Pavel Izmailov. Bayesian deep learning and a probabilistic perspective of generalization. NeurIPS, 2020.
  • Xia and Bouganis [2023] Guoxuan Xia and Christos-Savvas Bouganis. Window-based early-exit cascades for uncertainty estimation: When deep ensembles are more efficient than single models. In ICCV, 2023.
  • Yu et al. [2020] Fisher Yu, Haofeng Chen, Xin Wang, Wenqi Xian, Yingying Chen, Fangchen Liu, Vashisht Madhavan, and Trevor Darrell. Bdd100k: A diverse driving dataset for heterogeneous multitask learning. In CVPR, 2020.
  • Yun et al. [2019] Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, and Youngjoon Yoo. Cutmix: Regularization strategy to train strong classifiers with localizable features. In CVPR, 2019.
  • Zablocki et al. [2022] Éloi Zablocki, Hédi Ben-Younes, Patrick Pérez, and Matthieu Cord. Explainability of deep vision-based autonomous driving systems: Review and challenges. IJCV, 2022.
  • Zagoruyko and Komodakis [2016] Sergey Zagoruyko and Nikos Komodakis. Wide residual networks. In BMVC, 2016.
  • Zhai et al. [2023] Xiaohua Zhai, Basil Mustafa, Alexander Kolesnikov, and Lucas Beyer. Sigmoid loss for language image pre-training. In ICCV, 2023.
  • Zhang et al. [2018] Hongyi Zhang, Moustapha Cisse, Yann N Dauphin, and David Lopez-Paz. mixup: Beyond empirical risk minimization. In ICLR, 2018.
\startcontents
\thetitle

Supplementary Material
\printcontents 1

Contents

The supplementary material encompasses multiple details and insights complementing the main paper as follows. In Section H, we introduce and clarify the notations used throughout the paper. Theoretical insights are presented in Section A, delving into the theoretical foundations of our methods. Section B evaluates the stability of ABNN, shedding light on its robustness. Section C shifts the focus to the stability of the training procedure, an essential aspect deserving exploration in every post-hoc technique. Section D delves into a sensitivity analysis and ablation study, exploring key components’ resilience and performance impact. Section E focuses on the quality of the posterior estimated by ABNN. Finally, Sections G and F detail the training hyperparameters and showcase additional experiments, providing a comprehensive view of our methodology.

Appendix A Theoretical Analysis

In this section, we develop a mathematical formalism to study the theoretical properties of ABNN.

A.1 Stability of ABNN

Variational inference BNNs [5] are not commonly used in computer vision due to their challenges in scaling properly for deeper high capacity DNNs [20]. In this section, we derive theoretical insights that entail the greater stability of ABNN, arguably also a Variational Inference BNN (VI-BNN). We start with deriving the gradients for a layer in a classic 2-hidden-layer MLP BNN. For the gradient of the loss on the mean of the weights of layer j𝑗jitalic_j, we have:

MAP(𝝎)Wμ,i,i(j)=i′′MAP(𝝎)𝐡j+1,i′′[Wμ,i′′,i(j+1)+i′′,i(j+1)Wσ,i′′,i(j+1)]a(𝐡j,i)𝐚j1,iσj.subscriptMAP𝝎subscriptsuperscript𝑊𝑗𝜇𝑖superscript𝑖subscriptsuperscript𝑖′′subscriptMAP𝝎subscript𝐡𝑗1superscript𝑖′′delimited-[]subscriptsuperscript𝑊𝑗1𝜇superscript𝑖′′𝑖subscriptsuperscript𝑗1superscript𝑖′′𝑖subscriptsuperscript𝑊𝑗1𝜎superscript𝑖′′𝑖superscript𝑎subscript𝐡𝑗𝑖subscript𝐚𝑗1superscript𝑖subscript𝜎𝑗\frac{\partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(\boldsymbol{\mathbf{% \omega}})}{\partial W^{(j)}_{\mu,i,i^{\prime}}}=\sum_{i^{\prime\prime}}\frac{% \partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(\boldsymbol{\mathbf{\omega}% })}{\partial\mathbf{h}_{j+1,i^{\prime\prime}}}\\ \left[W^{(j+1)}_{\mu,i^{\prime\prime},i}+\mathcal{E}^{(j+1)}_{i^{\prime\prime}% ,i}W^{(j+1)}_{\sigma,i^{\prime\prime},i}\right]\frac{a^{\prime}(\mathbf{h}_{j,% i})\mathbf{a}_{j-1,i^{\prime}}}{\sigma_{j}}.start_ROW start_CELL divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_μ , italic_i , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG = ∑ start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ bold_h start_POSTSUBSCRIPT italic_j + 1 , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL [ italic_W start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_μ , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT + caligraphic_E start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_σ , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT ] divide start_ARG italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ) bold_a start_POSTSUBSCRIPT italic_j - 1 , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG . end_CELL end_ROW (11)

On its standard deviation, we have:

MAP(𝝎)Wσ,i,i(j)=i′′MAP(𝝎)𝐡j+1,i′′[Wμ,i′′,i(j+1)+i′′,i(j+1)Wσ,i′′,i(j+1)]a(𝐡j,i)ϵj,i,i𝐚j1,iσj.subscriptMAP𝝎subscriptsuperscript𝑊𝑗𝜎𝑖superscript𝑖subscriptsuperscript𝑖′′subscriptMAP𝝎subscript𝐡𝑗1superscript𝑖′′delimited-[]subscriptsuperscript𝑊𝑗1𝜇superscript𝑖′′𝑖subscriptsuperscript𝑗1superscript𝑖′′𝑖subscriptsuperscript𝑊𝑗1𝜎superscript𝑖′′𝑖superscript𝑎subscript𝐡𝑗𝑖subscriptbold-italic-ϵ𝑗𝑖superscript𝑖subscript𝐚𝑗1superscript𝑖subscript𝜎𝑗\frac{\partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(\boldsymbol{\mathbf{% \omega}})}{\partial W^{(j)}_{\sigma,i,i^{\prime}}}=\sum_{i^{\prime\prime}}% \frac{\partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(\boldsymbol{\mathbf{% \omega}})}{\partial\mathbf{h}_{j+1,i^{\prime\prime}}}\\ \left[W^{(j+1)}_{\mu,i^{\prime\prime},i}+\mathcal{E}^{(j+1)}_{i^{\prime\prime}% ,i}W^{(j+1)}_{\sigma,i^{\prime\prime},i}\right]\frac{a^{\prime}(\mathbf{h}_{j,% i})\boldsymbol{\mathbf{\epsilon}}_{j,i,i^{\prime}}\mathbf{a}_{j-1,i^{\prime}}}% {\sigma_{j}}.start_ROW start_CELL divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_σ , italic_i , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG = ∑ start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ bold_h start_POSTSUBSCRIPT italic_j + 1 , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL [ italic_W start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_μ , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT + caligraphic_E start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_σ , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT ] divide start_ARG italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ) bold_italic_ϵ start_POSTSUBSCRIPT italic_j , italic_i , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_a start_POSTSUBSCRIPT italic_j - 1 , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG . end_CELL end_ROW (12)

For the gradients of the ABNN parameters, in the case of a 2-hidden-layer MLP BNN, we have:

MAP(𝝎)γi(j)=i′′MAP(𝝎)𝐡j+1,i′′Wi′′,i(j+1)𝐡jμjσj(1+ϵj,i)a(𝐮j,i),subscriptMAP𝝎subscriptsuperscript𝛾𝑗𝑖subscriptsuperscript𝑖′′subscriptMAP𝝎subscript𝐡𝑗1superscript𝑖′′subscriptsuperscript𝑊𝑗1superscript𝑖′′𝑖subscript𝐡𝑗subscript𝜇𝑗subscript𝜎𝑗1subscriptbold-italic-ϵ𝑗𝑖superscript𝑎subscript𝐮𝑗𝑖\frac{\partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(\boldsymbol{\mathbf{% \omega}})}{\partial\gamma^{(j)}_{i}}=\sum_{i^{\prime\prime}}\frac{\partial% \mathcal{L}_{\scriptscriptstyle\text{MAP}}(\boldsymbol{\mathbf{\omega}})}{% \partial\mathbf{h}_{j+1,i^{\prime\prime}}}W^{(j+1)}_{i^{\prime\prime},i}\frac{% \mathbf{h}_{j}-\mu_{j}}{\sigma_{j}}(1+\boldsymbol{\mathbf{\epsilon}}_{j,i})a^{% \prime}(\mathbf{u}_{j,i}),divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ italic_γ start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = ∑ start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ bold_h start_POSTSUBSCRIPT italic_j + 1 , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG italic_W start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT divide start_ARG bold_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - italic_μ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG ( 1 + bold_italic_ϵ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ) italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_u start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ) , (13)

as well as, on β𝛽\betaitalic_β,

MAP(𝝎)βj,i=i′′MAP(𝝎)𝐡j+1,i′′Wi′′,i(j+1)a(𝐮j,i).subscriptMAP𝝎subscript𝛽𝑗𝑖subscriptsuperscript𝑖′′subscriptMAP𝝎subscript𝐡𝑗1superscript𝑖′′subscriptsuperscript𝑊𝑗1superscript𝑖′′𝑖superscript𝑎subscript𝐮𝑗𝑖\frac{\partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(\boldsymbol{\mathbf{% \omega}})}{\partial\beta_{j,i}}=\sum_{i^{\prime\prime}}\frac{\partial\mathcal{% L}_{\scriptscriptstyle\text{MAP}}(\boldsymbol{\mathbf{\omega}})}{\partial% \mathbf{h}_{j+1,i^{\prime\prime}}}W^{(j+1)}_{i^{\prime\prime},i}a^{\prime}(% \mathbf{u}_{j,i}).divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ italic_β start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT end_ARG = ∑ start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ bold_h start_POSTSUBSCRIPT italic_j + 1 , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG italic_W start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_u start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ) . (14)

We have four random variables: MAP(𝝎)𝐡j+1,i′′subscriptMAP𝝎subscript𝐡𝑗1superscript𝑖′′\frac{\partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(\boldsymbol{\mathbf{% \omega}})}{\partial\mathbf{h}_{j+1,i^{\prime\prime}}}divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ bold_h start_POSTSUBSCRIPT italic_j + 1 , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG, i′′,i(j+1)subscriptsuperscript𝑗1superscript𝑖′′𝑖\mathcal{E}^{(j+1)}_{i^{\prime\prime},i}caligraphic_E start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT and ϵj,isubscriptbold-italic-ϵ𝑗𝑖\boldsymbol{\mathbf{\epsilon}}_{j,i}bold_italic_ϵ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT along with a(𝐮j,i)superscript𝑎subscript𝐮𝑗𝑖a^{\prime}(\mathbf{u}_{j,i})italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_u start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ). Let’s consider calculating the conditional variance given MAP(𝝎)𝐡j+1,i′′MAP𝝎𝐡𝑗1superscript𝑖′′\frac{\partial\mathcal{L}{\scriptscriptstyle\text{MAP}}(\boldsymbol{\mathbf{% \omega}})}{\partial\mathbf{h}{j+1,i^{\prime\prime}}}divide start_ARG ∂ caligraphic_L MAP ( bold_italic_ω ) end_ARG start_ARG ∂ bold_h italic_j + 1 , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_ARG for all i′′superscript𝑖′′i^{\prime\prime}italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT. Assuming that all random variables associated with a single neuron are independent, we have:

var(MAP(𝝎)Wμ,i,i(j))=i′′(MAP(𝝎)𝐡j+1,i′′Wσ,i′′,i(j+1)𝐚j1,iσj)2×var[i′′,i(j+1)a(𝐡j,i)]varsubscriptMAP𝝎subscriptsuperscript𝑊𝑗𝜇𝑖superscript𝑖subscriptsuperscript𝑖′′superscriptsubscriptMAP𝝎subscript𝐡𝑗1superscript𝑖′′subscriptsuperscript𝑊𝑗1𝜎superscript𝑖′′𝑖subscript𝐚𝑗1superscript𝑖subscript𝜎𝑗2vardelimited-[]subscriptsuperscript𝑗1superscript𝑖′′𝑖superscript𝑎subscript𝐡𝑗𝑖\mbox{var}\left(\frac{\partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(% \boldsymbol{\mathbf{\omega}})}{\partial W^{(j)}_{\mu,i,i^{\prime}}}\right)=% \sum_{i^{\prime\prime}}\Bigg{(}\frac{\partial\mathcal{L}_{\scriptscriptstyle% \text{MAP}}(\boldsymbol{\mathbf{\omega}})}{\partial\mathbf{h}_{j+1,i^{\prime% \prime}}}\frac{W^{(j+1)}_{\sigma,i^{\prime\prime},i}\mathbf{a}_{j-1,i^{\prime}% }}{\sigma_{j}}\Bigg{)}^{2}\\ \times\mbox{var}\left[\mathcal{E}^{(j+1)}_{i^{\prime\prime},i}a^{\prime}(% \mathbf{h}_{j,i})\right]start_ROW start_CELL var ( divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_μ , italic_i , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG ) = ∑ start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ bold_h start_POSTSUBSCRIPT italic_j + 1 , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG divide start_ARG italic_W start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_σ , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT bold_a start_POSTSUBSCRIPT italic_j - 1 , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL × var [ caligraphic_E start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ) ] end_CELL end_ROW (15)

and

var(MAP(𝝎)Wσ,i,i(j))=i′′(MAP(𝝎)𝐡j+1,i′′Wσ,i′′,i(j+1)𝐚j1,iσj)2var[i′′,i(j+1)a(𝐡j,i)ϵj,i,i].varsubscriptMAP𝝎subscriptsuperscript𝑊𝑗𝜎𝑖superscript𝑖subscriptsuperscript𝑖′′superscriptsubscriptMAP𝝎subscript𝐡𝑗1superscript𝑖′′subscriptsuperscript𝑊𝑗1𝜎superscript𝑖′′𝑖subscript𝐚𝑗1superscript𝑖subscript𝜎𝑗2vardelimited-[]subscriptsuperscript𝑗1superscript𝑖′′𝑖superscript𝑎subscript𝐡𝑗𝑖subscriptbold-italic-ϵ𝑗𝑖superscript𝑖\mbox{var}\left(\frac{\partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(% \boldsymbol{\mathbf{\omega}})}{\partial W^{(j)}_{\sigma,i,i^{\prime}}}\right)=% \sum_{i^{\prime\prime}}\Bigg{(}\frac{\partial\mathcal{L}_{\scriptscriptstyle% \text{MAP}}(\boldsymbol{\mathbf{\omega}})}{\partial\mathbf{h}_{j+1,i^{\prime% \prime}}}\\ \frac{W^{(j+1)}_{\sigma,i^{\prime\prime},i}\mathbf{a}_{j-1,i^{\prime}}}{\sigma% _{j}}\Bigg{)}^{2}\mbox{var}\left[\mathcal{E}^{(j+1)}_{i^{\prime\prime},i}a^{% \prime}(\mathbf{h}_{j,i})\boldsymbol{\mathbf{\epsilon}}_{j,i,i^{\prime}}\right].start_ROW start_CELL var ( divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_σ , italic_i , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG ) = ∑ start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ bold_h start_POSTSUBSCRIPT italic_j + 1 , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL divide start_ARG italic_W start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_σ , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT bold_a start_POSTSUBSCRIPT italic_j - 1 , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT var [ caligraphic_E start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ) bold_italic_ϵ start_POSTSUBSCRIPT italic_j , italic_i , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ] . end_CELL end_ROW (16)

Using the fact that ϵj,i,isubscriptbold-italic-ϵ𝑗𝑖superscript𝑖\boldsymbol{\mathbf{\epsilon}}_{j,i,i^{\prime}}bold_italic_ϵ start_POSTSUBSCRIPT italic_j , italic_i , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT is independent of i′′,i(j+1)subscriptsuperscript𝑗1superscript𝑖′′𝑖\mathcal{E}^{(j+1)}_{i^{\prime\prime},i}caligraphic_E start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT and a(𝐡j,i)superscript𝑎subscript𝐡𝑗𝑖a^{\prime}(\mathbf{h}_{j,i})italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ) we have that

var[i′′,i(j+1)a(𝐡j,i)ϵj,i,i]=var[i′′,i(j+1)a(𝐡j,i)]+𝔼(i′′,i(j+1)a(𝐡j,i))2vardelimited-[]subscriptsuperscript𝑗1superscript𝑖′′𝑖superscript𝑎subscript𝐡𝑗𝑖subscriptbold-italic-ϵ𝑗𝑖superscript𝑖vardelimited-[]subscriptsuperscript𝑗1superscript𝑖′′𝑖superscript𝑎subscript𝐡𝑗𝑖𝔼superscriptsubscriptsuperscript𝑗1superscript𝑖′′𝑖superscript𝑎subscript𝐡𝑗𝑖2\mbox{var}\left[\mathcal{E}^{(j+1)}_{i^{\prime\prime},i}a^{\prime}(\mathbf{h}_% {j,i})\boldsymbol{\mathbf{\epsilon}}_{j,i,i^{\prime}}\right]{=}\mbox{var}\left% [\mathcal{E}^{(j+1)}_{i^{\prime\prime},i}a^{\prime}(\mathbf{h}_{j,i})\right]+% \mathbb{E}\left(\mathcal{E}^{(j+1)}_{i^{\prime\prime},i}a^{\prime}(\mathbf{h}_% {j,i})\right)^{2}var [ caligraphic_E start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ) bold_italic_ϵ start_POSTSUBSCRIPT italic_j , italic_i , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ] = var [ caligraphic_E start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ) ] + blackboard_E ( caligraphic_E start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (17)

In the case of ABNN, we can express the conditional variance as follows:

var(MAP(𝝎)γi(j))=i′′(MAP(𝝎)𝐡j+1,i′′Wi′′,i(j+1)𝐡jμ𝐡jσj)2var[ϵj,ia(𝐮j,i)]varsubscriptMAP𝝎subscriptsuperscript𝛾𝑗𝑖subscriptsuperscript𝑖′′superscriptsubscriptMAP𝝎subscript𝐡𝑗1superscript𝑖′′subscriptsuperscript𝑊𝑗1superscript𝑖′′𝑖subscript𝐡𝑗subscript𝜇subscript𝐡𝑗subscript𝜎𝑗2vardelimited-[]subscriptbold-italic-ϵ𝑗𝑖superscript𝑎subscript𝐮𝑗𝑖\mbox{var}\left(\frac{\partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(% \boldsymbol{\mathbf{\omega}})}{\partial\gamma^{(j)}_{i}}\right){=}\sum_{i^{% \prime\prime}}\left(\frac{\partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(% \boldsymbol{\mathbf{\omega}})}{\partial\mathbf{h}_{j+1,i^{\prime\prime}}}W^{(j% +1)}_{i^{\prime\prime},i}\frac{\mathbf{h}_{j}-\mu_{\mathbf{h}_{j}}}{\sigma_{j}% }\right)^{2}\mbox{var}\left[\boldsymbol{\mathbf{\epsilon}}_{j,i}a^{\prime}(% \mathbf{u}_{j,i})\right]var ( divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ italic_γ start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG ) = ∑ start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ bold_h start_POSTSUBSCRIPT italic_j + 1 , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG italic_W start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT divide start_ARG bold_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - italic_μ start_POSTSUBSCRIPT bold_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT var [ bold_italic_ϵ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_u start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ) ] (18)
var(MAP(𝝎)βj,i)=i′′(MAP(𝝎)𝐡j+1,i′′Wi′′,i(j+1))2var[a(𝐮j,i)]varsubscriptMAP𝝎subscript𝛽𝑗𝑖subscriptsuperscript𝑖′′superscriptsubscriptMAP𝝎subscript𝐡𝑗1superscript𝑖′′subscriptsuperscript𝑊𝑗1superscript𝑖′′𝑖2vardelimited-[]superscript𝑎subscript𝐮𝑗𝑖\mbox{var}\left(\frac{\partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(% \boldsymbol{\mathbf{\omega}})}{\partial\beta_{j,i}}\right){=}\sum_{i^{\prime% \prime}}\left(\frac{\partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(% \boldsymbol{\mathbf{\omega}})}{\partial\mathbf{h}_{j+1,i^{\prime\prime}}}W^{(j% +1)}_{i^{\prime\prime},i}\right)^{2}\mbox{var}\left[a^{\prime}(\mathbf{u}_{j,i% })\right]var ( divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ italic_β start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT end_ARG ) = ∑ start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ bold_h start_POSTSUBSCRIPT italic_j + 1 , italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG italic_W start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT var [ italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_u start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ) ] (19)

Assuming that var[ϵj,ia(𝐮j,i)]=var[i′′,i(j+1)a(𝐡j,i)]vardelimited-[]subscriptbold-italic-ϵ𝑗𝑖superscript𝑎subscript𝐮𝑗𝑖vardelimited-[]subscriptsuperscript𝑗1superscript𝑖′′𝑖superscript𝑎subscript𝐡𝑗𝑖\mbox{var}\left[\boldsymbol{\mathbf{\epsilon}}_{j,i}a^{\prime}(\mathbf{u}_{j,i% })\right]=\mbox{var}\left[\mathcal{E}^{(j+1)}_{i^{\prime\prime},i}a^{\prime}(% \mathbf{h}_{j,i})\right]var [ bold_italic_ϵ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_u start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ) ] = var [ caligraphic_E start_POSTSUPERSCRIPT ( italic_j + 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT , italic_i end_POSTSUBSCRIPT italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT ) ], we find that the variances of var(MAP(𝝎)Wμ,i,i(j))varsubscriptMAP𝝎subscriptsuperscript𝑊𝑗𝜇𝑖superscript𝑖\mbox{var}\left(\frac{\partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(% \boldsymbol{\mathbf{\omega}})}{\partial W^{(j)}_{\mu,i,i^{\prime}}}\right)var ( divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_μ , italic_i , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG ) and var(MAP(𝝎)Wσ,i,i(j))varsubscriptMAP𝝎subscriptsuperscript𝑊𝑗𝜎𝑖superscript𝑖\mbox{var}\left(\frac{\partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(% \boldsymbol{\mathbf{\omega}})}{\partial W^{(j)}_{\sigma,i,i^{\prime}}}\right)var ( divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_σ , italic_i , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG ) are directly proportional to the variance of var(MAP(𝝎)γj,i)varsubscriptMAP𝝎subscript𝛾𝑗𝑖\mbox{var}\left(\frac{\partial\mathcal{L}_{\scriptscriptstyle\text{MAP}}(% \boldsymbol{\mathbf{\omega}})}{\partial\gamma_{j,i}}\right)var ( divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) end_ARG start_ARG ∂ italic_γ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT end_ARG ).This proportionality is associated with the magnitudes of the weight values, and we assume that they are of similar magnitude. Consequently, the variance of the gradient of the parameters βj,isubscript𝛽𝑗𝑖\beta_{j,i}italic_β start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT is the smallest among all, followed by the variance of the parameters γj,isubscript𝛾𝑗𝑖\gamma_{j,i}italic_γ start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT and Wμ,i,i(j)subscriptsuperscript𝑊𝑗𝜇𝑖superscript𝑖W^{(j)}_{\mu,i,i^{\prime}}italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_μ , italic_i , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, which are roughly equivalent, with the highest variance observed for Wσ,i,i(j)subscriptsuperscript𝑊𝑗𝜎𝑖superscript𝑖W^{(j)}_{\sigma,i,i^{\prime}}italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_σ , italic_i , italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT. This property results in more stable backpropagation for ABNN compared to classic VI-BNNs.

A.2 Multi-modes with ABNN

The posterior of the DNN often comprises multiple modes [89, 53], making it non-trivial for an unimodal distribution chosen to represent the BNN’s posterior to account for these different modes effectively. One approach to address this issue is to train multiple BNNs, as proposed in the multi-SWAG method by Wilson and Izmailov [89]. However, adapting this strategy to VI-BNNs inherits the instability issue from classic BNNs and may struggle to fit multiple modes accurately.

Our solution, ABNN, also faces a similar challenge, where we need to ensure that the technique doesn’t collapse into the same local minima during training. We introduce a small perturbation to the loss function to prevent this collapse, which helps diversify the optimization process. This perturbation involves modifying the class weights within the cross-entropy loss. More precisely, contrary to classic VI-BNN that optimizes the Evidence Lower Bound (ELBO) loss, we propose to maximize the MAP. ABNNs optimize the following loss:

Refer to caption
Figure 3: Illustration of the training loss (in blue) and the corresponding posterior distribution (in red). Due to the multi-modal nature of the posterior, training multiple ABNNs with distinct final weights (such as 𝝎1*superscriptsubscript𝝎1\boldsymbol{\mathbf{\omega}}_{1}^{*}bold_italic_ω start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT, 𝝎2*superscriptsubscript𝝎2\boldsymbol{\mathbf{\omega}}_{2}^{*}bold_italic_ω start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT, and 𝝎3*superscriptsubscript𝝎3\boldsymbol{\mathbf{\omega}}_{3}^{*}bold_italic_ω start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT) enables sampling from different modes, enhancing the overall estimation of the posterior.
(𝝎)=(𝐱i,yi)𝒟α(yi)logP(yi𝐱i,𝝎)logP(𝝎),𝝎subscriptsubscript𝐱𝑖subscript𝑦𝑖𝒟𝛼subscript𝑦𝑖𝑃conditionalsubscript𝑦𝑖subscript𝐱𝑖𝝎𝑃𝝎\mathcal{L}(\boldsymbol{\mathbf{\omega}})=-\sum_{(\mathbf{x}_{i},y_{i})\in% \mathcal{D}}\alpha(y_{i})\log P(y_{i}\mid\mathbf{x}_{i},\boldsymbol{\mathbf{% \omega}})-\log P(\boldsymbol{\mathbf{\omega}}),caligraphic_L ( bold_italic_ω ) = - ∑ start_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∈ caligraphic_D end_POSTSUBSCRIPT italic_α ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_log italic_P ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∣ bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_ω ) - roman_log italic_P ( bold_italic_ω ) , (20)

In the standard cross-entropy loss, all classes are given equal weight, typically represented as α(yi)=1𝛼subscript𝑦𝑖1\alpha(y_{i})=1italic_α ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = 1. However, our approach deliberately introduces the random prior: random weight adjustments for certain classes, denoted as ηisubscript𝜂𝑖\eta_{i}italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (such that α(yi)=1+ηi𝛼subscript𝑦𝑖1subscript𝜂𝑖\alpha(y_{i})=1+\eta_{i}italic_α ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = 1 + italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT). This manipulation encourages various ABNNs to specialize as experts in different classes. Consequently, the training loss is formulated as follows:

(𝝎)=MAP(𝝎)+(𝝎)𝝎subscriptMAP𝝎𝝎\mathcal{L}(\boldsymbol{\mathbf{\omega}})=\mathcal{L}_{\scriptscriptstyle\text% {MAP}}(\boldsymbol{\mathbf{\omega}})+\mathcal{E}(\boldsymbol{\mathbf{\omega}})caligraphic_L ( bold_italic_ω ) = caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω ) + caligraphic_E ( bold_italic_ω ) (21)

where

(𝝎)=(𝐱i,yi)𝒟ηilogP(yi𝐱i,𝝎).𝝎subscriptsubscript𝐱𝑖subscript𝑦𝑖𝒟subscript𝜂𝑖𝑃conditionalsubscript𝑦𝑖subscript𝐱𝑖𝝎\mathcal{E}(\boldsymbol{\mathbf{\omega}})=-\sum_{(\mathbf{x}_{i},y_{i})\in% \mathcal{D}}\eta_{i}\log P(y_{i}\mid\mathbf{x}_{i},\boldsymbol{\mathbf{\omega}% }).caligraphic_E ( bold_italic_ω ) = - ∑ start_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∈ caligraphic_D end_POSTSUBSCRIPT italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log italic_P ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∣ bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_ω ) . (22)

Let’s denote 𝝎(0)superscript𝝎0\boldsymbol{\mathbf{\omega}}^{(0)}bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT as the parameter configuration that minimizes MAPsubscriptMAP\mathcal{L_{\scriptscriptstyle\text{MAP}}}caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT. Let us suppose for simplicity that the loss function is convex to provide a theoretical grounding to the random prior. After a single step of gradient descent (GD), we have:

𝝎(1)=𝝎(0)λ(𝝎(0)),superscript𝝎1superscript𝝎0𝜆superscript𝝎0\boldsymbol{\mathbf{\omega}}^{(1)}=\boldsymbol{\mathbf{\omega}}^{(0)}-\lambda% \nabla\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{(0)}),bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT = bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT - italic_λ ∇ caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) , (23)

where 𝝎(1)superscript𝝎1\boldsymbol{\mathbf{\omega}}^{(1)}bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT represents the parameters after the first optimization step, the superscript denotes the iteration number, and λ𝜆\lambdaitalic_λ is the learning rate.

We can express the updated loss (𝝎(1))superscript𝝎1\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{(1)})caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) using the GD, as shown in Eq. (24):

(𝝎(1))=(𝝎(0)λ(𝝎(0)))superscript𝝎1superscript𝝎0𝜆superscript𝝎0\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{(1)})=\mathcal{L}(\boldsymbol{% \mathbf{\omega}}^{(0)}-\lambda\nabla\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{% (0)}))caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) = caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT - italic_λ ∇ caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) ) (24)

Now, by applying a first-order Taylor expansion to \mathcal{L}caligraphic_L, we can express the updated loss (𝝎(1))superscript𝝎1\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{(1)})caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) as a function of the initial loss (𝝎(0))superscript𝝎0\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{(0)})caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) and the gradient update, as shown in Eq. (25):

(𝝎(1))=(𝝎(0))λ(𝝎(0))t(𝝎(0))superscript𝝎1superscript𝝎0𝜆superscriptsuperscript𝝎0𝑡superscript𝝎0\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{(1)})=\mathcal{L}(\boldsymbol{% \mathbf{\omega}}^{(0)})-\lambda\nabla\mathcal{L}(\boldsymbol{\mathbf{\omega}}^% {(0)})^{t}\nabla\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{(0)})caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) = caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) - italic_λ ∇ caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ∇ caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) (25)

This equation can be further simplified by noticing that MAP(𝝎(0))=0subscriptMAPsuperscript𝝎00\nabla\mathcal{L_{\scriptscriptstyle\text{MAP}}}(\boldsymbol{\mathbf{\omega}}^% {(0)})=0∇ caligraphic_L start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) = 0:

(𝝎(1))=(𝝎(0))λ(𝝎(0))t(𝝎(0))superscript𝝎1superscript𝝎0𝜆superscriptsuperscript𝝎0𝑡superscript𝝎0\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{(1)})=\mathcal{L}(\boldsymbol{% \mathbf{\omega}}^{(0)})-\lambda\nabla\mathcal{E}(\boldsymbol{\mathbf{\omega}}^% {(0)})^{t}\nabla\mathcal{E}(\boldsymbol{\mathbf{\omega}}^{(0)})caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) = caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) - italic_λ ∇ caligraphic_E ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ∇ caligraphic_E ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) (26)

Starting with Eq. (26), we have:

(𝝎(1))=(𝝎(0))λ(𝐱i,yi)𝒟(𝐱i,yi)𝒟ηiηilogP(yi𝐱i,𝝎(0))logP(yi𝐱i,𝝎(0))superscript𝝎1superscript𝝎0𝜆subscriptsubscript𝐱𝑖subscript𝑦𝑖𝒟subscriptsubscript𝐱superscript𝑖subscript𝑦superscript𝑖𝒟subscript𝜂𝑖subscript𝜂superscript𝑖𝑃conditionalsubscript𝑦superscript𝑖subscript𝐱superscript𝑖superscript𝝎0𝑃conditionalsubscript𝑦𝑖subscript𝐱𝑖superscript𝝎0\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{(1)})=\mathcal{L}(\boldsymbol{% \mathbf{\omega}}^{(0)})-\lambda\sum_{(\mathbf{x}_{i},y_{i})\in\mathcal{D}}\sum% _{(\mathbf{x}_{i^{\prime}},y_{i^{\prime}})\in\mathcal{D}}\eta_{i}\eta_{i^{% \prime}}\\ \log P(y_{i^{\prime}}\mid\mathbf{x}_{i^{\prime}},\boldsymbol{\mathbf{\omega}}^% {(0)})\log P(y_{i}\mid\mathbf{x}_{i},\boldsymbol{\mathbf{\omega}}^{(0)})start_ROW start_CELL caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) = caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) - italic_λ ∑ start_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∈ caligraphic_D end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) ∈ caligraphic_D end_POSTSUBSCRIPT italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_η start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL roman_log italic_P ( italic_y start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∣ bold_x start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) roman_log italic_P ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∣ bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) end_CELL end_ROW (27)

Under the assumption that ληiηi𝜆subscript𝜂𝑖subscript𝜂superscript𝑖\lambda\eta_{i}\eta_{i^{\prime}}italic_λ italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_η start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT is small and that logP(yi𝐱i,𝝎)𝑃conditionalsubscript𝑦superscript𝑖subscript𝐱superscript𝑖𝝎\log P(y_{i^{\prime}}\mid\mathbf{x}_{i^{\prime}},\boldsymbol{\mathbf{\omega}})roman_log italic_P ( italic_y start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∣ bold_x start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , bold_italic_ω ) is bounded, we can approximate the loss as follows: (𝝎(1))(𝝎(0))similar-to-or-equalssuperscript𝝎1superscript𝝎0\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{(1)})\simeq\mathcal{L}(\boldsymbol{% \mathbf{\omega}}^{(0)})caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) ≃ caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ). Consequently, we have (𝝎(1))(𝝎(0))similar-to-or-equalssuperscript𝝎1superscript𝝎0\nabla\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{(1)})\simeq\nabla\mathcal{L}(% \boldsymbol{\mathbf{\omega}}^{(0)})∇ caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) ≃ ∇ caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ), and after another optimization step, 𝝎(2)𝝎2\boldsymbol{\mathbf{\omega}}(2)bold_italic_ω ( 2 ) is updated as 𝝎(2)=𝝎(1)λ(𝝎(1))=𝝎(0)2λ(𝝎(0))𝝎2superscript𝝎1𝜆superscript𝝎1superscript𝝎02𝜆superscript𝝎0\boldsymbol{\mathbf{\omega}}(2)=\boldsymbol{\mathbf{\omega}}^{(1)}-\lambda% \nabla\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{(1)})=\boldsymbol{\mathbf{% \omega}}^{(0)}-2\lambda\nabla\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{(0)})bold_italic_ω ( 2 ) = bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT - italic_λ ∇ caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ) = bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT - 2 italic_λ ∇ caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ).

By applying the same technique iteratively, for all t𝑡titalic_t, we can approximate the loss as (𝝎(t))(𝝎(0))similar-to-or-equalssuperscript𝝎𝑡superscript𝝎0\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{(t)})\simeq\mathcal{L}(\boldsymbol{% \mathbf{\omega}}^{(0)})caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ≃ caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ). This leads also to the relationship 𝝎(t)=𝝎(0)tλ(𝝎(0))superscript𝝎𝑡superscript𝝎0𝑡𝜆superscript𝝎0\boldsymbol{\mathbf{\omega}}^{(t)}=\boldsymbol{\mathbf{\omega}}^{(0)}-t\lambda% \nabla\mathcal{L}(\boldsymbol{\mathbf{\omega}}^{(0)})bold_italic_ω start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT - italic_t italic_λ ∇ caligraphic_L ( bold_italic_ω start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ).

Under the conditions that the loss is convex, the derivatives of the DNN are bounded, and ληiηi𝜆subscript𝜂𝑖subscript𝜂superscript𝑖\lambda\eta_{i}\eta_{i^{\prime}}italic_λ italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_η start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT is small, we can find minima of \mathcal{L}caligraphic_L with similar loss values by introducing weight diversity. This loss function can be valuable in encouraging each DNN to escape from the global minima, particularly in convex cases. In non-convex cases, standard SGD may already help escape from local minima, but this additional loss may offer extra assistance in avoiding the same local minima. In Figure 3, we present a visualization depicting the training loss alongside the posterior distribution. This Figure highlights the importance of training multiple ABNNs with different optimal solutions to improve the quality in estimating the posterior distribution.

A.3 Discussion on the Bayesian Neural Network Nature of ABNN

The Law of the Unconscious Statistician (LOTUS) [21] is a theorem in probability theory that offers a method for computing the expected value of a function of a random variable. Hence, for a continuous random variable X𝑋Xitalic_X with a probability density function fX(x)subscript𝑓𝑋𝑥f_{X}(x)italic_f start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT ( italic_x ), the expected value of a function Y=g(X)𝑌𝑔𝑋Y=g(X)italic_Y = italic_g ( italic_X ) is expressed as:

EY(Y)=EX[g(X)]subscript𝐸𝑌𝑌subscript𝐸𝑋delimited-[]𝑔𝑋E_{Y}(Y)=E_{X}[g(X)]italic_E start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ( italic_Y ) = italic_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ italic_g ( italic_X ) ] (28)

In our scenario, let Ujsubscript𝑈𝑗U_{j}italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT denote the random variable associated with 𝐮jsubscript𝐮𝑗\mathbf{u}_{j}bold_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, and W(j)superscript𝑊𝑗W^{(j)}italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT represent the random variable W(j)superscript𝑊𝑗W^{(j)}italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT (we use the same letter for simplification). Thus, we have:

EUj(Uj)=𝐡jμ𝐡jσ𝐡j2+ϵγj+βjsubscript𝐸subscript𝑈𝑗subscript𝑈𝑗subscript𝐡𝑗subscript𝜇subscript𝐡𝑗superscriptsubscript𝜎subscript𝐡𝑗2italic-ϵsubscript𝛾𝑗subscript𝛽𝑗E_{U_{j}}(U_{j})=\frac{\mathbf{h}_{j}-\mu_{\mathbf{h}_{j}}}{\sqrt{\sigma_{% \mathbf{h}_{j}}^{2}+\epsilon}}\gamma_{j}+\beta_{j}italic_E start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = divide start_ARG bold_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - italic_μ start_POSTSUBSCRIPT bold_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_σ start_POSTSUBSCRIPT bold_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_ϵ end_ARG end_ARG italic_γ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT (29)

For simplification, we set βjsubscript𝛽𝑗\beta_{j}italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT=0, which leads to

varUj(Uj)=(𝐡jμ𝐡jσ𝐡j2+ϵγj)2subscriptvarsubscript𝑈𝑗subscript𝑈𝑗superscriptsubscript𝐡𝑗subscript𝜇subscript𝐡𝑗superscriptsubscript𝜎subscript𝐡𝑗2italic-ϵsubscript𝛾𝑗2\mbox{var}_{U_{j}}(U_{j})=\left(\frac{\mathbf{h}_{j}-\mu_{\mathbf{h}_{j}}}{% \sqrt{\sigma_{\mathbf{h}_{j}}^{2}+\epsilon}}\gamma_{j}\right)^{2}var start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = ( divide start_ARG bold_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - italic_μ start_POSTSUBSCRIPT bold_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_σ start_POSTSUBSCRIPT bold_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_ϵ end_ARG end_ARG italic_γ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (30)

Here, the parameters γjsubscript𝛾𝑗\gamma_{j}italic_γ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT and βjsubscript𝛽𝑗\beta_{j}italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT are optimized to obtain the best Bayesian Neural Network (BNN)

Appendix B Experiment on the variance of the gradient

To validate our hypothesis that ABNN is more stable than VI-BNNS, as detailed in section A.1, we analyze the variances of the gradients of classic DNNs, VI-BNNs, and ABNNs. Table 4 reveals that the gradient variance of BNNs are significantly greater than that of DNNs, aligning with the inherent challenges in training BNNs. Notably, ABNN exhibits a considerably lower gradient, stemming from only the weights of BNL are trained. Both empirical observations and theoretical considerations affirm the superior stability of ABNN on this side. Additionally, Table 4 includes results for a VI-BNN, which, as discussed in this section, exhibits suboptimal performance. Noteworthy other works [20, 40] also express concerns regarding the stability of VI-BNNs.

Method variance (104superscript10410^{-4}10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT)
ResNet50 1.43
ResNet50 BNN 2.54
ResNet50 ABNN 3.20102absentsuperscript102\cdot 10^{-2}⋅ 10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT
Wide-ResNet 1.37
Wide-ResNet BNN 2.53
Wide-ResNet ABNN 9.91102absentsuperscript102\cdot 10^{-2}⋅ 10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT
Table 4: Variance of gradients on relevant parameters
Single model: every weight (no bias)
BNN: means of the weight samplers
ABNN: parameters linked to the BNL layers

Appendix C Discussion on stability of the training of ABNN

ABNN being a post-hoc technique, it is imperative to ensure that it does not introduce instability to DNNs, especially in the critical domain of uncertainty quantification. We train multiple single models based on a ResNet-50 architecture on CIFAR-10 to verify this point, calculating the standard deviation of the different metrics. Additionally, we derive several ABNNs starting from these checkpoints and assess the variance. Finally, we apply our technique to train an ABNN for each checkpoint of the single models. Table 5 demonstrates that our approach minimally increases the variance, confirming that it does not introduce instability to the uncertainty quantification process.

Acc ECE AUPR AUC FPR95
Single model 0.356 0.0002 1.036 1.445 2.740
one 𝝎MAPsubscript𝝎MAP\boldsymbol{\mathbf{\omega}}_{\textrm{MAP}}bold_italic_ω start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT + Multiple ABNN 0.066 0.0009 0.130 0.224 0.658
Multiple 𝝎MAPsubscript𝝎MAP\boldsymbol{\mathbf{\omega}}_{\textrm{MAP}}bold_italic_ω start_POSTSUBSCRIPT MAP end_POSTSUBSCRIPT + ABNN 0.324 0.0011 1.131 1.570 3.202
Table 5: Standard Deviation (SD) Comparison of ABNN and DNN. The first row presents the SD of a single DNN, while the second row depicts the SD of ABNN starting from a single checkpoint. The last row quantifies the SD of ABNN when trained from different checkpoints. All training scenarios use the optimal hyperparameters for ABNN on a ResNet-50 architecture on the CIFAR-10 dataset.

Appendix D Ablation study of ABNN and Sensitivity analysis

In Table 6, we conduct a study to inspect the impact of adding or removing two characteristics from our method. First, we investigate whether the addition of the random prior (linked to \mathcal{E}caligraphic_E term in Eq. (22)), which introduces disturbance to the loss, improves the performance of ABNN. We test the Random Prior (RP) to check whether \mathcal{E}caligraphic_E contributes positively. Notably, when training a single RP model (when MM is ), it appears to degrade performance as it corrupts the cross-entropy. Conversely, in the case of training multiple ABNNs (MM is ), RP seems to improve uncertainty quantification metrics. The second aspect is the training of multiple modes: Table 6 shows that incorporating multiple modes (MM is ) improves the quality of the uncertainty quantification, in particular for OOD detection.

RP MM Acc \uparrow ECE \downarrow AUPR \uparrow AUC \uparrow FPR95 \downarrow
CIFAR-10 ResNet-50 94.7 1.0 96.8 94.1 17.3
95.2 0.89 96.7 94.4 15.3
95.3 1.34 97.1 94.8 15.7
95.4 0.845 97.0 94.7 15.1
WideResNet 94.4 1.34 96.4 93.7 18.2
92.8 2.2 97.5 95.3 14.8
94.9 1.38 97.3 95.1 15.7
93.7 1.8 98.5 96.9 12.6
CIFAR-100 ResNet-50 78.8 5.7 89.3 80.8 50.4
78.7 5.5 89.4 81.0 50.1
78.3 5.8 89.6 81.6 48.2
78.8 5.6 89.7 81.6 49.0
WideResNet 80.5 5.6 84.0 72.6 62.2
79.6 5.5 84.7 74.9 55.8
80.4 5.5 85.0 75.0 57.7
78.2 5.9 87.8 79.7 49.0
Table 6: Performance comparison (averaged over five runs) on CIFAR-10/100 using ResNet-50 and WideResNet28×\times×10. RP is Random Prior, and MM is multi-mode. For OOD detection, we use the SVHN dataset.
Coef. LR Acc \uparrow ECE \downarrow AUPR \uparrow AUC \uparrow FPR95 \downarrow
×101absentsuperscript101\times 10^{-1}× 10 start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT 95.5 0.9 96.2 93.0 20.7
×2101absent2superscript101\times 2\cdot 10^{-1}× 2 ⋅ 10 start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT 95.4 0.9 96.5 93.8 18.2
×5101absent5superscript101\times 5\cdot 10^{-1}× 5 ⋅ 10 start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT 95.4 1.0 96.3 93.3 19.9
×1absent1\times 1× 1 95.4 0.9 97.0 94.7 15.1
×2absent2\times 2× 2 95.3 1.1 95.7 92.3 22.0
×5absent5\times 5× 5 94.6 1.5 96.1 93.1 19.6
×10absent10\times 10× 10 93.9 1.1 96.7 94.2 19.2
Table 7: Sensitivity Analysis of the Learning Rate on CIFAR-10. e conducted training for ABNN using a learning rate set at 0.00570.00570.00570.0057 multiplied by the Coef LR.
Coef. LR Acc \uparrow ECE \downarrow AUPR \uparrow AUC \uparrow FPR95 \downarrow
×101absentsuperscript101\times 10^{-1}× 10 start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT 79.0 5.4 89.0 80.3 51.5
×2101absent2superscript101\times 2\cdot 10^{-1}× 2 ⋅ 10 start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT 78.9 5.5 88.6 79.9 52.5
×5101absent5superscript101\times 5\cdot 10^{-1}× 5 ⋅ 10 start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT 78.9 5.6 88.9 80.0 52.5
×1absent1\times 1× 1 78.9 5.5 89.4 81.0 50.1
×2absent2\times 2× 2 78.8 5.7 88.6 79.8 52.6
×15absent15\times 15× 15 79.0 5.5 88.8 80.1 52.0
×10absent10\times 10× 10 78.8 5.7 88.8 80.2 51.4
Table 8: Sensitivity Analysis of the Learning Rate on CIFAR-100. e conducted training for ABNN using a learning rate set at 0.001390.001390.001390.00139 multiplied by the Coef. LR.

We analyze the performance variations in Tables 7 and 8 by modifying the learning rate during the fine-tuning phase. It’s essential to highlight that this hyperparameter is the only one of ABNN. We fine-tune a single model with various learning rates for this evaluation after adapting it to ABNN. Remarkably, the learning rate appears non-critical, as the performances on CIFAR-10 and CIFAR-100 exhibit minimal variations, around one percent.

Appendix E Discussion on ABNN’s diversity

We trained a Deep Ensembles of ResNet-50 architecture for 200 epochs, specifically three networks. As demonstrated in [22], Deep Ensembles strikes a good balance between accuracy and diversity, enabling effective uncertainty quantification. Simultaneously, we train three ABNNs, initializing them from the same checkpoint. It’s important to highlight that all ABNN parameters were trained for this experiment. To visually compare the results, akin to [22] (Figure 2.c), we conduct a t-SNE analysis on the latent space of all the checkpoints after each epoch. Initially, ABNN shows limited diversity due to having only one checkpoint. However, as training progresses, some diversity emerges, though less than observed in Deep Ensembles. Table 9 illustrates that ABNN exhibits lower mutual information than BNN on both in-distribution (IDs) and out-of-distribution (OOD) samples. Yet, interestingly, ABNN achieves a superior mutual information ratio on OODs/IDs. It’s worth noting that mutual information serves as a metric for measuring diversity and can also quantify epistemic uncertainty. Mutual information is defined for a finite set {𝝎1,𝝎M}subscript𝝎1subscript𝝎𝑀\{\boldsymbol{\mathbf{\omega}}_{1},\ldots\boldsymbol{\mathbf{\omega}}_{M}\}{ bold_italic_ω start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … bold_italic_ω start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT } of M𝑀Mitalic_M weight configurations sampled from the posterior distribution as :

I(P(𝝎𝒟))Epistemic uncertainty=(1Mm=1MP(y𝐱,𝝎m))Total uncertainty1Mm=1M(P(y𝐱,𝝎m))Aleatoric uncertainty.subscript𝐼𝑃conditional𝝎𝒟Epistemic uncertaintysubscript1𝑀superscriptsubscript𝑚1𝑀𝑃conditional𝑦𝐱subscript𝝎𝑚Total uncertaintysubscript1𝑀superscriptsubscript𝑚1𝑀𝑃conditional𝑦𝐱subscript𝝎𝑚Aleatoric uncertainty\underbrace{I\left(P(\boldsymbol{\mathbf{\omega}}\mid\mathcal{D})\right)}_{% \mbox{Epistemic uncertainty}}=\underbrace{\mathcal{H}(\frac{1}{M}\sum_{m=1}^{M% }P(y\mid\mathbf{x},\boldsymbol{\mathbf{\omega}}_{m}))}_{\mbox{Total % uncertainty}}\\ -\underbrace{\frac{1}{M}\sum_{m=1}^{M}\mathcal{H}(P(y\mid\mathbf{x},% \boldsymbol{\mathbf{\omega}}_{m}))}_{\mbox{Aleatoric uncertainty}}.start_ROW start_CELL under⏟ start_ARG italic_I ( italic_P ( bold_italic_ω ∣ caligraphic_D ) ) end_ARG start_POSTSUBSCRIPT Epistemic uncertainty end_POSTSUBSCRIPT = under⏟ start_ARG caligraphic_H ( divide start_ARG 1 end_ARG start_ARG italic_M end_ARG ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_P ( italic_y ∣ bold_x , bold_italic_ω start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) ) end_ARG start_POSTSUBSCRIPT Total uncertainty end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL - under⏟ start_ARG divide start_ARG 1 end_ARG start_ARG italic_M end_ARG ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT caligraphic_H ( italic_P ( italic_y ∣ bold_x , bold_italic_ω start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) ) end_ARG start_POSTSUBSCRIPT Aleatoric uncertainty end_POSTSUBSCRIPT . end_CELL end_ROW (31)

with ()\mathcal{H}(\cdot)caligraphic_H ( ⋅ ) the entropy is defined by :

(P(y𝐱,𝝎))=yP(yi𝐱i,𝝎)logP(yi𝐱i,𝝎)𝑃conditional𝑦𝐱𝝎subscript𝑦𝑃conditionalsubscript𝑦𝑖subscript𝐱𝑖𝝎𝑃conditionalsubscript𝑦𝑖subscript𝐱𝑖𝝎\mathcal{H}(P(y\mid\mathbf{x},\boldsymbol{\mathbf{\omega}}))=-\sum_{y}P(y_{i}% \mid\mathbf{x}_{i},\boldsymbol{\mathbf{\omega}})\log P(y_{i}\mid\mathbf{x}_{i}% ,\boldsymbol{\mathbf{\omega}})caligraphic_H ( italic_P ( italic_y ∣ bold_x , bold_italic_ω ) ) = - ∑ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT italic_P ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∣ bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_ω ) roman_log italic_P ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∣ bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_ω ) (32)

The good mutual information ratio highlights ABNN’s superior ability to detect out-of-distribution samples compared to BNN. Moreover, Figure 2 illustrates that benefiting from multiple training instances, ABNN can effectively model multi-modes—a capability beyond the reach of classical BNNs. This is particularly valuable given the inherent multi-modal nature of the posterior.

ID OOD OOD/ID
ABNN 1.39e-4 5.22e-4 3.76
BNN 0.139 0.179 1.28
Table 9: Evaluation of the average Mutual Information on the test set and the OOD set of the different sample from an ABNN or a BNN.

Appendix F Extra Experiments

F.1 ViT transfer learning

We also show that ABNN can be used in contexts of transfer learning. Table 10 presents the comparative performance of ViT B-16 pre-trained on ImageNet-21k and fine-tuned with and without a mono-modal ABNN on CIFAR-100 in 10,000 steps. Despite the mono-modality of ABNN in this experiment, we show that the corresponding ViT outperforms the classic fine-tuning in calibration and AUPR.

Acc ECE AUPR AUC FPR95
Single model 92.0 4.4 96.6 92.7 28.1
ABNN 91.9 1.4 96.9 92.5 33.9
Table 10: Transfer learning of a ViT pre-trained on ImageNet-21k on CIFAR-100. The first line is the classic pre-trained model fine-tuned with ABNN, and the second is fine-tuned with the classic layer normalization.

F.2 Extra baselines

To benchmark our method against other posthoc techniques, we we implement a variant of Test-Time Augmentation [2, 77, 78], incorporating random Gaussian noise with a specified standard deviation of 0.080.080.080.08. The objective is to introduce diversity and ensemble different predictions by leveraging the added noise. Like [52], we experimented with adding noise to the latent space, representing the scenario where ABNN is not trained. We tested various levels of standard deviation, and the corresponding results are summarized in Table 11. Notably, the untrained ABNN does not perform effectively, underscoring the significance of a brief finetuning phase. Additionally, we train a VI-BNN, a non-posthoc technique, to understand the performance of a traditional BNN. It’s noteworthy that VI-BNNs proved challenging to train and demonstrated subpar performance.

Acc \uparrow ECE \downarrow AUPR \uparrow AUC \uparrow FPR95 \downarrow
CIFAR-10 Single model 95.53 0.83 96.52 93.70 18.43
VI BNN 75.66 5.40 80.60 69.53 66.62
Test-Time Augmentation 89.95 2.49 93.78 89.95 25.12
Noise on latent space std=0.01 95.51 0.82 96.51 93.68 18.56
Noise on latent space std=0.1 95.46 0.90 95.88 92.64 20.94
Noise on latent space std=1 18.66 31.58 71.89 50.86 93.27
ABNN 95.43 0.85 97.03 94.73 15.11
CIFAR-100 Single model 79.05 5.34 88.72 79.96 52.04
VI BNN 41.17 8.97 78.44 61.82 86.15
Test-Time Augmentation 72.65 7.67 86.64 76.89 55.52
Noise on latent space std=0.01 79.03 5.39 88.68 79.85 52.54
Noise on latent space std=0.1 77.49 6.36 86.53 75.19 64.42
Noise on latent space std=1 01.03 11.91 74.08 50.92 95.74
ABNN 78.94 5.47 89.36 81.04 50.12
Table 11: Performance comparison of different Post-hoc and BNN uncertainty quantification baselines on CIFAR-10/100 using ResNet-50.

Appendix G Training hyperparameters

Table 12 provides a detailed overview of all the hyperparameters employed throughout our study. We use SGD in conjunction with a multistep learning-rate scheduler for image classification tasks, adjusting the rate by multiplying it by γ𝛾\gammaitalic_γ-lr at each milestone. It’s important to note that, for stability reasons, BatchEnsemble based on ResNet-50 employed a lower learning rate of 0.08, deviating from the default 0.1. Our ”Medium” data augmentation strategy encompasses a blend of Mixup [96] and Cutmix [92], with a switch probability of 0.5. Additionally, timm’s augmentation classes [87] were incorporated with coefficients of 0.5 and 0.2. RandAugment [11] with parameters m=9𝑚9m=9italic_m = 9, n=2𝑛2n=2italic_n = 2, and mstd=1𝑚𝑠𝑡𝑑1mstd=1italic_m italic_s italic_t italic_d = 1, along with label-smoothing [79] of intensity 0.1, were also applied.

In the case of ImageNet, we follow the A3 procedure outlined in [88] for all models. It’s worth mentioning that training according to the exact A3 procedure was not consistently feasible; please refer to the specific subsections for additional details.

We highlight that, to enhance training stability and fasten the training, we introduc a hyperparameter α𝛼\alphaitalic_α in the BNL layer. This transforms the layer as follows:

𝐁𝐍𝐋(𝐡j)=𝐡jμ^jσ^j×γj(1+ϵjα)+βj.𝐁𝐍𝐋subscript𝐡𝑗subscript𝐡𝑗subscript^𝜇𝑗subscript^𝜎𝑗subscript𝛾𝑗1subscriptbold-italic-ϵ𝑗𝛼subscript𝛽𝑗\displaystyle\operatorname{\textbf{BNL}}(\mathbf{h}_{j})=\frac{\mathbf{h}_{j}-% \hat{\mu}_{j}}{\hat{\sigma}_{j}}\times\gamma_{j}(1+\boldsymbol{\mathbf{% \epsilon}}_{j}\alpha)+\beta_{j}.BNL ( bold_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = divide start_ARG bold_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - over^ start_ARG italic_μ end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG over^ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG × italic_γ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( 1 + bold_italic_ϵ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_α ) + italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT . (33)

The hyperparameter α𝛼\alphaitalic_α is typically set to 0.01, except in the case of ViT, where specific considerations may apply.

Dataset Networks Epochs Batch size start lr Momentum Weight decay γ𝛾\gammaitalic_γ-lr Milestones Data augmentations
C10 R50 200 128 0.1 0.9 5104absentsuperscript104\cdot 10^{-4}⋅ 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT 0.2 60, 120, 160 HFlip
C10 WR28-10 200 128 0.1 0.9 5104absentsuperscript104\cdot 10^{-4}⋅ 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT 0.2 60, 120, 160 HFlip
C100 R50 200 128 0.1 0.9 5104absentsuperscript104\cdot 10^{-4}⋅ 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT 0.2 60, 120, 160 HFlip
C100 WR28-10 200 128 0.1 0.9 5104absentsuperscript104\cdot 10^{-4}⋅ 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT 0.2 60, 120, 160 Medium
Table 12: Hyperparameters for image classification experiments. HFlip denotes the classical horizontal flip.
Dataset Networks Epochs Batch size start lr Alpha Momentum Weight decay γ𝛾\gammaitalic_γ-lr Milestones Data augmentations
C10 R50 2 128 0,0057 0.01 0.9 5104absentsuperscript104\cdot 10^{-4}⋅ 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT / / HFlip
C10 WR28-10 2 128 0,0091 0.01 0.9 5104absentsuperscript104\cdot 10^{-4}⋅ 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT / / HFlip
C100 R50 10 128 0,00139 0.01 0.9 5104absentsuperscript104\cdot 10^{-4}⋅ 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT 0.5 5 HFlip
C100 WR28-10 10 128 0,034 0.01 0.9 5104absentsuperscript104\cdot 10^{-4}⋅ 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT 0.5 5 HFlip
Imagenet R50 1 128 0,00439 0,01 0.9 5104absentsuperscript104\cdot 10^{-4}⋅ 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT / / HFlip
Imagenet ViT 0.25 128 7.33106absentsuperscript106\cdot 10^{-6}⋅ 10 start_POSTSUPERSCRIPT - 6 end_POSTSUPERSCRIPT 0.00035 0.9 71067superscript1067\cdot 10^{-6}7 ⋅ 10 start_POSTSUPERSCRIPT - 6 end_POSTSUPERSCRIPT 51045superscript1045\cdot 10^{-4}5 ⋅ 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT Constant HFlip
StreetHazards DeepLabv3+ 10 4 0.01 0.01 0.9 1104absentsuperscript104\cdot 10^{-4}⋅ 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT 0.9 Polynomial HFlip, RandomCrop, ColorJitter, RandomScale
BDD-Anomaly DeepLabv3+ 10 4 0.01 0.01 0.9 1104absentsuperscript104\cdot 10^{-4}⋅ 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT 0.9 / HFlip, RandomCrop, ColorJitter, RandomScale
MUAD DeepLabv3+ 6 4 0.044 0.01 0.9 1104absentsuperscript104\cdot 10^{-4}⋅ 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT 0.9 / HFlip, RandomCrop, ColorJitter, RandomScale
Table 13: Hyperparameters for image classification experiments with ABNN. HFlip denotes the classical horizontal flip. Random prior has been used.

Appendix H Notations

We summarize the main notations used in the paper in Table 14.

Table 14: Summary of the main notations of the paper.
Notations Meaning
𝒟={(𝐱i,𝐲i)}i=1N𝒟superscriptsubscriptsubscript𝐱𝑖subscript𝐲𝑖𝑖1𝑁\mathcal{D}=\{(\mathbf{x}_{i},\mathbf{y}_{i})\}_{i=1}^{N}caligraphic_D = { ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT The set of N𝑁Nitalic_N data samples and the corresponding labels
j𝑗jitalic_j The index of the current layer
𝝎𝝎\boldsymbol{\mathbf{\omega}}bold_italic_ω The set of all the weights of the DNN
𝝎mP(𝝎|𝒟)similar-tosubscript𝝎𝑚𝑃conditional𝝎𝒟\boldsymbol{\mathbf{\omega}}_{m}\sim P(\boldsymbol{\mathbf{\omega}}|\mathcal{D})bold_italic_ω start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ∼ italic_P ( bold_italic_ω | caligraphic_D ) The m𝑚mitalic_m-th sample from the concatenation of weights of the posterior of the DNN.
M𝑀Mitalic_M The number of networks in an ensemble
𝝎(t)superscript𝝎𝑡\boldsymbol{\mathbf{\omega}}^{(t)}bold_italic_ω start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT The concatenation of all the weights of the DNN after t𝑡titalic_t steps of optimization
𝐡jsubscript𝐡𝑗\mathbf{h}_{j}bold_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT The pre-activation feature map and output of layer (j1)𝑗1(j-1)( italic_j - 1 ) & input of layer j𝑗jitalic_j before normalization
𝐮jsubscript𝐮𝑗\mathbf{u}_{j}bold_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT The pre-activation feature map and output of layer (j1)𝑗1(j-1)( italic_j - 1 ) & input of layer j𝑗jitalic_j before normalization
γjsubscript𝛾𝑗\gamma_{j}italic_γ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT βjsubscript𝛽𝑗\beta_{j}italic_β start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT The parameters of the batch, instance, or layer normalization of layer j𝑗jitalic_j
μjsubscript𝜇𝑗\mu_{j}italic_μ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT σjsubscript𝜎𝑗\sigma_{j}italic_σ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT The empirical mean and variance used by the batch, instance, or layer normalization of layer j𝑗jitalic_j
𝐚jsubscript𝐚𝑗\mathbf{a}_{j}bold_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT The feature map and output of layer j𝑗jitalic_j, 𝐚j=a(𝐮j)subscript𝐚𝑗𝑎subscript𝐮𝑗\mathbf{a}_{j}=a(\mathbf{u}_{j})bold_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = italic_a ( bold_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT )
a()𝑎a(\cdot)italic_a ( ⋅ ) The activation function
W(j)superscript𝑊𝑗W^{(j)}italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT The weights of the j𝑗jitalic_j-th layer in a Multi-Layer Perceptron (MLP).
Wμ(j)subscriptsuperscript𝑊𝑗𝜇W^{(j)}_{\mu}italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT The mean weights of the j𝑗jitalic_j-th layer in a BNN MLP
Wσ(j)subscriptsuperscript𝑊𝑗𝜎W^{(j)}_{\sigma}italic_W start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT The standard variation weights of the j𝑗jitalic_j-th layer in a BNN MLP
ϵ(j)𝒩(𝟎,𝟙)similar-tosuperscriptitalic-ϵ𝑗𝒩01\epsilon^{(j)}\sim\mathcal{N}(\mathbf{0},\mathds{1})italic_ϵ start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT ∼ caligraphic_N ( bold_0 , blackboard_1 ) A vector sampled from a standard normal distribution at layer j𝑗jitalic_j
ϵitalic-ϵ\epsilonitalic_ϵ The concatenation of all the jlimit-from𝑗j-italic_j - the ϵ(j)superscriptitalic-ϵ𝑗\epsilon^{(j)}italic_ϵ start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT of each layer j𝑗jitalic_j
ϵlsubscriptitalic-ϵ𝑙\epsilon_{l}italic_ϵ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT The l𝑙litalic_l-th sample of ϵitalic-ϵ\epsilonitalic_ϵ
\mathcal{H}caligraphic_H The entropy function