Bayesian Neural Networks (BNNs) extend traditional neural networks by treating weights as probability distributions rather than fixed values. This approach quantifies uncertainty and avoids overfitting. Variational Inference (VI) provides a scalable method to approximate the intractable posterior distribution of these weights.
- Traditional Neural Networks: Each weight has a single fixed value (point estimate).
- Bayesian Neural Networks: Each weight is treated as a probability distribution, representing uncertainty about its true value.
Why Use Variational Inference?
- Challenge: Computing the exact "posterior" distribution over all weights (what we believe about weights after seeing the data) is mathematically intractable for neural networks.
- Solution:Variational Inference (VI) approximates this complex posterior with a simpler, easy-to-handle distribution, usually a Gaussian.
How Does Variational Inference Work in BNNs?
Choose a Simple Distribution: Pick a family of distributions (e.g., diagonal Gaussian) to approximate the true posterior over weights. Each weight now has a mean and standard deviation, not just a single value.
Optimization Objective: Instead of maximizing likelihood (as in standard neural nets), VI maximizes a new objective that balances two things:
- Fit to Data: How well the network explains the observed data (like usual training).
- Closeness to Prior: How close the chosen distribution is to a prior belief about weights (regularization).
Gradient-Based Training: VI uses gradient descent, just like regular neural networks, but updates both the means and standard deviations of the weight distributions.
Prediction: At test time, predictions are made by averaging over several samples of weights from the learned distribution, capturing model uncertainty.
Key Points
- Posterior Consistency: Under certain conditions, the variational approximation will concentrate around the true solution as data increases.
- Trade-off: VI must balance fitting the data and staying close to the prior, especially important in large (overparameterized) networks.
- Choice of Approximation: Simpler distributions (like independent Gaussians) are easier to train but may not capture all uncertainty; more complex ones (like normalizing flows) can be more accurate but harder to optimize.
Practical Implementation of Variational Inference in BNNs
Main Formula (ELBO)
where
- : The variational (approximate) posterior distribution over the network weights (what we’re learning).
- : The likelihood how likely the observed data is, given weights (model fit).
- : The prior distribution over weights (our initial belief, e.g., a standard normal distribution).
- : The expected log-likelihood encourages the model to fit the data.
- The Kullback-Leibler divergence regularizes to stay close to .
Practical Training Steps
1. Choose Priors and Variational Family: Set (e.g., for each weight).
Choose (e.g., a Gaussian with learnable mean and variance per weight).
2. Sample Weights: For each mini-batch, sample weights $\theta$ from .
3. Compute Expected Log-Likelihood:
- = number of samples,
- = -th sample from .
4.Compute KL Divergence:
For Gaussians, this has a closed-form expression.
5. Optimize ELBO: Use stochastic gradient descent (SGD/Adam) to maximize ELBO (or equivalently, minimize ).
Advantages
- Uncertainty Quantification: BNNs can say how confident they are in their predictions useful for safety tasks or when data is scarce.
- Regularization: The prior acts as a built-in regularizer, helping prevent overfitting.
- Scalability: VI allows Bayesian ideas to be used in deep learning at scale, since it works with standard training tools and hardware.