License: arXiv.org perpetual non-exclusive license
arXiv:2307.07816v2 [cs.LG] 04 Dec 2023

Minimal Random Code Learning with Mean-KL Parameterization

Jihao Andreas Lin    Gergely Flamich    José Miguel Hernández-Lobato
Abstract

This paper studies the qualitative behavior and robustness of two variants of Minimal Random Code Learning (MIRACLE) used to compress variational Bayesian neural networks. MIRACLE implements a powerful, conditionally Gaussian variational approximation for the weight posterior Q𝐰subscript𝑄𝐰Q_{\mathbf{w}}italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT and uses relative entropy coding to compress a weight sample from the posterior using a Gaussian coding distribution P𝐰subscript𝑃𝐰P_{\mathbf{w}}italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT. To achieve the desired compression rate, DKL[Q𝐰P𝐰]subscript𝐷KLdelimited-[]conditionalsubscript𝑄𝐰subscript𝑃𝐰D_{\mathrm{KL}}[Q_{\mathbf{w}}\|P_{\mathbf{w}}]italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ] must be constrained, which requires a computationally expensive annealing procedure under the conventional mean-variance (Mean-Var) parameterization for Q𝐰subscript𝑄𝐰Q_{\mathbf{w}}italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT. Instead, we parameterize Q𝐰subscript𝑄𝐰Q_{\mathbf{w}}italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT by its mean and KL divergence from P𝐰subscript𝑃𝐰P_{\mathbf{w}}italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT to constrain the compression cost to the desired value by construction. We demonstrate that variational training with Mean-KL parameterization converges twice as fast and maintains predictive performance after compression. Furthermore, we show that Mean-KL leads to more meaningful variational distributions with heavier tails and compressed weight samples which are more robust to pruning.

Neural Compression

Refer to caption
Figure 1: Layerwise histograms of variational mean and log standard deviation for Mean-Var (blue) versus Mean-KL (orange) parameterizations. Mean-Var struggles to learn meaningful distributions: means are concentrated at zero and standard deviations are clustered at high values. Mean-KL learns more reasonable distributions with heavier tails and a broader range of values.

1 Introduction

With the ever-growing size of neural network architectures, such as large language models (e.g. BERT, Kenton & Toutanova, 2019), it is now a key challenge to ensure their memory and energy efficiency. While there is a large literature on model compression, almost all works rely on some form of quantization scheme. In this paper, we consider an alternative method to quantization, namely Minimal Random Code Learning (MIRACLE, Havasi et al., 2019), which has recently demonstrated state-of-the-art performance for neural network compression. The MIRACLE framework employs a powerful, conditionally Gaussian variational distribution Q𝐰subscript𝑄𝐰Q_{\mathbf{w}}italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT over the weights 𝐰𝐰{\mathbf{w}}bold_w of a neural network and uses relative entropy coding (REC, Flamich et al., 2020) with a Gaussian coding distribution P𝐰subscript𝑃𝐰P_{\mathbf{w}}italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT to encode a random weight sample from Q𝐰subscript𝑄𝐰Q_{\mathbf{w}}italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT. The average coding cost of encoding a weight sample is DKL[Q𝐰P𝐰]subscript𝐷KLdelimited-[]conditionalsubscript𝑄𝐰subscript𝑃𝐰D_{\mathrm{KL}}[Q_{\mathbf{w}}\|P_{\mathbf{w}}]italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ], which needs to be carefully controlled in a practical compression scheme. To this end, we propose to use Mean-KL parameterization for Gaussians (Flamich et al., 2022) to parameterize Q𝐰subscript𝑄𝐰Q_{\mathbf{w}}italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT, allowing explicit control over DKL[Q𝐰P𝐰]subscript𝐷KLdelimited-[]conditionalsubscript𝑄𝐰subscript𝑃𝐰D_{\mathrm{KL}}[Q_{\mathbf{w}}\|P_{\mathbf{w}}]italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ] by construction. We demonstrate that Mean-KL leads to many practical benefits over the conventional mean-variance (Mean-Var) parameterization used by Havasi et al. 2019, which requires a computationally expensive annealing procedure to control the coding cost. In particular, we show that, compared to Mean-Var parameterization, variational training converges in half the number of iterations using Mean-KL parameterization while maintaining predictive performance after compression. Furthermore, we illustrate that the resulting variational distribution exhibits more meaningful shapes with heavy tails, which makes the compressed weight sample more robust against zero pruning.

2 Background

Minimal Random Code Learning

