Minimal Random Code Learning with Mean-KL Parameterization
Abstract
This paper studies the qualitative behavior and robustness of two variants of Minimal Random Code Learning (MIRACLE) used to compress variational Bayesian neural networks. MIRACLE implements a powerful, conditionally Gaussian variational approximation for the weight posterior and uses relative entropy coding to compress a weight sample from the posterior using a Gaussian coding distribution . To achieve the desired compression rate, must be constrained, which requires a computationally expensive annealing procedure under the conventional mean-variance (Mean-Var) parameterization for . Instead, we parameterize by its mean and KL divergence from to constrain the compression cost to the desired value by construction. We demonstrate that variational training with Mean-KL parameterization converges twice as fast and maintains predictive performance after compression. Furthermore, we show that Mean-KL leads to more meaningful variational distributions with heavier tails and compressed weight samples which are more robust to pruning.
1 Introduction
With the ever-growing size of neural network architectures, such as large language models (e.g. BERT, Kenton & Toutanova, 2019), it is now a key challenge to ensure their memory and energy efficiency. While there is a large literature on model compression, almost all works rely on some form of quantization scheme. In this paper, we consider an alternative method to quantization, namely Minimal Random Code Learning (MIRACLE, Havasi et al., 2019), which has recently demonstrated state-of-the-art performance for neural network compression. The MIRACLE framework employs a powerful, conditionally Gaussian variational distribution over the weights of a neural network and uses relative entropy coding (REC, Flamich et al., 2020) with a Gaussian coding distribution to encode a random weight sample from . The average coding cost of encoding a weight sample is , which needs to be carefully controlled in a practical compression scheme. To this end, we propose to use Mean-KL parameterization for Gaussians (Flamich et al., 2022) to parameterize , allowing explicit control over by construction. We demonstrate that Mean-KL leads to many practical benefits over the conventional mean-variance (Mean-Var) parameterization used by Havasi et al. 2019, which requires a computationally expensive annealing procedure to control the coding cost. In particular, we show that, compared to Mean-Var parameterization, variational training converges in half the number of iterations using Mean-KL parameterization while maintaining predictive performance after compression. Furthermore, we illustrate that the resulting variational distribution exhibits more meaningful shapes with heavy tails, which makes the compressed weight sample more robust against zero pruning.
2 Background
Minimal Random Code Learning
Havasi et al. 2019 consider a setting akin to the -VAE (Higgins et al., 2017) to encode neural network weights with a limited information budget . To this end, let and be the input, output and weight spaces, respectively, let be a dataset and let be a neural network with input and weights . To control the information content of the weights, let be the coding distribution and be the variational distribution over . In this setting, Hinton & Van Camp 1993 show that the information content of the weights is . Further, let be a distortion function. MIRACLE minimizes
(1) |
with respect to to minimize distortion within the given information budget of nats. During optimization, is dynamically adapted to anneal the KL divergence, such that the constraint is eventually satisfied.
In this paper, we encode the samples using minimal random coding (MRC, Havasi et al., 2019) for simplicity, though more sophisticated approaches, such as A* coding (Flamich et al., 2022) or greedy Poisson rejection sampling (Flamich, 2023), have been invented. Given a suitable , a random sample from is compressed by first drawing samples from . These samples are then used to construct a discrete distribution whose probability mass function is defined by the importance weights , where is the Radon-Nikodym derivative, i.e. the density ratio, of with respect to . The compressed weight sample is represented by an index . Since , it is always possible to encode using nats. The weight sample can be decoded by drawing the th sample from using a shared random number generator with a shared random seed. Due to the exponential scaling, simulating samples is intractable if has many dimensions. Havasi et al. 2019 solve this issue by partitioning dimensionwise into smaller blocks with local information budgets , such that is feasible.
Refining Mean-Field Posteriors
An important choice in practice is the variational family over which we optimize Equation 1. Since we are interested in studying the behavior of samples using MIRACLE, we also adopt the variational family suggested by Havasi et al. (2019). Concretely, assume that we have already partitioned the weight vector as , where denotes the number of blocks, and denotes vector concatenation. To begin, we use a mean-field Gaussian variational approximation, i.e. we parameterize the means and marginal variances (Mean-Var). Once variational training converges, we compress the first block , resulting in a sample . Kee** fixed, we resume optimization to fine-tune the remaining means and variances . We repeat this process times in total, where at step , are fixed, means and variances are optimized, and a random sample from block is encoded. Note that the variational posterior at step is only factorized conditionally on the weight samples in the first blocks, which results in a much better variational approximation.
Mean-KL Parameterization for Gaussians
Flamich et al. 2022 show that, given a univariate Gaussian coding distribution with mean and variance , a variational distribution can be uniquely parameterized by mean and if
(2) |
is satisfied. The variance of can be recovered via
(3) |
where and is the principal branch of the Lambert function (Corless et al., 1996), defined by the relation (see Appendix B for details).
3 Mean-KL Parameterization for MIRACLE
Recognizing that the main goal of minimizing Equation 1 combined with KL annealing is to solve
(4) | |||
(5) |
we propose to use Mean-KL parameterization (Flamich et al., 2022) to enforce the constraint mathematically instead of performing computationally expensive KL annealing. To this end, the total information budget must be distributed to each weight, resulting in local information budgets . Thus, in Mean-KL parameterization, each weight has a mean parameter and a local information budget , matching the number of parameters for the conventional Mean-Var parameterization, albeit with one fewer degree of freedom because .
In practice, we introduce an information quota parameter per weight, which satisfies and defines the relative share of the total information budget assigned to , that is . The constraint on the information quota parameters is implemented using a softmax function. To ensure that (Equation 2), we define
(6) |
as suggested by Flamich et al. (2022), leaving and as trainable parameters. In combination with blockwise partitioning of , each block has its own constraint and is simply replaced by . When drawing samples from or evaluating the density of , we convert and to and using Equation 6 and Equation 3, respectively, followed by the same computations as with conventional Mean-Var parameterization.
4 Experiments
We empirically demonstrate advantages of Mean-KL compared to conventional Mean-Var parameterization: We show that variational training with Mean-KL parameterization converges faster than Mean-Var while maintaining predictive performance, we illustrate that Mean-KL leads to more meaningful distributions with heavier tails, and we demonstrate that these more meaningful distributions translate to improved robustness when pruning weights to zero.
Training Dynamics and Predictive Performance
We adopt the experimental setup of Havasi et al. 2019 and train a LeNet-5 on MNIST. The distortion function is the cross-entropy, which is commonly used as a loss function in image classification. Matching Havasi et al. 2019, we used a local information budget of bits. We varied the block size between 20, 30, and 40. For both parameterizations, we used Adam with a learning rate of 0.001 and a mini-batch size of 200. For KL divergence annealing with Mean-Var, we used and , as suggested by Havasi et al. 2019. See Appendix C for further implementation details.
Figure 2 illustrates how Mean-Var spends most of the optimization on minimizing and annealing the KL divergence to the desired coding cost, whereas for Mean-KL, the whole optimization process focuses on minimizing cross entropy, given that the parameterization already constrains the KL divergence to the desired coding cost. Crucially, KL divergence annealing with Mean-Var takes a tremendous amount of time while minimizing cross entropy with Mean-KL converges in just half the number of iterations. Table 1 shows that Mean-KL maintains predictive performance comparable to Mean-Var across different compression ratios, being slightly better in the low compression ratio setting and slightly worse in the high compression ratio settings, albeit within standard error.
Block Size | Ratio | Mean-Var | Mean-KL |
---|---|---|---|
20 | 555x | % | % |
30 | 833x | % | % |
40 | 1111x | % | % |
Optimizer Iterations | 200,000 | 100,000 |
Visualizing Variational Posteriors
To qualitatively investigate the variational posterior distributions, we plot layerwise histograms of learned parameters after the compressed weight sample has been generated. For purposes of comparison, both Mean-Var and Mean-KL parameters have been have been converted to mean and log standard deviation.
Figure 1 reveals striking differences between layerwise Mean-Var and Mean-KL parameter distributions. In terms of the means, Mean-Var parameters collapse to sharp peaks at zero for all layers without any visible tails. In contrast, Mean-KL mean parameters manifest much wider, symmetric distributions centered around zero with heavier tails, resembling shapes akin to Laplace, Gaussian or Student’s -distributions. In terms of the log standard deviation, similarly, Mean-Var parameters form peaked distributions around a particular value with virtually no tails. The distributions of Mean-KL log standard deviations is more spread out, forming distinct shapes for each layer. In general, Mean-Var standard deviations seem to be higher than Mean-KL standard deviations. Furthermore, despite resulting in similar predictive performance, the stark differences in distributional shapes suggest potential qualitative differences between the learned variational posteriors.
Robustness to Pruning
To study potential qualitative differences between variational posteriors learned using Mean-Var and Mean-KL parameterizations, we analyze the robustness of the compressed weight sample by setting certain weights to zero using three different strategies:
-
1.
Random Uniform: Select pruned weights uniformly at random. This strategy reflects a general notion of robustness due to the uninformed nature of this strategy.
-
2.
Absolute Value: Set the weight with smallest absolute value to zero. This strategy is a simple yet competitive pruning baseline (Blalock et al., 2020), which only depends on the compressed weight sample itself. If the same sample was generated by two different distributions it would still be pruned in the same way.
-
3.
KL Divergence: Prune the weight which minimizes the KL divergence from the variational posterior to a Dirac delta at zero, . For a Gaussian variational posterior with diagonal covariance matrix, this is equivalent to finding the weight with maximal density at zero (see Appendix A for details). This strategy depends on the variational posterior, implying that the same compressed sample would be pruned differently if it was generated by two different distributions.
Figure 3 illustrates how the test accuracy changes as more weights in the compressed sample are pruned to zero. With Random Uniform pruning, Mean-Var test accuracy quickly drops off, already losing more than half the performance after about 20% of the weights have been pruned, and diminishing to performance equal to guessing uniformly at random after roughly 70% of the weights have been set to zero. Mean-KL performance also reduces rapidly, albeit more gracefully. After setting 30% of all weights to zero, a test accuracy of 80% is maintained. Performance equal to guessing is reached after more than 80% of the weights have been pruned. This suggests a general notion of improved robustness of the compressed sample produced by Mean-KL compared to Mean-Var.
With Absolute Value pruning, Mean-Var and Mean-KL perform nearly identical. Both parameterizations roughly maintain full predictive performance until 50% of the weights have been pruned and decay towards random guessing as more weights are set to zero. In particular, this pruning strategy does not depend on the variational posterior and is only informed by the compressed weight sample itself, demonstrating that both parameterizations produce compressed samples which are generally capable of maintaining performance to some degree under pruning.
Finally, both parameterizations perform drastically different under KL Divergence pruning. While Mean-Var test accuracy quickly falls off almost to random guessing after only 50% of the weights have been set to zero, Mean-KL maintains close to 90% test accuracy after pruning 90% of the weights, even outperforming the competitive Absolute Value baseline. Since this pruning strategy is informed by the variational posterior, the results strongly suggest that, compared to Mean-Var, Mean-KL parameterization leads to a superior variational posterior which produces more robust compressed samples. Given that this pruning strategy outperforms the competitive baseline, this property is also not a mere peculiarity but could potentially be leveraged to design more robust algorithms.
5 Conclusion
We demonstrated that MIRACLE with Mean-KL parameterization bypasses the need for time-consuming KL annealing, leading to training convergence after half the number of optimization steps while maintaining predictive performance. Furthermore, Mean-KL parameterization produces more meaningful variational posterior distributions with heavy tails, whereas standard Mean-Var parameterization produces distributions which are sharply peaked at particular values. We illustrated that these qualitative differences result in different properties when exposed to pruning, suggesting that compressed weight samples from Mean-KL are more robust than samples from Mean-Var. Future work should investigate whether faster convergence properties are scalable to larger models and pioneer Mean-KL parameterization for Bayesian neural networks independent of compression. Explicitly utilizing Mean-KL’s robustness to design pruning or compression algorithms comprises another possible avenue.
References
- Blalock et al. (2020) Blalock, D., Ortiz, J. J. G., Frankle, J., and Guttag, J. What is the State of Neural Network Pruning? In Proceedings of Machine Learning and Systems, 2020.
- Chen et al. (2015) Chen, W., Wilson, J. T., Tyree, S., Weinberger, K. Q., and Chen, Y. Compressing Neural Networks with the Hashing Trick. In International Conference on Machine Learning, 2015.
- Corless et al. (1996) Corless, R. M., Gonnet, G. H., Hare, D. E., Jeffrey, D. J., and Knuth, D. E. On the Lambert Function. Advances in Computational Mathematics, 1996.
- Dillon et al. (2017) Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., Patton, B., Alemi, A., Hoffman, M., and Saurous, R. A. TensorFlow Distributions. In arXiv:1711.10604, 2017.
- Flamich (2023) Flamich, G. Greedy Poisson Rejection Sampling. In arXiv:2305.15313, 2023.
- Flamich et al. (2020) Flamich, G., Havasi, M., and Hernández-Lobato, J. M. Compressing Images by Encoding their Latent Representations with Relative Entropy Coding. In Advances in Neural Information Processing Systems, 2020.
- Flamich et al. (2022) Flamich, G., Markou, S., and Hernández-Lobato, J. M. Fast Relative Entropy Coding with A* Coding. In International Conference on Machine Learning, 2022.
- Havasi et al. (2019) Havasi, M., Peharz, R., and Hernández-Lobato, J. M. Minimal Random Code Learning: Getting Bits Back from Compressed Model Parameters. In International Conference on Learning Representations, 2019.
- Higgins et al. (2017) Higgins, I., Matthey, L., Pal, A., Burgess, C. P., Glorot, X., Botvinick, M. M., Mohamed, S., and Lerchner, A. -VAE: Learning Basic Visual Concepts with a Constrained Variational Framework. In International Conference on Learning Representations, 2017.
- Hinton & Van Camp (1993) Hinton, G. E. and Van Camp, D. Kee** Neural Networks Simple by Minimizing the Description Length of the Weights. In Conference on Computational Learning Theory, 1993.
- Kenton & Toutanova (2019) Kenton, J. D. M.-W. C. and Toutanova, L. K. Bert: Pre-training of Deep Bidirectional Transformers for Language Understanding. In Conference of the North American Chapter of the Association for Computational Linguistics - Human Language Technologies, 2019.
- Paszke et al. (2019) Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L., Desmaison, A., Kopf, A., Yang, E., DeVito, Z., Raison, M., Tejani, A., Chilamkurthy, S., Steiner, B., Fang, L., Bai, J., and Chintala, S. PyTorch: An Imperative Style, High-Performance Deep Learning Library. In Advances in Neural Information Processing Systems, 2019.
- Winitzki (2003) Winitzki, S. Uniform Approximations for Transcendental Functions. In Computational Science and Its Applications, 2003.
Appendix A KL Divergence Pruning
Given a variational posterior as multivariate Gaussian distribution with diagonal covariance , we want to select the dimension which minimizes the KL divergence to a Dirac delta centered at zero, that is . Because the distribution of is mean-field factorized, it suffices to consider individual dimensions independ of each other. To this end, let and , then
(7) |
which can be simplified if we are only interested in finding the minimizer because and are constant with respect to ,
(8) |
Now, to let , we first set and let , yielding
(9) |
such that choosing the dimension by minimizing will prune the weight whose marginal distribution has the lowest KL divergence to a Dirac delta centered at zero or, equivalently, has the highest log density at zero.
Appendix B Padé Approximation to the Lambert Function
Since the Lambert function, defiend by , cannot be expressed using elementary functions, it has to be implemented by, for example, numerical or analytical approximations. We considered three different approximations to the principal branch of the Lambert function: Winitzki’s approximation for real (Winitzki 2003, (38)), Halley’s method for numerical root-finding with cubic rate of convergence, and a Padé approximation of order [3/2]. Winitzki’s approximation for real is used as initialization for Halley’s method in the implementation of TensorFlow Probability (Dillon et al., 2017), however we experienced that the former by itself is not accurate enough and that the latter can be slow and exhibit numerical issues. Instead, we used a Padé approximation of order [3/2], given by
(10) | ||||
(11) |
which was fast and accurate. We did not consider Winitzki’s approximation for (Winitzki 2003, (39)).
Appendix C Implementation Details
Our implementation uses PyTorch (Paszke et al., 2019) and follows Havasi et al. 2019 closely. The LeNet-5 model consists of two convolutional layers and two linear layers, which are applied sequentially. The first convolutional layer has 1 input channel, 20 output channels, a kernel size of 5x5, a stride of 1, and no padding. It is followed by a ReLU activation and a 2D max pooling layer with a kernel size of 2 and a stride of 2. The second convolutional layer has 20 input channel, 50 output channels, and also a kernel size of 5x5, a stride of 1, and no padding. It is also followed by a ReLU activation and a 2D max pooling layer with a kernel size of 2 and a stride of 2. The first linear layer has 800 input features, matching the flattened outputs from the previous layer, 500 output features, and it is followed by a ReLU activation. The second linear layer has 500 input features and 10 output features, matching the number of classes in the MNIST dataset. It is followed by a softmax layer to produce class probabilities. Additionally, weight hashing (Chen et al., 2015) is used in the second convolutional layer and the first linear layer to reduce the effective number of weights by a factor of 2x and 64x respectively. The layerwise log standard deviation parameters of the coding distribution were initialized to . For Mean-Var parameters, the means were initialized using PyTorch’s default initialization and the log standard deviations were initialized to . For Mean-KL parameters, was initialized by passing PyTorch’s default initialization through the analytical inverse of Equation 6 and was initialized to . After initial variational training, we perform 100 fine-tuning steps in-between compressing blocks.