Variational Optimization & Bayesian Inference

The intractability of marginal likelihoods

Bayesian inference boils down to computing integrals.

p(x)=p(xz)p(z)dzp(x) = \int p(x|z) p(z) dz

This marginal likelihood drives model selection and generalization. In high dimensions the typical set is a thin shell and MCMC has a hard time jumping between modes.

The variational approach swaps sampling for optimization. We pick a family of distributions Q\mathcal{Q} and find the member closest to the posterior,

q=argminqQD(qp(x))q^* = \text{argmin}_{q \in \mathcal{Q}} D(q \| p(\cdot | x))

This turns inference into gradient descent.

Deriving the ELBO

We work in the space of density functions. The KL divergence gives us a functional,

J[q]=q(z)logq(z)p(zx)dz\mathbb{J}[q] = \int q(z) \log \frac{q(z)}{p(z|x)} dz

Take the functional derivative. The integrand is L(z,q,q)=qlogqqlogpL(z, q, q') = q \log q - q \log p and so

δLδq=(1+logq)logp\frac{\delta L}{\delta q} = (1 + \log q) - \log p

Push q(z)dz=1\int q(z) dz = 1 in with a Lagrange multiplier λ\lambda,

δδq(J[q]+λ(q1))=1+logq(z)logp(zx)+λ=0\frac{\delta}{\delta q} ( \mathbb{J}[q] + \lambda (\int q - 1) ) = 1 + \log q(z) - \log p(z|x) + \lambda = 0

Solve for qq,

logq(z)=logp(zx)(1+λ)\log q(z) = \log p(z|x) - (1 + \lambda) q(z)p(zx)q^*(z) \propto p(z|x)

As you would expect the unconstrained optimum is the posterior itself. Since we cannot evaluate p(zx)p(z|x) directly we stick to a parametric family Q={qϕ}\mathcal{Q} = \{q_\phi\} and work with the evidence lower bound instead.

The ELBO. Start from

logp(x)=logp(x,z)dz=logEq[p(x,z)q(z)]\log p(x) = \log \int p(x, z) dz = \log \mathbb{E}_q \left[ \frac{p(x, z)}{q(z)} \right]

Jensen’s inequality kicks in because log\log is concave and gives

logp(x)Eq[logp(x,z)q(z)]=ELBO(ϕ)\log p(x) \ge \mathbb{E}_q \left[ \log \frac{p(x, z)}{q(z)} \right] = \text{ELBO}(\phi) ELBO=Eq[logp(x,z)]reconstructionEq[logq(z)]negative entropy\text{ELBO} = \underbrace{\mathbb{E}_q [\log p(x, z)]}_{\text{reconstruction}} - \underbrace{\mathbb{E}_q [\log q(z)]}_{\text{negative entropy}}

Maximizing the ELBO pulls qq toward regions where p(x,z)p(x,z) is big and the entropy term punishes collapse and keeps qq spread out. The gap between logp(x)\log p(x) and the ELBO is exactly DKL(qp(zx))D_{KL}(q \| p(z|x)).


Optimization methods

Mean field approximation

Assume q(z)=qi(zi)q(z) = \prod q_i(z_i) and fix all the factors except qjq_j. The best update is

logqj(zj)=Ej[logp(x,z)]+const\log q_j^*(z_j) = \mathbb{E}_{-j} [ \log p(x, z) ] + \text{const}

No gradients needed and just expectations. It works well for conjugate exponential families and it is the approach that sits under LDA.

Black-box variational inference

Take qλ=N(μ,σ2)q_\lambda = \mathcal{N}(\mu, \sigma^2) and differentiate the ELBO straight through,

λL=λEqλ[logp(x,z)logqλ(z)]\nabla_\lambda \mathcal{L} = \nabla_\lambda \mathbb{E}_{q_\lambda} [ \log p(x, z) - \log q_\lambda(z) ]

The score function estimator (REINFORCE) uses

Eq[f]=Eq[f(z)logqλ(z)]\nabla \mathbb{E}_q [f] = \mathbb{E}_q [ f(z) \nabla \log q_\lambda(z) ]

but it has wild variance in practice and often it is bad enough to be unusable.

The reparameterization trick fixes this. Write z=μ+σϵz = \mu + \sigma \epsilon with ϵN(0,1)\epsilon \sim \mathcal{N}(0,1) and then

λEp(ϵ)[f(g(ϵ,λ))]=Ep(ϵ)[zfλg]\nabla_\lambda \mathbb{E}_{p(\epsilon)} [ f(g(\epsilon, \lambda)) ] = \mathbb{E}_{p(\epsilon)} [ \nabla_z f \nabla_\lambda g ]

The gradient now runs through the model via zlogp\nabla_z \log p and this cuts variance by orders of magnitude. Reparameterization is the trick that makes VAEs actually work.


Natural gradients

Standard gradient descent treats parameter space as Euclidean. But the effect of nudging μ\mu depends entirely on the scale of σ\sigma. When σ=0.001\sigma = 0.001 shifting μ\mu by 0.10.1 is a huge change in KL terms and when σ=100\sigma = 100 the same shift barely matters.

The natural gradient fixes this by bounding each step by KL divergence rather than Euclidean distance,

ϕnew=argmaxϕL(ϕ)s.t. DKL(qϕqϕold)<ϵ\phi_{new} = \text{argmax}_\phi \mathcal{L}(\phi) \quad \text{s.t. } D_{KL}(q_\phi \| q_{\phi_{old}}) < \epsilon

To second order DKL(qϕ+dϕqϕ)12dϕTF(ϕ)dϕD_{KL}(q_{\phi+d\phi} \| q_\phi) \approx \frac{1}{2} d\phi^T F(\phi) d\phi where F(ϕ)=Eq[logqlogqT]F(\phi) = \mathbb{E}_q [ \nabla \log q \nabla \log q^T ] is the Fisher information matrix. The best direction turns out to be

gnat=F1(ϕ)ϕLg_{nat} = F^{-1}(\phi) \nabla_\phi \mathcal{L}

For exponential families in natural parameters η\eta the ordinary gradient ηL\nabla_\eta \mathcal{L} is already the natural gradient in mean parameter space and no matrix inversion is needed.

This is what Stochastic Variational Inference (Hoffman et al. 2013) is built on and it pushed VI out to datasets with billions of observations. The global parameter gradient breaks into a sum over local terms,

L=Eq[ηg]λg+i=1N(Eqi[ηl]λl)\nabla \mathcal{L} = \mathbb{E}_q [\eta_g] - \lambda_g + \sum_{i=1}^N (\mathbb{E}_{q_i} [\eta_l] - \lambda_l)

Estimate it with a minibatch of size MM,

^L=Eq[ηg]λg+NMm=1M(Eqm[ηl]λl)\hat{\nabla} \mathcal{L} = \mathbb{E}_q [\eta_g] - \lambda_g + \frac{N}{M} \sum_{m=1}^M (\mathbb{E}_{q_m} [\eta_l] - \lambda_l)

Taking steps ρt\rho_t that satisfy the Robbins-Monro conditions ρt=\sum \rho_t = \infty and ρt2<\sum \rho_t^2 < \infty gives you Bayesian inference at stochastic gradient descent scale.


Normalizing flows

Standard VI usually uses a Gaussian qq and that is unimodal and light-tailed. Normalizing flows fix this by stacking invertible maps,

zK=fKf1(z0),z0N(0,I)z_K = f_K \circ \dots \circ f_1(z_0), \quad z_0 \sim \mathcal{N}(0, I)

The density transforms by the change of variables formula,

logqK(zK)=logq0(z0)k=1Klogdetfkzk1\log q_K(z_K) = \log q_0(z_0) - \sum_{k=1}^K \log \left| \det \frac{\partial f_k}{\partial z_{k-1}} \right|

The maps have to be expressive and they also have to let you work out the Jacobian determinant cheaply. Planar flows (Rezende & Mohamed 2015) use

f(z)=z+utanh(wTz+b)f(z) = z + u \tanh(w^T z + b)

and the matrix determinant lemma gives det(I+uψ(z)T)=1+uTψ(z)\det(I + u\psi(z)^T) = 1 + u^T \psi(z). Stacking 10 to 20 of these layers lets qq pick up multimodal targets.


Stein variational gradient descent

SVGD (Liu & Wang 2016) takes a totally different angle and drops parametric families in favor of particles.

Let {zi}i=1N\{z_i\}_{i=1}^N stand in for qq. We push them with zz+ϵϕ(z)z \to z + \epsilon \phi(z) to bring down the KL divergence. Writing T(z)=z+ϵϕ(z)T(z) = z + \epsilon \phi(z) and q[T]=T#qq_{[T]} = T_\# q,

ϵDKL(q[T]p)ϵ=0=Eq[trace(Apϕ)]\nabla_\epsilon D_{KL}(q_{[T]} \| p) |_{\epsilon=0} = - \mathbb{E}_q [ \text{trace}(\mathcal{A}_p \phi) ]

where Apϕ=ϕlogp+ϕ\mathcal{A}_p \phi = \phi \nabla \log p + \nabla \cdot \phi is the Stein operator. The best perturbation inside the unit ball of an RKHS Hk\mathcal{H}_k is

ϕ(x)Eyq[k(x,y)ylogp(y)+yk(x,y)]\phi^*(x) \propto \mathbb{E}_{y \sim q} [ k(x, y) \nabla_y \log p(y) + \nabla_y k(x, y) ]

The update rule is

zizi+ϵ(1Nj[k(zi,zj)logp(x,zj)+zjk(zi,zj)])z_i \leftarrow z_i + \epsilon \left( \frac{1}{N} \sum_j [ k(z_i, z_j) \nabla \log p(x, z_j) + \nabla_{z_j} k(z_i, z_j) ] \right)

Two forces drive the dynamics. The first term is the kernel-weighted score and it pushes particles toward high-probability regions. The second term k\nabla k acts as a repulsive force and shoves particles apart and stops mode collapse. Starting from a tight cluster near the origin the particles split up and spread out to cover each mode of the posterior.

import jax import jax.numpy as jnp from jax import grad, vmap, jit from jax import random import matplotlib.pyplot as plt # 1. Define Kernels def rbf_kernel(X, h=-1): # vectorized RBF kernel # X: (N, D) # returns K: (N, N), grad_K: (N, N, D) # Pairwise distances diff = X[:, None, :] - X[None, :, :] # (N, N, D) sq_dist = jnp.sum(diff**2, axis=-1) # (N, N) # Median Heuristic if h < 0: h = jnp.median(sq_dist) / jnp.log(X.shape[0]) K = jnp.exp(-sq_dist / h) # Gradient of Kernel w.r.t the first particle set grad_K = -K[..., None] * diff * (2/h) return K, grad_K # 2. SVGD Step @jit def svgd_step(particles, log_prob_grad, step_size, optimizer_state=None): # Compute Score Function grad_logp = log_prob_grad(particles) # (N, D) # Kernel interaction K, grad_K = rbf_kernel(particles) # (N, N), (N, N, D) term1 = K @ grad_logp # (N, D) term2 = jnp.sum(grad_K, axis=1) # (N, D) # Sum over j phi = (term1 + term2) / particles.shape[0] return particles + step_size * phi # 3. Target Distribution: Bimodal Mixture def target_log_prob(x): # Mixture of two Gaussians at (-2, -2) and (2, 2) mu1 = jnp.array([-2.0, -2.0]) mu2 = jnp.array([2.0, 2.0]) w1, w2 = 0.5, 0.5 log_p1 = -0.5 * jnp.sum((x - mu1)**2) log_p2 = -0.5 * jnp.sum((x - mu2)**2) # LogSumExp trick return jax.scipy.special.logsumexp(jnp.array([log_p1, log_p2])) # Wrapper for vmapped gradients dist_grad = vmap(grad(target_log_prob)) def run_simulation(): key = random.PRNGKey(42) # Start all particles at (0,0) (Mode collapse state) particles = random.normal(key, (100, 2)) * 0.1 history = [] step_size = 0.1 for i in range(200): particles = svgd_step(particles, dist_grad, step_size) if i % 100 == 0: history.append(particles) # Plotting # Particles should split into two groups and cover both modes! return particles # Observation: Even effectively starting from a single point, the "Repulsive Force" # of the kernel (grad_K) pushes particles apart, forcing exploration of the second mode. # This proves SVGD > MCMC for multimodal mixing in many cases.

Variance reduction and the reparameterization trick

The practical success of VAEs (Kingma & Welling 2014) hangs on variance reduction. The core problem is computing ϕEqϕ[f(z)]\nabla_\phi \mathbb{E}_{q_\phi} [f(z)].

Score function estimator (REINFORCE).

ϕEq[f]=Eq[f(z)ϕlogqϕ(z)]\nabla_\phi \mathbb{E}_q [f] = \mathbb{E}_q [ f(z) \nabla_\phi \log q_\phi(z) ]

This works for any qq and even discrete distributions. But the variance scales with Var(f(z))\text{Var}(f(z)) and that is usually big. The estimator pokes at the function with random samples and never uses its local structure.

Pathwise derivative (reparameterization). Write z=g(ϵ,ϕ)z = g(\epsilon, \phi) for some diffeomorphism and you get

ϕE=Ep(ϵ)[zf(z)ϕg(ϵ,ϕ)]\nabla_\phi \mathbb{E} = \mathbb{E}_{p(\epsilon)} [ \nabla_z f(z) \nabla_\phi g(\epsilon, \phi) ]

This runs gradient information through the analytic gradient of the model and cuts variance dramatically. The catch is that both ff and qq have to be differentiable. For discrete latent variables you need relaxations like Gumbel-Softmax.


Summary

Variational inference turns posterior computation into optimization. Instead of waiting for a Markov chain to mix you set up a loss called the ELBO and descend. The tradeoff is real. The result is capped by the approximation family and the KL objective is mode-seeking. But it scales and normalizing flows or SVGD can close the approximation gap.

For big models and big datasets optimization-based inference will probably stay the dominant way to do this over sampling for most practical purposes.


VAE Loss Derivation

The VAE has generative model pθ(xz)p_\theta(x|z) and prior p(z)p(z) and approximate posterior qϕ(zx)q_\phi(z|x).

L=Eqϕ(zx)[logpθ(xz)]DKL(qϕ(zx)p(z))\mathcal{L} = \mathbb{E}_{q_\phi(z|x)} [ \log p_\theta(x|z) ] - D_{KL}( q_\phi(z|x) \| p(z) )

The KL term. Let q(zx)=N(μ,Σ)q(z|x) = \mathcal{N}(\mu, \Sigma) and p(z)=N(0,I)p(z) = \mathcal{N}(0, I). The general Gaussian KL is

DKL(N0N1)=12[tr(Σ11Σ0)+(μ1μ0)TΣ11(μ1μ0)k+lndetΣ1detΣ0]D_{KL}(N_0 \| N_1) = \frac{1}{2} \left[ \text{tr}(\Sigma_1^{-1} \Sigma_0) + (\mu_1 - \mu_0)^T \Sigma_1^{-1} (\mu_1 - \mu_0) - k + \ln \frac{\det \Sigma_1}{\det \Sigma_0} \right]

With μ1=0\mu_1=0 and Σ1=I\Sigma_1=I this punishes μ\mu and σ\sigma for drifting away from the standard normal and works as a regularizer that keeps the latent space organized.

The reconstruction term. Monte Carlo sample z(l)=μ+σϵ(l)z^{(l)} = \mu + \sigma \odot \epsilon^{(l)}. If p(xz)=N(x;Decoder(z),I)p(x|z) = \mathcal{N}(x; \text{Decoder}(z), I) then logp12xDecoder(z)2\log p \propto -\frac{1}{2}\|x - \text{Decoder}(z)\|^2 and that is just MSE.


Gumbel-Softmax for discrete variables

Reparameterization needs continuous variables. For categoricals zCat(π)z \sim \text{Cat}(\pi) you use the Gumbel-Max trick,

z=one_hot(argmax(logπi+gi)),giGumbel(0,1)z = \text{one\_hot}( \text{argmax}( \log \pi_i + g_i ) ), \quad g_i \sim \text{Gumbel}(0, 1)

Then relax it with a temperature τ\tau (Jang et al. 2016),

yi=exp((logπi+gi)/τ)exp((logπj+gj)/τ)y_i = \frac{\exp((\log \pi_i + g_i)/\tau)}{\sum \exp((\log \pi_j + g_j)/\tau)}

τ0\tau \to 0 brings back discrete samples and τ\tau \to \infty gives uniform. In practice τ\tau gets annealed during training.


Alpha-divergences

The “exclusive” KL (qp)(q \| p) is mode-seeking and qq just collapses onto one mode of pp. The “inclusive” KL (pq)(p \| q) is mass-covering but it needs you to evaluate pp. Renyi’s α\alpha-divergence interpolates between them,

Dα(pq)=1α1logp(z)αq(z)1αdzD_\alpha(p \| q) = \frac{1}{\alpha - 1} \log \int p(z)^\alpha q(z)^{1-\alpha} dz

α1\alpha \to 1 brings back DKL(pq)D_{KL}(p \| q) and α0\alpha \to 0 brings back logsupp(p)qdz-\log \int_{\text{supp}(p)} q\, dz which measures support overlap. Values in between slide between mass-covering when α<1\alpha < 1 and mode-seeking when α>1\alpha > 1. Black Box Alpha-VI (Hernandez-Lobato et al. 2016) minimizes DαD_\alpha with

Lα(q)1αlog1Kk=1K(p(x,zk)q(zk))α\mathcal{L}_\alpha(q) \approx \frac{1}{\alpha} \log \frac{1}{K} \sum_{k=1}^K \left( \frac{p(x, z_k)}{q(z_k)} \right)^\alpha

and this ties VI to particle filters and sequential Monte Carlo.


Development of variational methods

WhenWhatWhy it mattered
1998Michael Jordan et al. formalize variational methods for graphical modelsEstablished VI as a principled alternative to MCMC for structured models
1998Shun-Ichi Amari publishes natural gradient descentShowed that parameter space geometry determines the correct gradient direction
2013Hoffman et al. introduce Stochastic Variational InferenceExtended VI to datasets with billions of observations via stochastic optimization
2014Kingma & Welling introduce VAEs and the reparameterization trickConnected deep learning (backprop) with variational inference
2015Rezende & Mohamed develop normalizing flowsEnabled arbitrarily complex approximate posteriors via invertible transformations
2016Liu & Wang propose SVGDParticle-based VI using kernelized Stein discrepancy, requiring no parametric family

Terms and definitions

ELBO (Evidence Lower Bound) is the objective function for VI and it is defined as L(q)=Eq[logp(x,z)]Eq[logq(z)]\mathcal{L}(q) = \mathbb{E}_q[\log p(x,z)] - \mathbb{E}_q[\log q(z)]. Maximizing the ELBO is the same thing as minimizing KL(qp)\text{KL}(q \| p).

Mean Field is the assumption that qq factors all the way across latent dimensions. It is simple and it scales but it cannot pick up posterior correlations.

Natural Gradient is the gradient step on the Riemannian manifold set up by Fisher information. In natural parameters of exponential families ordinary gradients are already natural gradients.

Normalizing Flow is a chain of invertible transformations applied to a simple base distribution to make a complex density. The key constraint is that you have to compute the Jacobian determinant cheaply.

Reparameterization Trick writes z=g(ϵ,λ)z = g(\epsilon, \lambda) with ϵp(ϵ)\epsilon \sim p(\epsilon) so that λE[f(z)]=E[fλg]\nabla_\lambda \mathbb{E}[f(z)] = \mathbb{E}[\nabla f \cdot \nabla_\lambda g]. It cuts gradient variance by orders of magnitude compared to score-function estimators.

SVGD (Stein Variational Gradient Descent) is a deterministic particle method that moves samples toward the posterior using a kernelized velocity field that comes from Stein’s identity.

BBVI (Black-Box Variational Inference) uses score-function gradients of the ELBO and it works for any model with a log-density you can compute.

LDA (Latent Dirichlet Allocation) is the classic topic model and historically it was one of the first large-scale applications of mean field VI.


References

1. Blei, D. M. et al. (2017). “Variational Inference: A Review for Statisticians”. Extensive overview of VI as an alternative to MCMC. Covers Mean Field, Stochastic VI, and connections to convex optimization.

2. Liu, Q., & Wang, D. (2016). “Stein Variational Gradient Descent”. Introduces SVGD. Uses Stein’s identity and RKHS logic to derive a deterministic particle flow that simulates the heat flow of the posterior.

3. Kingma, D. P., & Welling, M. (2014). “Auto-Encoding Variational Bayes”. The paper that introduced VAEs and the Reparameterization Trick. It connected Deep Learning (backprop) with Variational Inference.

4. Rezende, D. J., & Mohamed, S. (2015). “Variational Inference with Normalizing Flows”. Showed how to construct arbitrarily complex posteriors q(z)q(z) by transforming simple Gaussians through invertible neural networks.

5. Amari, S. I. (1998). “Natural Gradient Works Efficiently in Learning”. The seminal paper defining the Natural Gradient for neural networks. It shows that the parameter space is Riemannian and the Fisher Information Metric is the unique metric invariant to reparameterization.

6. Jordan, M. I. et al. (1999). “An Introduction to Variational Methods for Graphical Models”. The classic tutorial that established the Mean Field approximation as the standard tool for exponential families in graphical models.