Havasi et al. 2019 consider a setting akin to the β𝛽\betaitalic_β-VAE (Higgins et al., 2017) to encode neural network weights with a limited information budget C𝐶Citalic_C. To this end, let 𝒳,𝒴𝒳𝒴{\mathcal{X}},{\mathcal{Y}}caligraphic_X , caligraphic_Y and 𝒲𝒲{\mathcal{W}}caligraphic_W be the input, output and weight spaces, respectively, let 𝒟={(𝐱n,𝐲n)}n=1N𝒟superscriptsubscriptsubscript𝐱𝑛subscript𝐲𝑛𝑛1𝑁\mathcal{D}=\{({\mathbf{x}}_{n},{\mathbf{y}}_{n})\}_{n=1}^{N}caligraphic_D = { ( bold_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , bold_y start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT be a dataset and let h:𝒳×𝒲𝒴:𝒳𝒲𝒴h:{\mathcal{X}}\times{\mathcal{W}}\to{\mathcal{Y}}italic_h : caligraphic_X × caligraphic_W → caligraphic_Y be a neural network with input 𝐱𝐱{\mathbf{x}}bold_x and weights 𝐰𝐰{\mathbf{w}}bold_w. To control the information content of the weights, let P𝐰subscript𝑃𝐰P_{\mathbf{w}}italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT be the coding distribution and Q𝐰subscript𝑄𝐰Q_{\mathbf{w}}italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT be the variational distribution over 𝐰𝐰{\mathbf{w}}bold_w. In this setting, Hinton & Van Camp 1993 show that the information content of the weights is DKL[Q𝐰P𝐰]subscript𝐷KLdelimited-[]conditionalsubscript𝑄𝐰subscript𝑃𝐰D_{\mathrm{KL}}[Q_{\mathbf{w}}\|P_{\mathbf{w}}]italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ]. Further, let Δ:𝒴×𝒴+:Δ𝒴𝒴superscript{\Delta:\mathcal{Y}\times\mathcal{Y}\to\mathbb{R}^{+}}roman_Δ : caligraphic_Y × caligraphic_Y → blackboard_R start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT be a distortion function. MIRACLE minimizes

𝔼𝐰Q𝐰(𝐱,𝐲)𝒟Δ(𝐲,h(𝐱,𝐰))+βDKL[Q𝐰P𝐰]subscript𝔼similar-to𝐰subscript𝑄𝐰subscript𝐱𝐲𝒟Δ𝐲𝐱𝐰𝛽subscript𝐷KLdelimited-[]conditionalsubscript𝑄𝐰subscript𝑃𝐰\displaystyle\mathbb{E}_{{\mathbf{w}}\sim Q_{\mathbf{w}}}\sum_{({\mathbf{x}},{% \mathbf{y}})\in\mathcal{D}}\Delta({\mathbf{y}},h({\mathbf{x}},{\mathbf{w}}))+% \beta D_{\mathrm{KL}}[Q_{\mathbf{w}}\|P_{\mathbf{w}}]blackboard_E start_POSTSUBSCRIPT bold_w ∼ italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT ( bold_x , bold_y ) ∈ caligraphic_D end_POSTSUBSCRIPT roman_Δ ( bold_y , italic_h ( bold_x , bold_w ) ) + italic_β italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ] (1)

with respect to Q𝐰subscript𝑄𝐰Q_{\mathbf{w}}italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT to minimize distortion within the given information budget of DKL[Q𝐰P𝐰]=Csubscript𝐷KLdelimited-[]conditionalsubscript𝑄𝐰subscript𝑃𝐰𝐶D_{\mathrm{KL}}[Q_{\mathbf{w}}\|P_{\mathbf{w}}]=Citalic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ] = italic_C nats. During optimization, β𝛽\betaitalic_β is dynamically adapted to anneal the KL divergence, such that the constraint is eventually satisfied.

In this paper, we encode the samples using minimal random coding (MRC, Havasi et al., 2019) for simplicity, though more sophisticated approaches, such as A* coding (Flamich et al., 2022) or greedy Poisson rejection sampling (Flamich, 2023), have been invented. Given a suitable Q𝐰subscript𝑄𝐰Q_{\mathbf{w}}italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT, a random sample from Q𝐰subscript𝑄𝐰Q_{\mathbf{w}}italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT is compressed by first drawing K=exp(DKL[Q𝐰P𝐰])𝐾subscript𝐷KLdelimited-[]conditionalsubscript𝑄𝐰subscript𝑃𝐰K=\exp(D_{\mathrm{KL}}[Q_{\mathbf{w}}\|P_{\mathbf{w}}])italic_K = roman_exp ( italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ] ) samples from P𝐰subscript𝑃𝐰P_{\mathbf{w}}italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT. These K𝐾Kitalic_K samples are then used to construct a discrete distribution whose probability mass function is defined by the importance weights rk=dQ𝐰dP𝐰(𝐰k)subscript𝑟𝑘𝑑subscript𝑄𝐰𝑑subscript𝑃𝐰subscript𝐰𝑘r_{k}=\frac{dQ_{\mathbf{w}}}{dP_{\mathbf{w}}}({\mathbf{w}}_{k})italic_r start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = divide start_ARG italic_d italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT end_ARG start_ARG italic_d italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT end_ARG ( bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ), where dQ𝐰dP𝐰𝑑subscript𝑄𝐰𝑑subscript𝑃𝐰\frac{dQ_{\mathbf{w}}}{dP_{\mathbf{w}}}divide start_ARG italic_d italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT end_ARG start_ARG italic_d italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT end_ARG is the Radon-Nikodym derivative, i.e. the density ratio, of Q𝐰subscript𝑄𝐰Q_{\mathbf{w}}italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT with respect to P𝐰subscript𝑃𝐰P_{\mathbf{w}}italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT. The compressed weight sample is represented by an index k*Qksimilar-tosubscript𝑘subscript𝑄𝑘k_{*}\sim Q_{k}italic_k start_POSTSUBSCRIPT * end_POSTSUBSCRIPT ∼ italic_Q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. Since 0k*<K0subscript𝑘𝐾0\leq k_{*}<K0 ≤ italic_k start_POSTSUBSCRIPT * end_POSTSUBSCRIPT < italic_K, it is always possible to encode k*subscript𝑘k_{*}italic_k start_POSTSUBSCRIPT * end_POSTSUBSCRIPT using DKL[Q𝐰P𝐰]=Csubscript𝐷KLdelimited-[]conditionalsubscript𝑄𝐰subscript𝑃𝐰𝐶D_{\mathrm{KL}}[Q_{\mathbf{w}}\|P_{\mathbf{w}}]=Citalic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ] = italic_C nats. The weight sample can be decoded by drawing the k*subscript𝑘k_{*}italic_k start_POSTSUBSCRIPT * end_POSTSUBSCRIPTth sample from P𝐰subscript𝑃𝐰P_{\mathbf{w}}italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT using a shared random number generator with a shared random seed. Due to the exponential scaling, simulating K𝐾Kitalic_K samples is intractable if 𝐰𝐰{\mathbf{w}}bold_w has many dimensions. Havasi et al. 2019 solve this issue by partitioning 𝐰𝐰{\mathbf{w}}bold_w dimensionwise into smaller blocks with local information budgets Cblocksubscript𝐶blockC_{\mathrm{block}}italic_C start_POSTSUBSCRIPT roman_block end_POSTSUBSCRIPT, such that K𝐾Kitalic_K is feasible.

