Affordable Bayesian Neural Networks

Posted by Yuling Yao on May 20, 2019.       Tag: computation  
We wrote a new paper on approximate Bayesian inference in deep network, with a similiar computation cost as point estimation.

Recently Oscar Chang and I and some others wrote a paper with the title Ensemble Model Patching: A Parameter-Efficient Variational Bayesian Neural Network. It mostly addresses the challenge of large scale Bayesian inference on deep neural networks.

The motivation is clear: 1) We want to fit Bayesian neural networks, but we do not afford to run MCMC. Indeed we do not even afford to run variational inference, as a mean-field gaussian approximation will double the parameter use and make it infeasible for the the use in the state of the art network, and 2) I really write a deep learning paper whatever it means.

As a statistician, I rarely think about this kind of computational bottleneck too seriously, but Oscar reminds me that it is indeed a concern:

A 100\% parameter overhead will double the amount of memory required. However, as many deep learning models already utilize the available hardware resources to the fullest, often maxing out GPU memory usage. In 2015, the highest end consumer-grade GPU (Nvidia’s Titan X Maxwell), had 12GB of memory; by 2018, this number had doubled to 24GB. With Nvidia’s Titan RTX. This means that in the period 2015-2018, a programmer could only have Bayesianized state-of-the-art neural networks from prior to 2015. Any newer designs, once their parameter use had been doubled, would not fit into her GPU’s memory. Thus, the developer’s Bayesianized networks are up to three years behind the state-of-the-art. With GPU memory capacity doubling approximately every 3 years — an eternity in the rapidly progressing field of deep learning, the programmer’s Bayesianized networks will be up to three years behind the state of the art.

OK, so here is the goal: we want to do Bayesian inference on the network that is so lagre and big do anything beyond point estimation, therefore we wish the approximate inference takes nearly a muchcomputation resources as in non-Bayesian point estimation.

MAP estimation is variational inference?

Apparently, it is hardly possible. Let’s recall what point estimation can do at all. Given any model $p(y|x, \theta)$ , we could always optimize over $\theta$ and find MAP estimates For example, when $p(y|x, \theta)$ is gaussian likelihhod with fixed output precision, and $\theta$ has a Gaussian prior, it leads to minizing MSE plus L2 regularizaion: MSE $+\lambda ||\theta||_2^2$. A hypothetical Bayesian person will argue, a point estimation leads to overfitting, especially in this neural network context.

By non-Bayesian point estimation I mean MAP estimates–not even aimd to update the posterior variance such as in laplace approximation.

But, suppose I would need to justify this approach, I could also claim the point estimation above is indeed (variational) Bayes. That is I use the mean-field gaussian approximation and I further restrict my variational family to make $\sigma$ fixed. Then minimizing KL(q || p) leads to

which is exactly what we will use in point estimation.

In this paper we also manage to extend this variational Bayes view to several other methods that were not considered Bayes.

Or maybe not.

If it looks like a duck, swims like a duck, and quacks like a duck, then it probably is a duck.

Variational inference can be probablmantic, and the argument above is one such example: the variaitonal family is so restricted that it does nothing more than a point estimation.

We need a more flexible variational family while maintaining the computation convenience. Dropout, or its computational varients MC dropout, essentially approximate $\theta$ by a mixture of two components with the following variational family:

($\theta$ can and always should be high dimensional, I just do not want write matrix expressions.)

It is more flexible, and therefore, more computationally expensive than mean-field. So the usual practice of dropout will fix p, and also fix $\sigma$ as machine precision. Roughly speaking we are estimating any posterior distribution by a mixture of two spikes, among which the first spike is centered at zero, and the second one is allowed ($M$) to move.

As a variational inference problem, we shall write down the ELBO, and it turns out that the negative ELBO is MSE + L2 of M. So magically we are solving a $L^2$ regularization but we end up with a distribution estimation $q(\theta)$, which is fully parametrized by $M$.

Parameter overhead

