Variational Optimization & Bayesian Inference
1. The Geometry of Integration
Bayesian Inference is fundamentally about integration.
This marginal likelihood (Evidence) is the holy grail. It allows for model selection and explains generalization. However, in high dimensions, the volume of the typical set shrinks exponentially. MCMC methods (sampling) struggle to traverse this void. Variational Inference (VI) proposes a radical shift: Don’t sample. Optimize. We replace the integration problem with a functional optimization problem over a family of distributions .
This transforms inference into a problem typically solvable by Stochastic Gradient Descent.
2. Functional Calculus and the ELBO
Before the algebraic derivation, we examine the functional analysis. We operate on the space of probability density functions. Consider the Kullback-Leibler Divergence functional:
To minimize this, we take the Functional Derivative (Variational derivative) . Using the Euler-Lagrange framework, let the integrand be .
We impose the constraint using a Lagrange multiplier .
Solving for :
This confirms that the unconstrained functional minimum is indeed the posterior. Since we cannot evaluate , we restrict to a tractible parametric family .
The Evidence Lower Bound (Algebraic View):
By Jensen’s Inequality (since is concave):
Maximizing the ELBO forces to put mass where is high (Reconstruction), but also to stay spread out (Entropy). The gap is exactly the KL divergence: .
3. Optimization Strategies
3.1 Coordinate Ascent (Mean Field)
Assume . The optimal given others fixed is:
This is tractable for Conjugate Exponential families. Gradients are not required, only expectations. Used in Latent Dirichlet Allocation (LDA).
3.2 Black-Box Variational Inference (BBVI)
Assume is Gaussian . Parameter . Maximizing via Gradient Descent.
Log-Derivative Trick (REINFORCE):
High variance? Reparameterization Trick: Let . e.g., .
This uses the gradient of the model . Much lower variance. This is the engine of Variational Autoencoders (VAEs).
4. The Geometry of Natural Gradients
Standard Gradient Descent assumes the parameter space is Euclidean ( distance). But parameters of a distribution live on a statistical manifold. A small change in (mean) means something very different if is small vs large. Trust Region optimization:
The KL divergence to second order is the Fisher Information Matrix :
where . The constraint becomes a quadratic bound. The optimal update direction is the Natural Gradient:
Gaussian Natural Gradients: For , the Fisher Matrix is block diagonal. Remarkably, for exponential families, we don’t need to invert the matrix explicitly! Using the Canonical Parameters , the Gradients are equivalent to the Natural Gradients in the Mean Parameter space. Binary updates become .
This is the basis of SVI (Stochastic Variational Inference) which scales to billions of examples.
5. Stochastic Variational Inference (SVI)
Traditional VI requires iterating through the entire dataset to compute the coordinate ascent updates. For documents (LDA), this is impossible. Hoffman et al. (2013) introduced SVI. Key idea: The Natural Gradient of the global parameters is a sum of local structures.
We can estimate this noisy gradient using a minibatch of size .
Because we are in the Natural Gradient space (Riemannian), we can take steps of size satisfying Robbins-Monro conditions ().
6. Normalizing Flows (Beyond Mean Field)
The biggest limitation of VI is choice of . Using a Gaussian limits us to unimodal, light-tailed approximations. Normalizing Flows allow us to construct complex densities by transforming a simple base distribution (e.g. Gaussian) through a sequence of invertible maps .
Change of Variables Formula:
The challenge is finding functions with (1) High expressivity and (2) Linear-time Jacobian determinants.
Planar Flows (Rezende & Mohamed 2015):
New density concentrates mass along a hyperplane. Jacobian Determinant Lemma:
Stacking 10-20 Planar Flows allows the posterior to snake around complex, multi-modal landscapes.
7. Stein Variational Gradient Descent (SVGD)
BBVI is limited by the family (e.g., Gaussian implies no multi-modality). We want non-parametric VI. Use particles. Let be particles approximating . We want to move particles to decrease KL. Let . . Derivative of KL:
where is the Stein Operator: . We seek optimal perturbation in the unit ball of an RKHS .
Algorithm (SVGD):
Interpretation:
- Driver: Moves towards high probability ().
- Repulsion: Moves away from ( term acts as repulsive force). This prevents mode collapse! It allows particles to cover the posterior.
8. Variance Reduction: The Pathwise Derivative
Why did VAEs (2014) revolutionize the field? Because standard gradient estimation for expectations is noisy. We want .
Method 1: Score Function (REINFORCE)
Likelihood Ratio trick. valid for any density (even discrete). Variance: Proportional to . High noise.
Method 2: Pathwise Derivative (Reparameterization) Assume continuous and diffeomorphism .
Gradient moves inside:
Variance: Typically orders of magnitude lower. Limitation: Requires (model) and to be differentiable.
9. JAX Implementation: Particles vs Distribution
import jax
import jax.numpy as jnp
from jax import grad, vmap, jit
from jax import random
import matplotlib.pyplot as plt
# 1. Define Kernels
def rbf_kernel(X, h=-1):
# vectorized RBF kernel
# X: (N, D)
# returns K: (N, N), grad_K: (N, N, D)
# Pairwise distances
diff = X[:, None, :] - X[None, :, :] # (N, N, D)
sq_dist = jnp.sum(diff**2, axis=-1) # (N, N)
# Median Heuristic
if h < 0:
h = jnp.median(sq_dist) / jnp.log(X.shape[0])
K = jnp.exp(-sq_dist / h)
# Gradient of Kernel w.r.t the first particle set
grad_K = -K[..., None] * diff * (2/h)
return K, grad_K
# 2. SVGD Step
@jit
def svgd_step(particles, log_prob_grad, step_size, optimizer_state=None):
# Compute Score Function
grad_logp = log_prob_grad(particles) # (N, D)
# Kernel interaction
K, grad_K = rbf_kernel(particles) # (N, N), (N, N, D)
term1 = K @ grad_logp # (N, D)
term2 = jnp.sum(grad_K, axis=1) # (N, D) # Sum over j
phi = (term1 + term2) / particles.shape[0]
return particles + step_size * phi
# 3. Target Distribution: Bimodal Mixture
def target_log_prob(x):
# Mixture of two Gaussians at (-2, -2) and (2, 2)
mu1 = jnp.array([-2.0, -2.0])
mu2 = jnp.array([2.0, 2.0])
w1, w2 = 0.5, 0.5
log_p1 = -0.5 * jnp.sum((x - mu1)**2)
log_p2 = -0.5 * jnp.sum((x - mu2)**2)
# LogSumExp trick
return jax.scipy.special.logsumexp(jnp.array([log_p1, log_p2]))
# Wrapper for vmapped gradients
dist_grad = vmap(grad(target_log_prob))
def run_simulation():
key = random.PRNGKey(42)
# Start all particles at (0,0) (Mode collapse state)
particles = random.normal(key, (100, 2)) * 0.1
history = []
step_size = 0.1
for i in range(200):
particles = svgd_step(particles, dist_grad, step_size)
if i % 100 == 0: history.append(particles)
# Plotting
# Particles should split into two groups and cover both modes!
return particles
# Observation: Even effectively starting from a single point, the "Repulsive Force"
# of the kernel (grad_K) pushes particles apart, forcing exploration of the second mode.
# This proves SVGD > MCMC for multimodal mixing in many cases.10. Conclusion: The Optimization Perspective
Variational Inference represents a philosophical shift in Statistics. Instead of simulating the process (MCMC), we design an optimization landscape (ELBO) whose geometry entices the solution to reveal itself. From the functional calculus of the Euler-Lagrange equations to the differential geometry of Natural Gradients and the particle physics of Stein Flows, VI provides a rich playground where Geometry meets Probability. As we scale to massive datasets and complex deep generative models, the “Optimization” view of inference is likely to dominate the “Sampling” view for the foreseeable future.
4. Stein Variational Gradient Descent (SVGD)
BBVI is limited by the family (e.g., Gaussian implies no multi-modality). We want non-parametric VI. Use particles. Let be particles approximating . We want to move particles to decrease KL. Let . . Derivative of KL:
where is the Stein Operator: . We seek optimal perturbation in the unit ball of an RKHS .
Result (Liu & Wang 2016):
Algorithm (SVGD):
Interpretation:
- Driver: Moves towards high probability ().
- Repulsion: Moves away from ( term acts as repulsive force). This prevents mode collapse! It allows particles to cover the posterior.
5. Stochastic Variational Inference (SVI)
Traditional VI requires iterating through the entire dataset to compute the coordinate ascent updates. For documents (LDA), this is impossible. Hoffman et al. (2013) introduced SVI. Key idea: The Natural Gradient of the global parameters is a sum of local structures.
We can estimate this noisy gradient using a minibatch of size .
Because we are in the Natural Gradient space (Riemannian), we can take steps of size satisfying Robbins-Monro conditions (). This proved that Bayesian inference could scale to “Big Data” comparably to deep learning.
6. Stein Variational Gradient Descent (SVGD)
BBVI is limited by the family (e.g., Gaussian implies no multi-modality). We want non-parametric VI. Use particles. Let be particles approximating . We want to move particles to decrease KL. Let . . Derivative of KL:
where is the Stein Operator: . We seek optimal perturbation in the unit ball of an RKHS .
Algorithm (SVGD):
Interpretation:
- Driver: Moves towards high probability ().
- Repulsion: Moves away from ( term acts as repulsive force). This prevents mode collapse! It allows particles to cover the posterior.
JAX Implementation: Particles vs Distribution
import jax
import jax.numpy as jnp
from jax import grad, vmap, jit
from jax import random
import matplotlib.pyplot as plt
# 1. Define Kernels
def rbf_kernel(X, h=-1):
# vectorized RBF kernel
# X: (N, D)
# returns K: (N, N), grad_K: (N, N, D)
# Pairwise distances
diff = X[:, None, :] - X[None, :, :] # (N, N, D)
sq_dist = jnp.sum(diff**2, axis=-1) # (N, N)
# Median Heuristic
if h < 0:
h = jnp.median(sq_dist) / jnp.log(X.shape[0])
K = jnp.exp(-sq_dist / h)
# Gradient of Kernel w.r.t the first particle set
grad_K = -K[..., None] * diff * (2/h)
return K, grad_K
# 2. SVGD Step
@jit
def svgd_step(particles, log_prob_grad, step_size, optimizer_state=None):
# Compute Score Function
grad_logp = log_prob_grad(particles) # (N, D)
# Kernel interaction
K, grad_K = rbf_kernel(particles) # (N, N), (N, N, D)
# The Stein Force
# phi(xi) = sum_j [ k(xj, xi) score(xj) + grad_xj k(xj, xi) ] / N
# Note: Our rbf_kernel implementation returns grad w.r.t first arg.
# Symmetry allows us to use it directly.
term1 = K @ grad_logp # (N, D)
term2 = jnp.sum(grad_K, axis=1) # (N, D) # Sum over j
phi = (term1 + term2) / particles.shape[0]
return particles + step_size * phi
# 3. Target Distribution: Bimodal Mixture
def target_log_prob(x):
# Mixture of two Gaussians at (-2, -2) and (2, 2)
mu1 = jnp.array([-2.0, -2.0])
mu2 = jnp.array([2.0, 2.0])
w1, w2 = 0.5, 0.5
log_p1 = -0.5 * jnp.sum((x - mu1)**2)
log_p2 = -0.5 * jnp.sum((x - mu2)**2)
# LogSumExp trick
return jax.scipy.special.logsumexp(jnp.array([log_p1, log_p2]))
# Wrapper for vmapped gradients
dist_grad = vmap(grad(target_log_prob))
def run_simulation():
key = random.PRNGKey(42)
# Start all particles at (0,0) (Mode collapse state)
particles = random.normal(key, (100, 2)) * 0.1
history = []
step_size = 0.1
for i in range(200):
particles = svgd_step(particles, dist_grad, step_size)
if i % 100 == 0: history.append(particles)
# Plotting
# Particles should split into two groups and cover both modes!
return particles
# Observation: Even effectively starting from a single point, the "Repulsive Force"
# of the kernel (grad_K) pushes particles apart, forcing exploration of the second mode.
# This proves SVGD > MCMC for multimodal mixing in many cases.6. Variance Reduction: The Pathwise Derivative
Why did VAEs (2014) revolutionize the field? Because standard gradient estimation for expectations is noisy. We want .
Method 1: Score Function (REINFORCE)
Likelihood Ratio trick. valid for any density (even discrete). Variance: Proportional to . Even if the optimum is found, the gradient noise remains high unless (Requires baselines). convergence.
Method 2: Pathwise Derivative (Reparameterization) Assume continuous and diffeomorphism .
Gradient moves inside:
Variance: Relies on (gradient of the model). Typical variance is orders of magnitude lower than REINFORCE. Why? The Score Function effectively “probes” the function by random sampling. The Pathwise derivative uses the analytic knowledge of the function’s slope to guide the update. Limitation: Requires (model) and to be differentiable. Cannot handle discrete latent variables easily (requires Gumbel-Softmax relaxation).
Conclusion: The Optimization Perspective
Variational Inference represents a philosophical shift in Statistics. Instead of simulating the process (MCMC), we design an optimization landscape (ELBO) whose geometry entices the solution to reveal itself. From the functional calculus of the Euler-Lagrange equations to the differential geometry of Natural Gradients and the particle physics of Stein Flows, VI provides a rich playground where Geometry meets Probability. As we scale to massive datasets and complex deep generative models, the “Optimization” view of inference is likely to dominate the “Sampling” view for the foreseeable future.
Historical Timeline
| Year | Event | Significance |
|---|---|---|
| 1998 | Michael Jordan et al. | Formalize Variational Methods for Graphical Models. |
| 1998 | Shun-Ichi Amari | Natural Gradient Descent. |
| 2013 | Hoffman et al. | Stochastic Variational Inference (SVI). |
| 2014 | Kingma & Welling | VAEs (Reparameterization Trick). |
| 2015 | Rezende & Mohamed | Normalizing Flows. |
| 2016 | Liu & Wang | Stein Variational Gradient Descent (SVGD). |
Appendix A: VAE Loss Derivation (Full Details)
The Variational Autoencoder has Generative Model and prior . Approximator . ELBO:
A.1 The Gaussian KL Term
Let and . General formula for KL between two Gaussians :
Plugging in (usually diag representation ) and : This loss term is minimized when and . It acts as a regularizer keeping the latent space compact.
A.2 The Reconstruction Term
Monte Carlo estimate. . If :
This recovers the MSE loss.
Appendix B: The Gumbel-Softmax Trick (Discrete Regularization)
Reparameterization works for continuous variables. What about categories? We cannot backpropagate through sampling a categorical variable . Gumbel-Max Trick:
where . Concrete Relaxation (Jang et al. 2016): Replace argmax with Softmax with temperature :
As , approaches the one-hot sample. As , approaches uniform.
Appendix C: Generalization to Alpha-Divergences
Why minimize and not (Expectation Propagation) or something else? The “exclusive” KL () forces to fit into a mode of (Zero forcing). The “inclusive” KL () forces to cover the entire mass of (Mass covering). We can generalize using Renyi’s -divergence:
Limit : Inclusive KL. Limit : Exclusive KL (). Black Box Alpha-VI (Hernandez-Lobato et al. 2016): We can minimize directly using the VR-max bound or importance sampling.
This connects VI to Particle Filters and Sequential Monte Carlo. By tuning , we can control the behavior of the approximator: from mode-seeking to mean-seeking to heavy-tailed covering.
Appendix D: Glossary of Terms
- BBVI: Black-Box Variational Inference. Uses gradients of ELBO.
- ELBO: Evidence Lower Bound. The objective function for VI.
- LDA: Latent Dirichlet Allocation. A classic topic model using Mean Field VI.
- Mean Field: Assumption that factorizes fully.
- Natural Gradient: Gradient step in Riemannian manifold defined by Fisher Information.
- Normalizing Flow: A sequence of invertible transformations to model complex densities.
- Reparameterization Trick: . Low variance.
- SVGD: Particle-based VI using kernelized Stein discrepancy.
References
1. Blei, D. M. et al. (2017). “Variational Inference: A Review for Statisticians”. Extensive overview of VI as an alternative to MCMC. Covers Mean Field, Stochastic VI, and connections to convex optimization.
2. Liu, Q., & Wang, D. (2016). “Stein Variational Gradient Descent”. Introduces SVGD. Uses Stein’s identity and RKHS logic to derive a deterministic particle flow that simulates the heat flow of the posterior.
3. Kingma, D. P., & Welling, M. (2014). “Auto-Encoding Variational Bayes”. The paper that introduced VAEs and the Reparameterization Trick. It connected Deep Learning (backprop) with Variational Inference.
4. Rezende, D. J., & Mohamed, S. (2015). “Variational Inference with Normalizing Flows”.
- mostró how to construct arbitrarily complex posteriors by transforming simple Gaussians through invertible neural networks.*
5. Amari, S. I. (1998). “Natural Gradient Works Efficiently in Learning”. The seminal paper defining the Natural Gradient for neural networks. It shows that the parameter space is Riemannian and the Fisher Information Metric is the unique metric invariant to reparameterization.
6. Jordan, M. I. et al. (1999). “An Introduction to Variational Methods for Graphical Models”. The classic tutorial that established the Mean Field approximation as the standard tool for exponential families in graphical models.