Refining Mean-Field Posteriors

An important choice in practice is the variational family over which we optimize Equation 1. Since we are interested in studying the behavior of samples using MIRACLE, we also adopt the variational family suggested by Havasi et al. (2019). Concretely, assume that we have already partitioned the weight vector as 𝐰=w1:B=𝐰1𝐰2𝐰B𝐰subscript𝑤:1𝐵direct-sumsubscript𝐰1subscript𝐰2subscript𝐰𝐵{\mathbf{w}}=w_{1:B}={\mathbf{w}}_{1}\oplus{\mathbf{w}}_{2}\oplus\ldots\oplus{% \mathbf{w}}_{B}bold_w = italic_w start_POSTSUBSCRIPT 1 : italic_B end_POSTSUBSCRIPT = bold_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊕ bold_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⊕ … ⊕ bold_w start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT, where B𝐵Bitalic_B denotes the number of blocks, and direct-sum\oplus denotes vector concatenation. To begin, we use a mean-field Gaussian variational approximation, i.e. we parameterize the means μ1:B=μ1μBsubscript𝜇:1𝐵direct-sumsubscript𝜇1subscript𝜇𝐵\mu_{1:B}=\mu_{1}\oplus\ldots\oplus\mu_{B}italic_μ start_POSTSUBSCRIPT 1 : italic_B end_POSTSUBSCRIPT = italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊕ … ⊕ italic_μ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT and marginal variances σ1:B2=σ12σB2subscriptsuperscript𝜎2:1𝐵direct-sumsubscriptsuperscript𝜎21subscriptsuperscript𝜎2𝐵\sigma^{2}_{1:B}=\sigma^{2}_{1}\oplus\ldots\oplus\sigma^{2}_{B}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 : italic_B end_POSTSUBSCRIPT = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊕ … ⊕ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT (Mean-Var). Once variational training converges, we compress the first block 𝐰1subscript𝐰1{\mathbf{w}}_{1}bold_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, resulting in a sample 𝐰~1subscript~𝐰1\tilde{{\mathbf{w}}}_{1}over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. Kee** 𝐰~1subscript~𝐰1\tilde{{\mathbf{w}}}_{1}over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT fixed, we resume optimization to fine-tune the remaining means μ2:Bsubscript𝜇:2𝐵\mu_{2:B}italic_μ start_POSTSUBSCRIPT 2 : italic_B end_POSTSUBSCRIPT and variances σ2:B2subscriptsuperscript𝜎2:2𝐵\sigma^{2}_{2:B}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 : italic_B end_POSTSUBSCRIPT. We repeat this process B𝐵Bitalic_B times in total, where at step b𝑏bitalic_b, 𝐰~1,,𝐰~b1subscript~𝐰1subscript~𝐰𝑏1\tilde{{\mathbf{w}}}_{1},\ldots,\tilde{{\mathbf{w}}}_{b-1}over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_b - 1 end_POSTSUBSCRIPT are fixed, means μb:Bsubscript𝜇:𝑏𝐵\mu_{b:B}italic_μ start_POSTSUBSCRIPT italic_b : italic_B end_POSTSUBSCRIPT and variances σb:B2subscriptsuperscript𝜎2:𝑏𝐵\sigma^{2}_{b:B}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_b : italic_B end_POSTSUBSCRIPT are optimized, and a random sample from block b𝑏bitalic_b is encoded. Note that the variational posterior Q𝐰b:B|𝐰~1:b1subscript𝑄conditionalsubscript𝐰:𝑏𝐵subscript~𝐰:1𝑏1Q_{{\mathbf{w}}_{b:B}|\tilde{{\mathbf{w}}}_{1:b-1}}italic_Q start_POSTSUBSCRIPT bold_w start_POSTSUBSCRIPT italic_b : italic_B end_POSTSUBSCRIPT | over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT 1 : italic_b - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT at step b𝑏bitalic_b is only factorized conditionally on the weight samples in the first b1𝑏1b-1italic_b - 1 blocks, which results in a much better variational approximation.

Mean-KL Parameterization for Gaussians

Flamich et al. 2022 show that, given a univariate Gaussian coding distribution Pw=𝒩(w|ν,ρ2)subscript𝑃𝑤𝒩conditional𝑤𝜈superscript𝜌2P_{w}=\mathcal{N}(w|\nu,\rho^{2})italic_P start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = caligraphic_N ( italic_w | italic_ν , italic_ρ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) with mean ν𝜈\nuitalic_ν and variance ρ2superscript𝜌2\rho^{2}italic_ρ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, a variational distribution Qw=𝒩(w|μ,σ2)subscript𝑄𝑤𝒩conditional𝑤𝜇superscript𝜎2Q_{w}=\mathcal{N}(w|\mu,\sigma^{2})italic_Q start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = caligraphic_N ( italic_w | italic_μ , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) can be uniquely parameterized by mean μ𝜇\muitalic_μ and DKL[QwPw]=κsubscript𝐷KLdelimited-[]conditionalsubscript𝑄𝑤subscript𝑃𝑤𝜅D_{\mathrm{KL}}[Q_{w}\|P_{w}]=\kappaitalic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_Q start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ] = italic_κ if

|μν|<ρ2κ𝜇𝜈𝜌2𝜅\displaystyle|\mu-\nu|<\rho\sqrt{2\kappa}| italic_μ - italic_ν | < italic_ρ square-root start_ARG 2 italic_κ end_ARG (2)

is satisfied. The variance σ2superscript𝜎2\sigma^{2}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT of Qwsubscript𝑄𝑤Q_{w}italic_Q start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT can be recovered via

