Affordable Bayesian Neural Networks
Posted by Yuling Yao on May 20, 2019.Recently Oscar Chang and I and some others wrote a paper with the title Ensemble Model Patching: A ParameterEfficient Variational Bayesian Neural Network. It mostly addresses the challenge of large scale Bayesian inference on deep neural networks.
The motivation is clear: 1) We want to fit Bayesian neural networks, but we do not afford to run MCMC. Indeed we do not even afford to run variational inference, as a meanfield gaussian approximation will double the parameter use and make it infeasible for the the use in the state of the art network, and 2) I really write a deep learning paper whatever it means.
As a statistician, I rarely think about this kind of computational bottleneck too seriously, but Oscar reminds me that it is indeed a concern:
A 100\% parameter overhead will double the amount of memory required. However, as many deep learning models already utilize the available hardware resources to the fullest, often maxing out GPU memory usage. In 2015, the highest end consumergrade GPU (Nvidia’s Titan X Maxwell), had 12GB of memory; by 2018, this number had doubled to 24GB. With Nvidia’s Titan RTX. This means that in the period 20152018, a programmer could only have Bayesianized stateoftheart neural networks from prior to 2015. Any newer designs, once their parameter use had been doubled, would not fit into her GPU’s memory. Thus, the developer’s Bayesianized networks are up to three years behind the stateoftheart. With GPU memory capacity doubling approximately every 3 years — an eternity in the rapidly progressing field of deep learning, the programmer’s Bayesianized networks will be up to three years behind the state of the art.
OK, so here is the goal: we want to do Bayesian inference on the network that is so lagre and big do anything beyond point estimation, therefore we wish the approximate inference takes nearly a muchcomputation resources as in nonBayesian point estimation.
MAP estimation is variational inference?
Apparently, it is hardly possible. Let’s recall what point estimation can do at all. Given any model $p(yx, \theta)$ , we could always optimize over $\theta$ and find MAP estimates For example, when $p(yx, \theta)$ is gaussian likelihhod with fixed output precision, and $\theta$ has a Gaussian prior, it leads to minizing MSE plus L2 regularizaion: MSE $+\lambda \theta_2^2$. A hypothetical Bayesian person will argue, a point estimation leads to overfitting, especially in this neural network context.
By nonBayesian point estimation I mean MAP estimates–not even aimd to update the posterior variance such as in laplace approximation.
But, suppose I would need to justify this approach, I could also claim the point estimation above is indeed (variational) Bayes. That is I use the meanfield gaussian approximation and I further restrict my variational family to make $\sigma$ fixed. Then minimizing KL(q  p) leads to
which is exactly what we will use in point estimation.
In this paper we also manage to extend this variational Bayes view to several other methods that were not considered Bayes.
Or maybe not.
If it looks like a duck, swims like a duck, and quacks like a duck, then it probably is a duck.
Variational inference can be probablmantic, and the argument above is one such example: the variaitonal family is so restricted that it does nothing more than a point estimation.
We need a more flexible variational family while maintaining the computation convenience. Dropout, or its computational varients MC dropout, essentially approximate $\theta$ by a mixture of two components with the following variational family:
($\theta$ can and always should be high dimensional, I just do not want write matrix expressions.)
It is more flexible, and therefore, more computationally expensive than meanfield. So the usual practice of dropout will fix p, and also fix $\sigma$ as machine precision. Roughly speaking we are estimating any posterior distribution by a mixture of two spikes, among which the first spike is centered at zero, and the second one is allowed ($M$) to move.
As a variational inference problem, we shall write down the ELBO, and it turns out that the negative ELBO is MSE + L2 of M. So magically we are solving a $L^2$ regularization but we end up with a distribution estimation $q(\theta)$, which is fully parametrized by $M$.
Parameter overhead
The two variational distribuion metioned above are both fully parametrized by only the same number of variational parameters $M$ as model parameters $\theta$:
or
The point estimation (line 1) is more problematic because it as restricts the variance to be a fixed value. By contrast, the dropout is possible to match the posterior variance, although it suffers from the Bernoulli properties:
Only one degree of freedom ($M$) can be optimized over: a large posterior variance has to come with a large magnitude of the posterior mean. In other words, such restriction enforces the tradeoff between predictive accuracy and calibration error.
Ensemble based mixture
It is natural to extend binarycomponent mixture to a multimixture model for each margin:
Again we fix $p_k$ and $\sigma$.
And, of course, using mixture distributions to enrich the expressiveness of variational Bayes is not a new idea. There are so many mixture meanfield methods.
To some extent, those mixtures is an extremely flexible family. In the high dimensional setting, it leads to a product of mixtures ($D^K$ if $D$ is the dimension and K is the number of mixtures). To be slightly weaker, we can model the weights in each layer to be a mixture of meanflied gaussian, then the total number of mixtures jointly is $L^K$ where $L$ is the number of layers.
Loosely speaking as long as $K$ is larger than 2, we are able to approximate the mean and variance of each margin. (Of course, it will never match)
Our method
The goal is not to find the most flexible variational family. The goal is to find a practical useful approach. We do add the parameter overhead $M$ times. Is that a concern?
The short answer is No. And there are two reasons for justification.

