Variational Optimization & Bayesian Inference
The intractability of marginal likelihoods
Bayesian inference boils down to computing integrals.
This marginal likelihood drives model selection and generalization. In high dimensions the typical set is a thin shell and MCMC has a hard time jumping between modes.
The variational approach swaps sampling for optimization. We pick a family of distributions and find the member closest to the posterior,
This turns inference into gradient descent.
Deriving the ELBO
We work in the space of density functions. The KL divergence gives us a functional,
Take the functional derivative. The integrand is and so
Push in with a Lagrange multiplier ,
Solve for ,
As you would expect the unconstrained optimum is the posterior itself. Since we cannot evaluate directly we stick to a parametric family and work with the evidence lower bound instead.
The ELBO. Start from
Jensen’s inequality kicks in because is concave and gives
Maximizing the ELBO pulls toward regions where is big and the entropy term punishes collapse and keeps spread out. The gap between and the ELBO is exactly .
Optimization methods
Mean field approximation
Assume and fix all the factors except . The best update is
No gradients needed and just expectations. It works well for conjugate exponential families and it is the approach that sits under LDA.
Black-box variational inference
Take and differentiate the ELBO straight through,
The score function estimator (REINFORCE) uses
but it has wild variance in practice and often it is bad enough to be unusable.
The reparameterization trick fixes this. Write with and then
The gradient now runs through the model via and this cuts variance by orders of magnitude. Reparameterization is the trick that makes VAEs actually work.
Natural gradients
Standard gradient descent treats parameter space as Euclidean. But the effect of nudging depends entirely on the scale of . When shifting by is a huge change in KL terms and when the same shift barely matters.
The natural gradient fixes this by bounding each step by KL divergence rather than Euclidean distance,
To second order where is the Fisher information matrix. The best direction turns out to be
For exponential families in natural parameters the ordinary gradient is already the natural gradient in mean parameter space and no matrix inversion is needed.
This is what Stochastic Variational Inference (Hoffman et al. 2013) is built on and it pushed VI out to datasets with billions of observations. The global parameter gradient breaks into a sum over local terms,
Estimate it with a minibatch of size ,
Taking steps that satisfy the Robbins-Monro conditions and gives you Bayesian inference at stochastic gradient descent scale.
Normalizing flows
Standard VI usually uses a Gaussian and that is unimodal and light-tailed. Normalizing flows fix this by stacking invertible maps,
The density transforms by the change of variables formula,
The maps have to be expressive and they also have to let you work out the Jacobian determinant cheaply. Planar flows (Rezende & Mohamed 2015) use
and the matrix determinant lemma gives . Stacking 10 to 20 of these layers lets pick up multimodal targets.
Stein variational gradient descent
SVGD (Liu & Wang 2016) takes a totally different angle and drops parametric families in favor of particles.
Let stand in for . We push them with to bring down the KL divergence. Writing and ,
where is the Stein operator. The best perturbation inside the unit ball of an RKHS is
The update rule is
Two forces drive the dynamics. The first term is the kernel-weighted score and it pushes particles toward high-probability regions. The second term acts as a repulsive force and shoves particles apart and stops mode collapse. Starting from a tight cluster near the origin the particles split up and spread out to cover each mode of the posterior.
import jax
import jax.numpy as jnp
from jax import grad, vmap, jit
from jax import random
import matplotlib.pyplot as plt
# 1. Define Kernels
def rbf_kernel(X, h=-1):
# vectorized RBF kernel
# X: (N, D)
# returns K: (N, N), grad_K: (N, N, D)
# Pairwise distances
diff = X[:, None, :] - X[None, :, :] # (N, N, D)
sq_dist = jnp.sum(diff**2, axis=-1) # (N, N)
# Median Heuristic
if h < 0:
h = jnp.median(sq_dist) / jnp.log(X.shape[0])
K = jnp.exp(-sq_dist / h)
# Gradient of Kernel w.r.t the first particle set
grad_K = -K[..., None] * diff * (2/h)
return K, grad_K
# 2. SVGD Step
@jit
def svgd_step(particles, log_prob_grad, step_size, optimizer_state=None):
# Compute Score Function
grad_logp = log_prob_grad(particles) # (N, D)
# Kernel interaction
K, grad_K = rbf_kernel(particles) # (N, N), (N, N, D)
term1 = K @ grad_logp # (N, D)
term2 = jnp.sum(grad_K, axis=1) # (N, D) # Sum over j
phi = (term1 + term2) / particles.shape[0]
return particles + step_size * phi
# 3. Target Distribution: Bimodal Mixture
def target_log_prob(x):
# Mixture of two Gaussians at (-2, -2) and (2, 2)
mu1 = jnp.array([-2.0, -2.0])
mu2 = jnp.array([2.0, 2.0])
w1, w2 = 0.5, 0.5
log_p1 = -0.5 * jnp.sum((x - mu1)**2)
log_p2 = -0.5 * jnp.sum((x - mu2)**2)
# LogSumExp trick
return jax.scipy.special.logsumexp(jnp.array([log_p1, log_p2]))
# Wrapper for vmapped gradients
dist_grad = vmap(grad(target_log_prob))
def run_simulation():
key = random.PRNGKey(42)
# Start all particles at (0,0) (Mode collapse state)
particles = random.normal(key, (100, 2)) * 0.1
history = []
step_size = 0.1
for i in range(200):
particles = svgd_step(particles, dist_grad, step_size)
if i % 100 == 0: history.append(particles)
# Plotting
# Particles should split into two groups and cover both modes!
return particles
# Observation: Even effectively starting from a single point, the "Repulsive Force"
# of the kernel (grad_K) pushes particles apart, forcing exploration of the second mode.
# This proves SVGD > MCMC for multimodal mixing in many cases.Variance reduction and the reparameterization trick
The practical success of VAEs (Kingma & Welling 2014) hangs on variance reduction. The core problem is computing .
Score function estimator (REINFORCE).
This works for any and even discrete distributions. But the variance scales with and that is usually big. The estimator pokes at the function with random samples and never uses its local structure.
Pathwise derivative (reparameterization). Write for some diffeomorphism and you get
This runs gradient information through the analytic gradient of the model and cuts variance dramatically. The catch is that both and have to be differentiable. For discrete latent variables you need relaxations like Gumbel-Softmax.
Summary
Variational inference turns posterior computation into optimization. Instead of waiting for a Markov chain to mix you set up a loss called the ELBO and descend. The tradeoff is real. The result is capped by the approximation family and the KL objective is mode-seeking. But it scales and normalizing flows or SVGD can close the approximation gap.
For big models and big datasets optimization-based inference will probably stay the dominant way to do this over sampling for most practical purposes.
VAE Loss Derivation
The VAE has generative model and prior and approximate posterior .
The KL term. Let and . The general Gaussian KL is
With and this punishes and for drifting away from the standard normal and works as a regularizer that keeps the latent space organized.
The reconstruction term. Monte Carlo sample . If then and that is just MSE.
Gumbel-Softmax for discrete variables
Reparameterization needs continuous variables. For categoricals you use the Gumbel-Max trick,
Then relax it with a temperature (Jang et al. 2016),
brings back discrete samples and gives uniform. In practice gets annealed during training.
Alpha-divergences
The “exclusive” KL is mode-seeking and just collapses onto one mode of . The “inclusive” KL is mass-covering but it needs you to evaluate . Renyi’s -divergence interpolates between them,
brings back and brings back which measures support overlap. Values in between slide between mass-covering when and mode-seeking when . Black Box Alpha-VI (Hernandez-Lobato et al. 2016) minimizes with
and this ties VI to particle filters and sequential Monte Carlo.
Development of variational methods
| When | What | Why it mattered |
|---|---|---|
| 1998 | Michael Jordan et al. formalize variational methods for graphical models | Established VI as a principled alternative to MCMC for structured models |
| 1998 | Shun-Ichi Amari publishes natural gradient descent | Showed that parameter space geometry determines the correct gradient direction |
| 2013 | Hoffman et al. introduce Stochastic Variational Inference | Extended VI to datasets with billions of observations via stochastic optimization |
| 2014 | Kingma & Welling introduce VAEs and the reparameterization trick | Connected deep learning (backprop) with variational inference |
| 2015 | Rezende & Mohamed develop normalizing flows | Enabled arbitrarily complex approximate posteriors via invertible transformations |
| 2016 | Liu & Wang propose SVGD | Particle-based VI using kernelized Stein discrepancy, requiring no parametric family |
Terms and definitions
ELBO (Evidence Lower Bound) is the objective function for VI and it is defined as . Maximizing the ELBO is the same thing as minimizing .
Mean Field is the assumption that factors all the way across latent dimensions. It is simple and it scales but it cannot pick up posterior correlations.
Natural Gradient is the gradient step on the Riemannian manifold set up by Fisher information. In natural parameters of exponential families ordinary gradients are already natural gradients.
Normalizing Flow is a chain of invertible transformations applied to a simple base distribution to make a complex density. The key constraint is that you have to compute the Jacobian determinant cheaply.
Reparameterization Trick writes with so that . It cuts gradient variance by orders of magnitude compared to score-function estimators.
SVGD (Stein Variational Gradient Descent) is a deterministic particle method that moves samples toward the posterior using a kernelized velocity field that comes from Stein’s identity.
BBVI (Black-Box Variational Inference) uses score-function gradients of the ELBO and it works for any model with a log-density you can compute.
LDA (Latent Dirichlet Allocation) is the classic topic model and historically it was one of the first large-scale applications of mean field VI.
References
1. Blei, D. M. et al. (2017). “Variational Inference: A Review for Statisticians”. Extensive overview of VI as an alternative to MCMC. Covers Mean Field, Stochastic VI, and connections to convex optimization.
2. Liu, Q., & Wang, D. (2016). “Stein Variational Gradient Descent”. Introduces SVGD. Uses Stein’s identity and RKHS logic to derive a deterministic particle flow that simulates the heat flow of the posterior.
3. Kingma, D. P., & Welling, M. (2014). “Auto-Encoding Variational Bayes”. The paper that introduced VAEs and the Reparameterization Trick. It connected Deep Learning (backprop) with Variational Inference.
4. Rezende, D. J., & Mohamed, S. (2015). “Variational Inference with Normalizing Flows”. Showed how to construct arbitrarily complex posteriors by transforming simple Gaussians through invertible neural networks.
5. Amari, S. I. (1998). “Natural Gradient Works Efficiently in Learning”. The seminal paper defining the Natural Gradient for neural networks. It shows that the parameter space is Riemannian and the Fisher Information Metric is the unique metric invariant to reparameterization.
6. Jordan, M. I. et al. (1999). “An Introduction to Variational Methods for Graphical Models”. The classic tutorial that established the Mean Field approximation as the standard tool for exponential families in graphical models.