σ2=ρ2W(exp(z22κ1)),superscript𝜎2superscript𝜌2𝑊superscript𝑧22𝜅1\displaystyle\sigma^{2}=-\rho^{2}W\left(-\exp(z^{2}-2\kappa-1)\right),italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = - italic_ρ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_W ( - roman_exp ( italic_z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 italic_κ - 1 ) ) , (3)

where z=(μν)/ρ𝑧𝜇𝜈𝜌z=(\mu-\nu)/\rhoitalic_z = ( italic_μ - italic_ν ) / italic_ρ and W𝑊Witalic_W is the principal branch of the Lambert W𝑊Witalic_W function (Corless et al., 1996), defined by the relation W(x)eW(x)=x𝑊𝑥superscript𝑒𝑊𝑥𝑥W(x)e^{W(x)}=xitalic_W ( italic_x ) italic_e start_POSTSUPERSCRIPT italic_W ( italic_x ) end_POSTSUPERSCRIPT = italic_x (see Appendix B for details).

3 Mean-KL Parameterization for MIRACLE

Recognizing that the main goal of minimizing Equation 1 combined with KL annealing is to solve

argminQ𝐰𝔼𝐰Q𝐰(𝐱,𝐲)𝒟Δ(𝐲,h(𝐱,𝐰)),subscript𝑄𝐰subscript𝔼similar-to𝐰subscript𝑄𝐰subscript𝐱𝐲𝒟Δ𝐲𝐱𝐰\displaystyle\underset{Q_{\mathbf{w}}}{\arg\min}\quad\mathbb{E}_{{\mathbf{w}}% \sim Q_{\mathbf{w}}}\sum_{({\mathbf{x}},{\mathbf{y}})\in\mathcal{D}}\Delta({% \mathbf{y}},h({\mathbf{x}},{\mathbf{w}})),start_UNDERACCENT italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT end_UNDERACCENT start_ARG roman_arg roman_min end_ARG blackboard_E start_POSTSUBSCRIPT bold_w ∼ italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT ( bold_x , bold_y ) ∈ caligraphic_D end_POSTSUBSCRIPT roman_Δ ( bold_y , italic_h ( bold_x , bold_w ) ) , (4)
subject toDKL[Q𝐰P𝐰]=C,subject tosubscript𝐷KLdelimited-[]conditionalsubscript𝑄𝐰subscript𝑃𝐰𝐶\displaystyle\text{subject to}\quad D_{\mathrm{KL}}[Q_{\mathbf{w}}\|P_{\mathbf% {w}}]=C,subject to italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ] = italic_C , (5)

we propose to use Mean-KL parameterization (Flamich et al., 2022) to enforce the DKL[Q𝐰P𝐰]=Csubscript𝐷KLdelimited-[]conditionalsubscript𝑄𝐰subscript𝑃𝐰𝐶D_{\mathrm{KL}}[Q_{\mathbf{w}}\|P_{\mathbf{w}}]=Citalic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ] = italic_C constraint mathematically instead of performing computationally expensive KL annealing. To this end, the total information budget C=κ𝐶𝜅C=\kappaitalic_C = italic_κ must be distributed to each weight, resulting in local information budgets κwsubscript𝜅𝑤\kappa_{w}italic_κ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT. Thus, in Mean-KL parameterization, each weight has a mean parameter μwsubscript𝜇𝑤\mu_{w}italic_μ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT and a local information budget κwsubscript𝜅𝑤\kappa_{w}italic_κ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT, matching the number of parameters for the conventional Mean-Var parameterization, albeit with one fewer degree of freedom because w𝐰κw=κsubscript𝑤𝐰subscript𝜅𝑤𝜅\sum_{w\in{\mathbf{w}}}\kappa_{w}=\kappa∑ start_POSTSUBSCRIPT italic_w ∈ bold_w end_POSTSUBSCRIPT italic_κ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = italic_κ.

In practice, we introduce an information quota parameter γwsubscript𝛾𝑤\gamma_{w}italic_γ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT per weight, which satisfies w𝐰γw=1subscript𝑤𝐰subscript𝛾𝑤1\sum_{w\in{\mathbf{w}}}\gamma_{w}=1∑ start_POSTSUBSCRIPT italic_w ∈ bold_w end_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = 1 and defines the relative share of the total information budget assigned to w𝑤witalic_w, that is κw=γwκsubscript𝜅𝑤subscript𝛾𝑤𝜅\kappa_{w}=\gamma_{w}\kappaitalic_κ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = italic_γ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT italic_κ. The constraint on the information quota parameters is implemented using a softmax function. To ensure that |μwν|<ρ2κwsubscript𝜇𝑤𝜈𝜌2subscript𝜅𝑤|\mu_{w}-\nu|<\rho\sqrt{2\kappa_{w}}| italic_μ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT - italic_ν | < italic_ρ square-root start_ARG 2 italic_κ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_ARG (Equation 2), we define

μw=ν+ρ2κwtanh(τw),subscript𝜇𝑤𝜈𝜌2subscript𝜅𝑤tanhsubscript𝜏𝑤\displaystyle\mu_{w}=\nu+\rho\sqrt{2\kappa_{w}}\mathrm{tanh}(\tau_{w}),italic_μ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = italic_ν + italic_ρ square-root start_ARG 2 italic_κ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_ARG roman_tanh ( italic_τ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ) , (6)