The two variational distribuion metioned above are both fully parametrized by only the same number of variational parameters $M$ as model parameters $\theta$:


The point estimation (line 1) is more problematic because it as restricts the variance to be a fixed value. By contrast, the dropout is possible to match the posterior variance, although it suffers from the Bernoulli properties:

Only one degree of freedom ($M$) can be optimized over: a large posterior variance has to come with a large magnitude of the posterior mean. In other words, such restriction enforces the tradeoff between predictive accuracy and calibration error.

Ensemble based mixture

It is natural to extend binary-component mixture to a multi-mixture model for each margin:

Again we fix $p_k$ and $\sigma$.

And, of course, using mixture distributions to enrich the expressiveness of variational Bayes is not a new idea. There are so many mixture mean-field methods.

To some extent, those mixtures is an extremely flexible family. In the high dimensional setting, it leads to a product of mixtures ($D^K$ if $D$ is the dimension and K is the number of mixtures). To be slightly weaker, we can model the weights in each layer to be a mixture of mean-flied gaussian, then the total number of mixtures jointly is $L^K$ where $L$ is the number of layers.

Loosely speaking as long as $K$ is larger than 2, we are able to approximate the mean and variance of each margin. (Of course, it will never match)

Our method

The goal is not to find the most flexible variational family. The goal is to find a practical useful approach. We do add the parameter overhead $M$ times. Is that a concern?

The short answer is No. And there are two reasons for justification.

  1. Deep networks are overparameterized, not only overparameterized for point-prediction but also overparameterized for distribution. It suggests we only need to model this relatively expensive variational family on some model patch.

  2. The likelihood within a mixture model requires to marginalize out all discrete variables (assignment). We compute this through Monte Carlo samples. Each time we draw one realization $k$, and only optimize over $M_k$. The ELBO in this case is MSE+ $\lambda \sum_k ||M_k||_2^2$.

This interpretation makes it easy for deep learning programmers to incorporate Bayesian training into their neural networks without much additional programming overhead. However, we emphasize the fundamental differences between conventional point estimation with $L^2$ regularization and our Bayesian approach:

  1. The $L^2$ regularization is overall variational parameters {M}, not over neural net weights $\theta$.

  2. Even if we run SGD with one Monte Carlo draw with one realization of the categorical variable at each iteration, the Monte Carlo gradient is still unbiased. Thus, the optimization converges to the desired variational distribution.

  3. In the testing phase, we will draw $S>1$ to obtain the approximate posterior distribution $\Theta_1, \dots, \Theta_S$.

In the practical perspective, the proposed methods are better alternatives to dropout, especially in batch normalized networks. They outperform MC dropout and non-Bayesian networks on test accuracy, probability calibration, and robustness against common image corruptions. The parameter overhead and computational cost are nearly the same as in non-Bayesian training, making Bayesian networks affordable on large scale datasets and modern deep residual networks with hundreds of layers. To our knowledge, we are the first to scale a Bayesian neural network to the ImageNet dataset, more than halving the non-Bayesian calibration error.

Learning uncertainty using point estimation?

At some point I started wondering about the following question:

To what extent can we approximate the posterior distributions by (the distribution) point estimation, which in many deep models is the only thing we are able to handle

It is not as crazy as it sounds like. For example

  • Bootstrap distribution of MAP is the same as the true posterior in the normal-normal model.

Subsampling is, in general, a bad idea as it only uses part of data for training. We could also have the following constructions in the deep learning literature that essentially approximate the posterior distributions by ensemble point estimations:

In our paper, we distinguish explicit and implicit ensembles, where the former one is post-inference ensembles (i.e., each ensemble is trained independently in training), while the latter one is more of a mixture model. For explicit ensembles, stacking is a more promising approach. For implicit ensemble, we show the effect ensemble size in our approach is exponentially larger ($K^L$) than the explicit one, which makes it more likely to adapt to the multi-modal nature of deep networks.

But again, we hope this work will draw more attention to computationally efficient methods in large scale Bayesian inference.