Deep networks are overparameterized, not only overparameterized for pointprediction but also overparameterized for distribution. It suggests we only need to model this relatively expensive variational family on some model patch.

The likelihood within a mixture model requires to marginalize out all discrete variables (assignment). We compute this through Monte Carlo samples. Each time we draw one realization $k$, and only optimize over $M_k$. The ELBO in this case is MSE+ $\lambda \sum_k M_k_2^2$.
This interpretation makes it easy for deep learning programmers to incorporate Bayesian training into their neural networks without much additional programming overhead. However, we emphasize the fundamental differences between conventional point estimation with $L^2$ regularization and our Bayesian approach:

The $L^2$ regularization is overall variational parameters {M}, not over neural net weights $\theta$.

Even if we run SGD with one Monte Carlo draw with one realization of the categorical variable at each iteration, the Monte Carlo gradient is still unbiased. Thus, the optimization converges to the desired variational distribution.

In the testing phase, we will draw $S>1$ to obtain the approximate posterior distribution $\Theta_1, \dots, \Theta_S$.
In the practical perspective, the proposed methods are better alternatives to dropout, especially in batch normalized networks. They outperform MC dropout and nonBayesian networks on test accuracy, probability calibration, and robustness against common image corruptions. The parameter overhead and computational cost are nearly the same as in nonBayesian training, making Bayesian networks affordable on large scale datasets and modern deep residual networks with hundreds of layers. To our knowledge, we are the first to scale a Bayesian neural network to the ImageNet dataset, more than halving the nonBayesian calibration error.
Learning uncertainty using point estimation?
At some point I started wondering about the following question:
To what extent can we approximate the posterior distributions by (the distribution) point estimation, which in many deep models is the only thing we are able to handle
It is not as crazy as it sounds like. For example
 Bootstrap distribution of MAP is the same as the true posterior in the normalnormal model.
Subsampling is, in general, a bad idea as it only uses part of data for training. We could also have the following constructions in the deep learning literature that essentially approximate the posterior distributions by ensemble point estimations:

learn MAP using SGD with constant learning rate, which mimics Langevin dynamics.

each ensemble comes from L2 regularization, but the center of the L2 comes from the prior

artificially creates weakidentification by modeling the mean and variance together
In our paper, we distinguish explicit and implicit ensembles, where the former one is postinference ensembles (i.e., each ensemble is trained independently in training), while the latter one is more of a mixture model. For explicit ensembles, stacking is a more promising approach. For implicit ensemble, we show the effect ensemble size in our approach is exponentially larger ($K^L$) than the explicit one, which makes it more likely to adapt to the multimodal nature of deep networks.
But again, we hope this work will draw more attention to computationally efficient methods in large scale Bayesian inference.