as suggested by Flamich et al. (2022), leaving τwsubscript𝜏𝑤\tau_{w}italic_τ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT and γwsubscript𝛾𝑤\gamma_{w}italic_γ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT as trainable parameters. In combination with blockwise partitioning of 𝐰𝐰{\mathbf{w}}bold_w, each block has its own constraint and κ𝜅\kappaitalic_κ is simply replaced by κblocksubscript𝜅block\kappa_{\mathrm{block}}italic_κ start_POSTSUBSCRIPT roman_block end_POSTSUBSCRIPT. When drawing samples from Q𝐰subscript𝑄𝐰Q_{\mathbf{w}}italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT or evaluating the density of Q𝐰subscript𝑄𝐰Q_{\mathbf{w}}italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT, we convert τwsubscript𝜏𝑤\tau_{w}italic_τ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT and γwsubscript𝛾𝑤\gamma_{w}italic_γ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT to μwsubscript𝜇𝑤\mu_{w}italic_μ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT and σw2superscriptsubscript𝜎𝑤2{\sigma_{w}}^{2}italic_σ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT using Equation 6 and Equation 3, respectively, followed by the same computations as with conventional Mean-Var parameterization.

4 Experiments

We empirically demonstrate advantages of Mean-KL compared to conventional Mean-Var parameterization: We show that variational training with Mean-KL parameterization converges faster than Mean-Var while maintaining predictive performance, we illustrate that Mean-KL leads to more meaningful distributions with heavier tails, and we demonstrate that these more meaningful distributions translate to improved robustness when pruning weights to zero.

Training Dynamics and Predictive Performance

We adopt the experimental setup of Havasi et al. 2019 and train a LeNet-5 on MNIST. The distortion function ΔΔ\Deltaroman_Δ is the cross-entropy, which is commonly used as a loss function in image classification. Matching Havasi et al. 2019, we used a local information budget of Cblock=κblock=20subscript𝐶blocksubscript𝜅block20C_{\mathrm{block}}=\kappa_{\mathrm{block}}=20italic_C start_POSTSUBSCRIPT roman_block end_POSTSUBSCRIPT = italic_κ start_POSTSUBSCRIPT roman_block end_POSTSUBSCRIPT = 20 bits. We varied the block size between 20, 30, and 40. For both parameterizations, we used Adam with a learning rate of 0.001 and a mini-batch size of 200. For KL divergence annealing with Mean-Var, we used ϵβ0=108subscriptitalic-ϵsubscript𝛽0superscript108\epsilon_{\beta_{0}}=10^{-8}italic_ϵ start_POSTSUBSCRIPT italic_β start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT = 10 start_POSTSUPERSCRIPT - 8 end_POSTSUPERSCRIPT and ϵβ=5×105subscriptitalic-ϵ𝛽5superscript105\epsilon_{\beta}=5\times 10^{-5}italic_ϵ start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT = 5 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT, as suggested by Havasi et al. 2019. See Appendix C for further implementation details.

Figure 2 illustrates how Mean-Var spends most of the optimization on minimizing and annealing the KL divergence to the desired coding cost, whereas for Mean-KL, the whole optimization process focuses on minimizing cross entropy, given that the parameterization already constrains the KL divergence to the desired coding cost. Crucially, KL divergence annealing with Mean-Var takes a tremendous amount of time while minimizing cross entropy with Mean-KL converges in just half the number of iterations. Table 1 shows that Mean-KL maintains predictive performance comparable to Mean-Var across different compression ratios, being slightly better in the low compression ratio setting and slightly worse in the high compression ratio settings, albeit within standard error.

Table 1: MNIST classification error after compression (lower is better). Mean ±plus-or-minus\pm± standard error over 10 seeds.
Block Size Ratio Mean-Var Mean-KL
20 555x 0.82±0.07plus-or-minus0.820.070.82\pm 0.070.82 ± 0.07 % 0.77±0.05plus-or-minus0.770.050.77\pm 0.050.77 ± 0.05 %
30 833x 0.79±0.05plus-or-minus0.790.050.79\pm 0.050.79 ± 0.05 % 0.87±0.08plus-or-minus0.870.080.87\pm 0.080.87 ± 0.08 %
40 1111x 0.87±0.07plus-or-minus0.870.070.87\pm 0.070.87 ± 0.07 % 0.96±0.08plus-or-minus0.960.080.96\pm 0.080.96 ± 0.08 %
Optimizer Iterations 200,000 100,000
Refer to caption
Figure 2: Training dynamics of Mean-Var and Mean-KL parameterizations. Mean-Var requires a large amount of iterations to anneal the KL divergence to the desired coding cost. Mean-KL constrains DKL[Q𝐰P𝐰]subscript𝐷KLdelimited-[]conditionalsubscript𝑄𝐰subscript𝑃𝐰D_{\mathrm{KL}}[Q_{\mathbf{w}}\|P_{\mathbf{w}}]italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ∥ italic_P start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ] to the desired value and focuses on minimizing cross entropy, converging in half the number of iterations.

Visualizing Variational Posteriors

Refer to caption
Figure 3: Predictive performance of compressed weight samples from Mean-Var and Mean-KL parameterizations when exposed to pruning via setting weights to zero by selecting the pruned weights uniformly at random (left), based on the smallest absolute values (middle) or based on minimizing KL divergence to a Dirac delta centered at zero (right). Mean ±plus-or-minus\pm± standard error over block sizes 20, 30, and 40 with 10 random seeds per block size.

To qualitatively investigate the variational posterior distributions, we plot layerwise histograms of learned parameters after the compressed weight sample has been generated. For purposes of comparison, both Mean-Var and Mean-KL parameters have been have been converted to mean and log standard deviation.

Figure 1 reveals striking differences between layerwise Mean-Var and Mean-KL parameter distributions. In terms of the means, Mean-Var parameters collapse to sharp peaks at zero for all layers without any visible tails. In contrast, Mean-KL mean parameters manifest much wider, symmetric distributions centered around zero with heavier tails, resembling shapes akin to Laplace, Gaussian or Student’s t𝑡titalic_t-distributions. In terms of the log standard deviation, similarly, Mean-Var parameters form peaked distributions around a particular value with virtually no tails. The distributions of Mean-KL log standard deviations is more spread out, forming distinct shapes for each layer. In general, Mean-Var standard deviations seem to be higher than Mean-KL standard deviations. Furthermore, despite resulting in similar predictive performance, the stark differences in distributional shapes suggest potential qualitative differences between the learned variational posteriors.

Robustness to Pruning

To study potential qualitative differences between variational posteriors learned using Mean-Var and Mean-KL parameterizations, we analyze the robustness of the compressed weight sample by setting certain weights to zero using three different strategies:

  1. 1.

    Random Uniform: Select pruned weights uniformly at random. This strategy reflects a general notion of robustness due to the uninformed nature of this strategy.

  2. 2.

    Absolute Value: Set the weight with smallest absolute value to zero. This strategy is a simple yet competitive pruning baseline (Blalock et al., 2020), which only depends on the compressed weight sample itself. If the same sample was generated by two different distributions it would still be pruned in the same way.

  3. 3.

    KL Divergence: Prune the weight which minimizes the KL divergence from the variational posterior to a Dirac delta at zero, argminiDKL[δwQwi]subscript𝑖subscript𝐷KLdelimited-[]conditionalsubscript𝛿𝑤subscript𝑄subscript𝑤𝑖{\arg\min}_{i}\,D_{\mathrm{KL}}[\delta_{w}\|Q_{w_{i}}]roman_arg roman_min start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_δ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ∥ italic_Q start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ]. For a Gaussian variational posterior with diagonal covariance matrix, this is equivalent to finding the weight with maximal density at zero (see Appendix A for details). This strategy depends on the variational posterior, implying that the same compressed sample would be pruned differently if it was generated by two different distributions.

Figure 3 illustrates how the test accuracy changes as more weights in the compressed sample are pruned to zero. With Random Uniform pruning, Mean-Var test accuracy quickly drops off, already losing more than half the performance after about 20% of the weights have been pruned, and diminishing to performance equal to guessing uniformly at random after roughly 70% of the weights have been set to zero. Mean-KL performance also reduces rapidly, albeit more gracefully. After setting 30% of all weights to zero, a test accuracy of 80% is maintained. Performance equal to guessing is reached after more than 80% of the weights have been pruned. This suggests a general notion of improved robustness of the compressed sample produced by Mean-KL compared to Mean-Var.

With Absolute Value pruning, Mean-Var and Mean-KL perform nearly identical. Both parameterizations roughly maintain full predictive performance until 50% of the weights have been pruned and decay towards random guessing as more weights are set to zero. In particular, this pruning strategy does not depend on the variational posterior and is only informed by the compressed weight sample itself, demonstrating that both parameterizations produce compressed samples which are generally capable of maintaining performance to some degree under pruning.

Finally, both parameterizations perform drastically different under KL Divergence pruning. While Mean-Var test accuracy quickly falls off almost to random guessing after only 50% of the weights have been set to zero, Mean-KL maintains close to 90% test accuracy after pruning 90% of the weights, even outperforming the competitive Absolute Value baseline. Since this pruning strategy is informed by the variational posterior, the results strongly suggest that, compared to Mean-Var, Mean-KL parameterization leads to a superior variational posterior which produces more robust compressed samples. Given that this pruning strategy outperforms the competitive baseline, this property is also not a mere peculiarity but could potentially be leveraged to design more robust algorithms.

5 Conclusion

We demonstrated that MIRACLE with Mean-KL parameterization bypasses the need for time-consuming KL annealing, leading to training convergence after half the number of optimization steps while maintaining predictive performance. Furthermore, Mean-KL parameterization produces more meaningful variational posterior distributions with heavy tails, whereas standard Mean-Var parameterization produces distributions which are sharply peaked at particular values. We illustrated that these qualitative differences result in different properties when exposed to pruning, suggesting that compressed weight samples from Mean-KL are more robust than samples from Mean-Var. Future work should investigate whether faster convergence properties are scalable to larger models and pioneer Mean-KL parameterization for Bayesian neural networks independent of compression. Explicitly utilizing Mean-KL’s robustness to design pruning or compression algorithms comprises another possible avenue.

References

  • Blalock et al. (2020) Blalock, D., Ortiz, J. J. G., Frankle, J., and Guttag, J. What is the State of Neural Network Pruning? In Proceedings of Machine Learning and Systems, 2020.
  • Chen et al. (2015) Chen, W., Wilson, J. T., Tyree, S., Weinberger, K. Q., and Chen, Y. Compressing Neural Networks with the Hashing Trick. In International Conference on Machine Learning, 2015.
  • Corless et al. (1996) Corless, R. M., Gonnet, G. H., Hare, D. E., Jeffrey, D. J., and Knuth, D. E. On the Lambert W𝑊Witalic_W Function. Advances in Computational Mathematics, 1996.
  • Dillon et al. (2017) Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., Patton, B., Alemi, A., Hoffman, M., and Saurous, R. A. TensorFlow Distributions. In arXiv:1711.10604, 2017.
  • Flamich (2023) Flamich, G. Greedy Poisson Rejection Sampling. In arXiv:2305.15313, 2023.
  • Flamich et al. (2020) Flamich, G., Havasi, M., and Hernández-Lobato, J. M. Compressing Images by Encoding their Latent Representations with Relative Entropy Coding. In Advances in Neural Information Processing Systems, 2020.
  • Flamich et al. (2022) Flamich, G., Markou, S., and Hernández-Lobato, J. M. Fast Relative Entropy Coding with A* Coding. In International Conference on Machine Learning, 2022.
  • Havasi et al. (2019) Havasi, M., Peharz, R., and Hernández-Lobato, J. M. Minimal Random Code Learning: Getting Bits Back from Compressed Model Parameters. In International Conference on Learning Representations, 2019.
  • Higgins et al. (2017) Higgins, I., Matthey, L., Pal, A., Burgess, C. P., Glorot, X., Botvinick, M. M., Mohamed, S., and Lerchner, A. β𝛽\betaitalic_β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework. In International Conference on Learning Representations, 2017.
  • Hinton & Van Camp (1993) Hinton, G. E. and Van Camp, D. Kee** Neural Networks Simple by Minimizing the Description Length of the Weights. In Conference on Computational Learning Theory, 1993.
  • Kenton & Toutanova (2019) Kenton, J. D. M.-W. C. and Toutanova, L. K. Bert: Pre-training of Deep Bidirectional Transformers for Language Understanding. In Conference of the North American Chapter of the Association for Computational Linguistics - Human Language Technologies, 2019.
  • Paszke et al. (2019) Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L., Desmaison, A., Kopf, A., Yang, E., DeVito, Z., Raison, M., Tejani, A., Chilamkurthy, S., Steiner, B., Fang, L., Bai, J., and Chintala, S. PyTorch: An Imperative Style, High-Performance Deep Learning Library. In Advances in Neural Information Processing Systems, 2019.
  • Winitzki (2003) Winitzki, S. Uniform Approximations for Transcendental Functions. In Computational Science and Its Applications, 2003.

Appendix A KL Divergence Pruning

Given a variational posterior Q𝐰subscript𝑄𝐰Q_{\mathbf{w}}italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT as multivariate Gaussian distribution 𝒩(𝐰|𝝁,𝚺)𝒩conditional𝐰𝝁𝚺\mathcal{N}(\mathbf{w}|\bm{\mu},\bm{\Sigma})caligraphic_N ( bold_w | bold_italic_μ , bold_Σ ) with diagonal covariance 𝚺=diag(𝝈2)𝚺diagsuperscript𝝈2\bm{\Sigma}=\mathrm{diag}(\bm{\sigma}^{2})bold_Σ = roman_diag ( bold_italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ), we want to select the dimension i𝑖iitalic_i which minimizes the KL divergence to a Dirac delta centered at zero, that is DKL[δ𝐰Q𝐰]subscript𝐷KLdelimited-[]conditionalsubscript𝛿𝐰subscript𝑄𝐰D_{\mathrm{KL}}[\delta_{\mathbf{w}}\|Q_{\mathbf{w}}]italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_δ start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ∥ italic_Q start_POSTSUBSCRIPT bold_w end_POSTSUBSCRIPT ]. Because the distribution of 𝐰𝐰{\mathbf{w}}bold_w is mean-field factorized, it suffices to consider individual dimensions independ of each other. To this end, let Qwi=𝒩(wi|μi,σi2)subscript𝑄subscript𝑤𝑖𝒩conditionalsubscript𝑤𝑖subscript𝜇𝑖superscriptsubscript𝜎𝑖2Q_{w_{i}}=\mathcal{N}(w_{i}|\mu_{i},\sigma_{i}^{2})italic_Q start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT = caligraphic_N ( italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) and Pwi=Pw=𝒩(w|ν,ρ2)subscript𝑃subscript𝑤𝑖subscript𝑃𝑤𝒩conditional𝑤𝜈superscript𝜌2P_{w_{i}}=P_{w}=\mathcal{N}(w|\nu,\rho^{2})italic_P start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT = italic_P start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT = caligraphic_N ( italic_w | italic_ν , italic_ρ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ), then

DKL[PwQwi]subscript𝐷KLdelimited-[]conditionalsubscript𝑃𝑤subscript𝑄subscript𝑤𝑖\displaystyle D_{\mathrm{KL}}[P_{w}\|Q_{w_{i}}]italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_P start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ∥ italic_Q start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] =logσiρ+ρ2+(νμi)22σi212,absentsubscript𝜎𝑖𝜌superscript𝜌2superscript𝜈subscript𝜇𝑖22superscriptsubscript𝜎𝑖212\displaystyle=\log\frac{\sigma_{i}}{\rho}+\frac{\rho^{2}+(\nu-\mu_{i})^{2}}{2{% \sigma_{i}}^{2}}-\frac{1}{2},= roman_log divide start_ARG italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_ρ end_ARG + divide start_ARG italic_ρ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ( italic_ν - italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG - divide start_ARG 1 end_ARG start_ARG 2 end_ARG , (7)

which can be simplified if we are only interested in finding the minimizer because logρ𝜌\log\rhoroman_log italic_ρ and 1212\frac{1}{2}divide start_ARG 1 end_ARG start_ARG 2 end_ARG are constant with respect to i𝑖iitalic_i,

argmin𝑖DKL[PwQwi]𝑖subscript𝐷KLdelimited-[]conditionalsubscript𝑃𝑤subscript𝑄subscript𝑤𝑖\displaystyle\underset{i}{\arg\min}\;D_{\mathrm{KL}}[P_{w}\|Q_{w_{i}}]underitalic_i start_ARG roman_arg roman_min end_ARG italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_P start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ∥ italic_Q start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] =argmin𝑖logσi+ρ2+(νμi)22σi2.absent𝑖subscript𝜎𝑖superscript𝜌2superscript𝜈subscript𝜇𝑖22superscriptsubscript𝜎𝑖2\displaystyle=\underset{i}{\arg\min}\;\log\sigma_{i}+\frac{\rho^{2}+(\nu-\mu_{% i})^{2}}{2{\sigma_{i}}^{2}}.= underitalic_i start_ARG roman_arg roman_min end_ARG roman_log italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + divide start_ARG italic_ρ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ( italic_ν - italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG . (8)

Now, to let Pwδwsubscript𝑃𝑤subscript𝛿𝑤P_{w}\to\delta_{w}italic_P start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT → italic_δ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT, we first set ν=0𝜈0\nu=0italic_ν = 0 and let ρ0𝜌0\rho\to 0italic_ρ → 0, yielding

argmin𝑖DKL[δwQwi]=argmin𝑖logσi+μi22σi2=argmax𝑖log𝒩(0|μi,σi),𝑖subscript𝐷KLdelimited-[]conditionalsubscript𝛿𝑤subscript𝑄subscript𝑤𝑖𝑖subscript𝜎𝑖superscriptsubscript𝜇𝑖22superscriptsubscript𝜎𝑖2𝑖𝒩conditional0subscript𝜇𝑖subscript𝜎𝑖\displaystyle\underset{i}{\arg\min}\;D_{\mathrm{KL}}[\delta_{w}\|Q_{w_{i}}]=% \underset{i}{\arg\min}\;\log\sigma_{i}+\frac{{\mu_{i}}^{2}}{2{\sigma_{i}}^{2}}% =\underset{i}{\arg\max}\;\log\mathcal{N}(0|\mu_{i},\sigma_{i}),underitalic_i start_ARG roman_arg roman_min end_ARG italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT [ italic_δ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ∥ italic_Q start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] = underitalic_i start_ARG roman_arg roman_min end_ARG roman_log italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + divide start_ARG italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG = underitalic_i start_ARG roman_arg roman_max end_ARG roman_log caligraphic_N ( 0 | italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , (9)

such that choosing the dimension i𝑖iitalic_i by minimizing log(σi)+μi2/2σi2subscript𝜎𝑖superscriptsubscript𝜇𝑖22superscriptsubscript𝜎𝑖2\log(\sigma_{i})+{\mu_{i}}^{2}/{2\sigma_{i}}^{2}roman_log ( italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2 italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT will prune the weight whose marginal distribution has the lowest KL divergence to a Dirac delta centered at zero or, equivalently, has the highest log density at zero.

Appendix B Padé Approximation to the Lambert W𝑊Witalic_W Function

Since the Lambert W𝑊Witalic_W function, defiend by W(x)eW(x)=x𝑊𝑥superscript𝑒𝑊𝑥𝑥W(x)e^{W(x)}=xitalic_W ( italic_x ) italic_e start_POSTSUPERSCRIPT italic_W ( italic_x ) end_POSTSUPERSCRIPT = italic_x, cannot be expressed using elementary functions, it has to be implemented by, for example, numerical or analytical approximations. We considered three different approximations to the principal branch of the Lambert W𝑊Witalic_W function: Winitzki’s approximation for real x>0𝑥0x>0italic_x > 0 (Winitzki 2003, (38)), Halley’s method for numerical root-finding with cubic rate of convergence, and a Padé approximation of order [3/2]. Winitzki’s approximation for real x>0𝑥0x>0italic_x > 0 is used as initialization for Halley’s method in the implementation of TensorFlow Probability (Dillon et al., 2017), however we experienced that the former by itself is not accurate enough and that the latter can be slow and exhibit numerical issues. Instead, we used a Padé approximation of order [3/2], given by

W(x)𝑊𝑥\displaystyle W(x)italic_W ( italic_x ) 13720t(x)3+257720t(x)2+16t(x)1103720t(x)2+56t(x)+1,absent13720𝑡superscript𝑥3257720𝑡superscript𝑥216𝑡𝑥1103720𝑡superscript𝑥256𝑡𝑥1\displaystyle\approx\frac{\frac{13}{720}t(x)^{3}+\frac{257}{720}t(x)^{2}+\frac% {1}{6}t(x)-1}{\frac{103}{720}t(x)^{2}+\frac{5}{6}t(x)+1},≈ divide start_ARG divide start_ARG 13 end_ARG start_ARG 720 end_ARG italic_t ( italic_x ) start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT + divide start_ARG 257 end_ARG start_ARG 720 end_ARG italic_t ( italic_x ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG 6 end_ARG italic_t ( italic_x ) - 1 end_ARG start_ARG divide start_ARG 103 end_ARG start_ARG 720 end_ARG italic_t ( italic_x ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG 5 end_ARG start_ARG 6 end_ARG italic_t ( italic_x ) + 1 end_ARG , (10)
wheret(x)where𝑡𝑥\displaystyle\text{where}\qquad t(x)where italic_t ( italic_x ) =2ex+2,absent2𝑒𝑥2\displaystyle=\sqrt{2ex+2},= square-root start_ARG 2 italic_e italic_x + 2 end_ARG , (11)

which was fast and accurate. We did not consider Winitzki’s approximation for e1x1superscript𝑒1𝑥1-e^{-1}\leq x\leq 1- italic_e start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ≤ italic_x ≤ 1 (Winitzki 2003, (39)).

Appendix C Implementation Details

Our implementation uses PyTorch (Paszke et al., 2019) and follows Havasi et al. 2019 closely. The LeNet-5 model consists of two convolutional layers and two linear layers, which are applied sequentially. The first convolutional layer has 1 input channel, 20 output channels, a kernel size of 5x5, a stride of 1, and no padding. It is followed by a ReLU activation and a 2D max pooling layer with a kernel size of 2 and a stride of 2. The second convolutional layer has 20 input channel, 50 output channels, and also a kernel size of 5x5, a stride of 1, and no padding. It is also followed by a ReLU activation and a 2D max pooling layer with a kernel size of 2 and a stride of 2. The first linear layer has 800 input features, matching the flattened outputs from the previous layer, 500 output features, and it is followed by a ReLU activation. The second linear layer has 500 input features and 10 output features, matching the number of classes in the MNIST dataset. It is followed by a softmax layer to produce class probabilities. Additionally, weight hashing (Chen et al., 2015) is used in the second convolutional layer and the first linear layer to reduce the effective number of weights by a factor of 2x and 64x respectively. The layerwise log standard deviation parameters of the coding distribution were initialized to 22-2- 2. For Mean-Var parameters, the means were initialized using PyTorch’s default initialization and the log standard deviations were initialized to 1010-10- 10. For Mean-KL parameters, τwsubscript𝜏𝑤\tau_{w}italic_τ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT was initialized by passing PyTorch’s default initialization through the analytical inverse of Equation 6 and γwsubscript𝛾𝑤\gamma_{w}italic_γ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT was initialized to 00. After initial variational training, we perform 100 fine-tuning steps in-between compressing blocks.