Variational Optimization & Bayesian Inference

1. The Geometry of Integration

Bayesian Inference is fundamentally about integration.

p(x)=p(xz)p(z)dzp(x) = \int p(x|z) p(z) dz

This marginal likelihood (Evidence) is the holy grail. It allows for model selection and explains generalization. However, in high dimensions, the volume of the typical set shrinks exponentially. MCMC methods (sampling) struggle to traverse this void. Variational Inference (VI) proposes a radical shift: Don’t sample. Optimize. We replace the integration problem with a functional optimization problem over a family of distributions Q\mathcal{Q}.

q=argminqQD(qp(x))q^* = \text{argmin}_{q \in \mathcal{Q}} D(q \| p(\cdot | x))

This transforms inference into a problem typically solvable by Stochastic Gradient Descent.

2. Functional Calculus and the ELBO

Before the algebraic derivation, we examine the functional analysis. We operate on the space of probability density functions. Consider the Kullback-Leibler Divergence functional:

J[q]=q(z)logq(z)p(zx)dz\mathbb{J}[q] = \int q(z) \log \frac{q(z)}{p(z|x)} dz

To minimize this, we take the Functional Derivative (Variational derivative) δJδq\frac{\delta \mathbb{J}}{\delta q}. Using the Euler-Lagrange framework, let the integrand be L(z,q,q)=qlogqqlogpL(z, q, q') = q \log q - q \log p.

δLδq=(1+logq)logp\frac{\delta L}{\delta q} = (1 + \log q) - \log p

We impose the constraint q(z)dz=1\int q(z) dz = 1 using a Lagrange multiplier λ\lambda.

δδq(J[q]+λ(q1))=1+logq(z)logp(zx)+λ=0\frac{\delta}{\delta q} ( \mathbb{J}[q] + \lambda (\int q - 1) ) = 1 + \log q(z) - \log p(z|x) + \lambda = 0

Solving for q(z)q(z):

logq(z)=logp(zx)(1+λ)\log q(z) = \log p(z|x) - (1 + \lambda) q(z)p(zx)q^*(z) \propto p(z|x)

This confirms that the unconstrained functional minimum is indeed the posterior. Since we cannot evaluate p(zx)p(z|x), we restrict qq to a tractible parametric family Q={qϕ}\mathcal{Q} = \{q_\phi\}.

The Evidence Lower Bound (Algebraic View):

logp(x)=logp(x,z)dz=logEq[p(x,z)q(z)]\log p(x) = \log \int p(x, z) dz = \log \mathbb{E}_q \left[ \frac{p(x, z)}{q(z)} \right]

By Jensen’s Inequality (since log\log is concave):

logp(x)Eq[logp(x,z)q(z)]=ELBO(ϕ)\log p(x) \ge \mathbb{E}_q \left[ \log \frac{p(x, z)}{q(z)} \right] = \text{ELBO}(\phi) ELBO=Eq[logp(x,z)]Eq[logq(z)]\text{ELBO} = \mathbb{E}_q [\log p(x, z)] - \mathbb{E}_q [\log q(z)] =ReconstructionEntropy= \text{Reconstruction} - \text{Entropy}

Maximizing the ELBO forces qq to put mass where p(x,z)p(x, z) is high (Reconstruction), but also to stay spread out (Entropy). The gap is exactly the KL divergence: logp(x)ELBO=DKL(qp(zx))\log p(x) - \text{ELBO} = D_{KL}(q \| p(z|x)).


3. Optimization Strategies

3.1 Coordinate Ascent (Mean Field)

Assume q(z)=qi(zi)q(z) = \prod q_i(z_i). The optimal qj(zj)q_j^*(z_j) given others fixed is:

logqj(zj)=Ej[logp(x,z)]+const\log q_j^*(z_j) = \mathbb{E}_{-j} [ \log p(x, z) ] + \text{const}

This is tractable for Conjugate Exponential families. Gradients are not required, only expectations. Used in Latent Dirichlet Allocation (LDA).

3.2 Black-Box Variational Inference (BBVI)

Assume qλq_\lambda is Gaussian N(μ,σ2)\mathcal{N}(\mu, \sigma^2). Parameter λ=(μ,σ)\lambda = (\mu, \sigma). Maximizing L(λ)\mathcal{L}(\lambda) via Gradient Descent.

λL=λEqλ[logp(x,z)logqλ(z)]\nabla_\lambda \mathcal{L} = \nabla_\lambda \mathbb{E}_{q_\lambda} [ \log p(x, z) - \log q_\lambda(z) ]

Log-Derivative Trick (REINFORCE):

Eq[f]=Eq[f(z)logqλ(z)]\nabla \mathbb{E}_q [f] = \mathbb{E}_q [ f(z) \nabla \log q_\lambda(z) ]

High variance? Reparameterization Trick: Let z=g(ϵ,λ)z = g(\epsilon, \lambda). e.g., z=μ+σϵ,ϵN(0,1)z = \mu + \sigma \epsilon, \epsilon \sim \mathcal{N}(0, 1).

λEp(ϵ)[f(g(ϵ,λ))]=Ep(ϵ)[zfλg]\nabla_\lambda \mathbb{E}_{p(\epsilon)} [ f(g(\epsilon, \lambda)) ] = \mathbb{E}_{p(\epsilon)} [ \nabla_z f \nabla_\lambda g ]

This uses the gradient of the model zlogp\nabla_z \log p. Much lower variance. This is the engine of Variational Autoencoders (VAEs).


4. The Geometry of Natural Gradients

Standard Gradient Descent ϕϕαϕL\phi \leftarrow \phi - \alpha \nabla_\phi \mathcal{L} assumes the parameter space is Euclidean (L2L_2 distance). But parameters of a distribution live on a statistical manifold. A small change in μ\mu (mean) means something very different if σ\sigma is small vs large. Trust Region optimization:

ϕnew=argmaxϕL(ϕ)s.t. DKL(qϕqϕold)<ϵ\phi_{new} = \text{argmax}_\phi \mathcal{L}(\phi) \quad \text{s.t. } D_{KL}(q_\phi \| q_{\phi_{old}}) < \epsilon

The KL divergence to second order is the Fisher Information Matrix FF:

DKL(qϕ+dϕqϕ)12dϕTF(ϕ)dϕD_{KL}(q_{\phi+d\phi} \| q_\phi) \approx \frac{1}{2} d\phi^T F(\phi) d\phi

where F(ϕ)=Eq[logqlogqT]F(\phi) = \mathbb{E}_q [ \nabla \log q \nabla \log q^T ]. The constraint becomes a quadratic bound. The optimal update direction is the Natural Gradient:

gnat=F1(ϕ)ϕLg_{nat} = F^{-1}(\phi) \nabla_\phi \mathcal{L}

Gaussian Natural Gradients: For q(z)=N(μ,Σ)q(z) = \mathcal{N}(\mu, \Sigma), the Fisher Matrix is block diagonal. Remarkably, for exponential families, we don’t need to invert the FF matrix explicitly! Using the Canonical Parameters η\eta, the Gradients ηL\nabla_\eta \mathcal{L} are equivalent to the Natural Gradients in the Mean Parameter space. Binary updates become O(1)O(1).

λnew=(1ρ)λold+ρ(Eq[T(z)]+)\lambda_{new} = (1 - \rho) \lambda_{old} + \rho (\mathbb{E}_q[T(z)] + \dots)

This is the basis of SVI (Stochastic Variational Inference) which scales to billions of examples.


5. Stochastic Variational Inference (SVI)

Traditional VI requires iterating through the entire dataset to compute the coordinate ascent updates. For N=109N = 10^9 documents (LDA), this is impossible. Hoffman et al. (2013) introduced SVI. Key idea: The Natural Gradient of the global parameters is a sum of local structures.

L=Eq[ηg]λg+i=1N(Eqi[ηl]λl)\nabla \mathcal{L} = \mathbb{E}_q [\eta_g] - \lambda_g + \sum_{i=1}^N (\mathbb{E}_{q_i} [\eta_l] - \lambda_l)

We can estimate this noisy gradient using a minibatch of size MM.

^L=Eq[ηg]λg+NMm=1M(Eqm[ηl]λl)\hat{\nabla} \mathcal{L} = \mathbb{E}_q [\eta_g] - \lambda_g + \frac{N}{M} \sum_{m=1}^M (\mathbb{E}_{q_m} [\eta_l] - \lambda_l)

Because we are in the Natural Gradient space (Riemannian), we can take steps of size ρt\rho_t satisfying Robbins-Monro conditions (ρt=,ρt2<\sum \rho_t = \infty, \sum \rho_t^2 < \infty).


6. Normalizing Flows (Beyond Mean Field)

The biggest limitation of VI is choice of qq. Using a Gaussian qq limits us to unimodal, light-tailed approximations. Normalizing Flows allow us to construct complex densities by transforming a simple base distribution (e.g. Gaussian) through a sequence of invertible maps fkf_k.

zK=fKf1(z0),z0N(0,I)z_K = f_K \circ \dots \circ f_1(z_0), \quad z_0 \sim \mathcal{N}(0, I)

Change of Variables Formula:

logqK(zK)=logq0(z0)k=1Klogdetfkzk1\log q_K(z_K) = \log q_0(z_0) - \sum_{k=1}^K \log \left| \det \frac{\partial f_k}{\partial z_{k-1}} \right|

The challenge is finding functions with (1) High expressivity and (2) Linear-time Jacobian determinants.

Planar Flows (Rezende & Mohamed 2015):

f(z)=z+utanh(wTz+b)f(z) = z + u \tanh(w^T z + b)

New density concentrates mass along a hyperplane. Jacobian Determinant Lemma:

det(I+uψ(z)T)=1+uTψ(z)\det \left( I + u \psi(z)^T \right) = 1 + u^T \psi(z)

Stacking 10-20 Planar Flows allows the posterior to snake around complex, multi-modal landscapes.


7. Stein Variational Gradient Descent (SVGD)

BBVI is limited by the family qq (e.g., Gaussian implies no multi-modality). We want non-parametric VI. Use particles. Let {zi}i=1N\{z_i\}_{i=1}^N be particles approximating qq. We want to move particles zz+ϵϕ(z)z \to z + \epsilon \phi(z) to decrease KL. Let T(z)=z+ϵϕ(z)T(z) = z + \epsilon \phi(z). q[T]=T#qq_{[T]} = T_\# q. Derivative of KL:

ϵDKL(q[T]p)ϵ=0=Eq[trace(Apϕ)]\nabla_\epsilon D_{KL}(q_{[T]} \| p) |_{\epsilon=0} = - \mathbb{E}_q [ \text{trace}(\mathcal{A}_p \phi) ]

where Ap\mathcal{A}_p is the Stein Operator: Apϕ=ϕlogp+ϕ\mathcal{A}_p \phi = \phi \nabla \log p + \nabla \cdot \phi. We seek optimal perturbation ϕ\phi in the unit ball of an RKHS Hk\mathcal{H}_k.

ϕ(x)Eyq[k(x,y)ylogp(y)+yk(x,y)]\phi^*(x) \propto \mathbb{E}_{y \sim q} [ k(x, y) \nabla_y \log p(y) + \nabla_y k(x, y) ]

Algorithm (SVGD):

zizi+ϵ(1Nj[k(zi,zj)logp(x,zj)+zjk(zi,zj)])z_i \leftarrow z_i + \epsilon \left( \frac{1}{N} \sum_j [ k(z_i, z_j) \nabla \log p(x, z_j) + \nabla_{z_j} k(z_i, z_j) ] \right)

Interpretation:

  1. Driver: Moves ziz_i towards high probability (logp\nabla \log p).
  2. Repulsion: Moves ziz_i away from zjz_j (k\nabla k term acts as repulsive force). This prevents mode collapse! It allows particles to cover the posterior.

8. Variance Reduction: The Pathwise Derivative

Why did VAEs (2014) revolutionize the field? Because standard gradient estimation for expectations is noisy. We want ϕEqϕ[f(z)]\nabla_\phi \mathbb{E}_{q_\phi} [f(z)].

Method 1: Score Function (REINFORCE)

ϕEq[f]=Eq[f(z)ϕlogqϕ(z)]\nabla_\phi \mathbb{E}_q [f] = \mathbb{E}_q [ f(z) \nabla_\phi \log q_\phi(z) ]

Likelihood Ratio trick. valid for any density qq (even discrete). Variance: Proportional to Var(f(z))\text{Var}(f(z)). High noise.

Method 2: Pathwise Derivative (Reparameterization) Assume continuous zz and diffeomorphism z=g(ϵ,ϕ)z = g(\epsilon, \phi).

Eqϕ[f(z)]=Ep(ϵ)[f(g(ϵ,ϕ))]\mathbb{E}_{q_\phi} [f(z)] = \mathbb{E}_{p(\epsilon)} [ f(g(\epsilon, \phi)) ]

Gradient moves inside:

ϕE=Ep(ϵ)[zf(z)ϕg(ϵ,ϕ)]\nabla_\phi \mathbb{E} = \mathbb{E}_{p(\epsilon)} [ \nabla_z f(z) \nabla_\phi g(\epsilon, \phi) ]

Variance: Typically orders of magnitude lower. Limitation: Requires ff (model) and qq to be differentiable.


9. JAX Implementation: Particles vs Distribution

import jax import jax.numpy as jnp from jax import grad, vmap, jit from jax import random import matplotlib.pyplot as plt # 1. Define Kernels def rbf_kernel(X, h=-1): # vectorized RBF kernel # X: (N, D) # returns K: (N, N), grad_K: (N, N, D) # Pairwise distances diff = X[:, None, :] - X[None, :, :] # (N, N, D) sq_dist = jnp.sum(diff**2, axis=-1) # (N, N) # Median Heuristic if h < 0: h = jnp.median(sq_dist) / jnp.log(X.shape[0]) K = jnp.exp(-sq_dist / h) # Gradient of Kernel w.r.t the first particle set grad_K = -K[..., None] * diff * (2/h) return K, grad_K # 2. SVGD Step @jit def svgd_step(particles, log_prob_grad, step_size, optimizer_state=None): # Compute Score Function grad_logp = log_prob_grad(particles) # (N, D) # Kernel interaction K, grad_K = rbf_kernel(particles) # (N, N), (N, N, D) term1 = K @ grad_logp # (N, D) term2 = jnp.sum(grad_K, axis=1) # (N, D) # Sum over j phi = (term1 + term2) / particles.shape[0] return particles + step_size * phi # 3. Target Distribution: Bimodal Mixture def target_log_prob(x): # Mixture of two Gaussians at (-2, -2) and (2, 2) mu1 = jnp.array([-2.0, -2.0]) mu2 = jnp.array([2.0, 2.0]) w1, w2 = 0.5, 0.5 log_p1 = -0.5 * jnp.sum((x - mu1)**2) log_p2 = -0.5 * jnp.sum((x - mu2)**2) # LogSumExp trick return jax.scipy.special.logsumexp(jnp.array([log_p1, log_p2])) # Wrapper for vmapped gradients dist_grad = vmap(grad(target_log_prob)) def run_simulation(): key = random.PRNGKey(42) # Start all particles at (0,0) (Mode collapse state) particles = random.normal(key, (100, 2)) * 0.1 history = [] step_size = 0.1 for i in range(200): particles = svgd_step(particles, dist_grad, step_size) if i % 100 == 0: history.append(particles) # Plotting # Particles should split into two groups and cover both modes! return particles # Observation: Even effectively starting from a single point, the "Repulsive Force" # of the kernel (grad_K) pushes particles apart, forcing exploration of the second mode. # This proves SVGD > MCMC for multimodal mixing in many cases.

10. Conclusion: The Optimization Perspective

Variational Inference represents a philosophical shift in Statistics. Instead of simulating the process (MCMC), we design an optimization landscape (ELBO) whose geometry entices the solution to reveal itself. From the functional calculus of the Euler-Lagrange equations to the differential geometry of Natural Gradients and the particle physics of Stein Flows, VI provides a rich playground where Geometry meets Probability. As we scale to massive datasets and complex deep generative models, the “Optimization” view of inference is likely to dominate the “Sampling” view for the foreseeable future.


4. Stein Variational Gradient Descent (SVGD)

BBVI is limited by the family qq (e.g., Gaussian implies no multi-modality). We want non-parametric VI. Use particles. Let {zi}i=1N\{z_i\}_{i=1}^N be particles approximating qq. We want to move particles zz+ϵϕ(z)z \to z + \epsilon \phi(z) to decrease KL. Let T(z)=z+ϵϕ(z)T(z) = z + \epsilon \phi(z). q[T]=T#qq_{[T]} = T_\# q. Derivative of KL:

ϵDKL(q[T]p)ϵ=0=Eq[trace(Apϕ)]\nabla_\epsilon D_{KL}(q_{[T]} \| p) |_{\epsilon=0} = - \mathbb{E}_q [ \text{trace}(\mathcal{A}_p \phi) ]

where Ap\mathcal{A}_p is the Stein Operator: Apϕ=ϕlogp+ϕ\mathcal{A}_p \phi = \phi \nabla \log p + \nabla \cdot \phi. We seek optimal perturbation ϕ\phi in the unit ball of an RKHS Hk\mathcal{H}_k.

ϕ=argmaxϕH,ϕ1Eq[Apϕ]\phi^* = \text{argmax}_{\phi \in \mathcal{H}, \|\phi\| \le 1} \mathbb{E}_q [\mathcal{A}_p \phi]

Result (Liu & Wang 2016):

ϕ(x)Eyq[k(x,y)ylogp(y)+yk(x,y)]\phi^*(x) \propto \mathbb{E}_{y \sim q} [ k(x, y) \nabla_y \log p(y) + \nabla_y k(x, y) ]

Algorithm (SVGD):

zizi+ϵ(1Nj[k(zi,zj)logp(x,zj)+zjk(zi,zj)])z_i \leftarrow z_i + \epsilon \left( \frac{1}{N} \sum_j [ k(z_i, z_j) \nabla \log p(x, z_j) + \nabla_{z_j} k(z_i, z_j) ] \right)

Interpretation:

  1. Driver: Moves ziz_i towards high probability (logp\nabla \log p).
  2. Repulsion: Moves ziz_i away from zjz_j (k\nabla k term acts as repulsive force). This prevents mode collapse! It allows particles to cover the posterior.

5. Stochastic Variational Inference (SVI)

Traditional VI requires iterating through the entire dataset to compute the coordinate ascent updates. For N=109N = 10^9 documents (LDA), this is impossible. Hoffman et al. (2013) introduced SVI. Key idea: The Natural Gradient of the global parameters is a sum of local structures.

L=Eq[ηg]λg+i=1N(Eqi[ηl]λl)\nabla \mathcal{L} = \mathbb{E}_q [\eta_g] - \lambda_g + \sum_{i=1}^N (\mathbb{E}_{q_i} [\eta_l] - \lambda_l)

We can estimate this noisy gradient using a minibatch of size MM.

^L=Eq[ηg]λg+NMm=1M(Eqm[ηl]λl)\hat{\nabla} \mathcal{L} = \mathbb{E}_q [\eta_g] - \lambda_g + \frac{N}{M} \sum_{m=1}^M (\mathbb{E}_{q_m} [\eta_l] - \lambda_l)

Because we are in the Natural Gradient space (Riemannian), we can take steps of size ρt\rho_t satisfying Robbins-Monro conditions (ρt=,ρt2<\sum \rho_t = \infty, \sum \rho_t^2 < \infty). This proved that Bayesian inference could scale to “Big Data” comparably to deep learning.

6. Stein Variational Gradient Descent (SVGD)

BBVI is limited by the family qq (e.g., Gaussian implies no multi-modality). We want non-parametric VI. Use particles. Let {zi}i=1N\{z_i\}_{i=1}^N be particles approximating qq. We want to move particles zz+ϵϕ(z)z \to z + \epsilon \phi(z) to decrease KL. Let T(z)=z+ϵϕ(z)T(z) = z + \epsilon \phi(z). q[T]=T#qq_{[T]} = T_\# q. Derivative of KL:

ϵDKL(q[T]p)ϵ=0=Eq[trace(Apϕ)]\nabla_\epsilon D_{KL}(q_{[T]} \| p) |_{\epsilon=0} = - \mathbb{E}_q [ \text{trace}(\mathcal{A}_p \phi) ]

where Ap\mathcal{A}_p is the Stein Operator: Apϕ=ϕlogp+ϕ\mathcal{A}_p \phi = \phi \nabla \log p + \nabla \cdot \phi. We seek optimal perturbation ϕ\phi in the unit ball of an RKHS Hk\mathcal{H}_k.

ϕ(x)Eyq[k(x,y)ylogp(y)+yk(x,y)]\phi^*(x) \propto \mathbb{E}_{y \sim q} [ k(x, y) \nabla_y \log p(y) + \nabla_y k(x, y) ]

Algorithm (SVGD):

zizi+ϵ(1Nj[k(zi,zj)logp(x,zj)+zjk(zi,zj)])z_i \leftarrow z_i + \epsilon \left( \frac{1}{N} \sum_j [ k(z_i, z_j) \nabla \log p(x, z_j) + \nabla_{z_j} k(z_i, z_j) ] \right)

Interpretation:

  1. Driver: Moves ziz_i towards high probability (logp\nabla \log p).
  2. Repulsion: Moves ziz_i away from zjz_j (k\nabla k term acts as repulsive force). This prevents mode collapse! It allows particles to cover the posterior.

JAX Implementation: Particles vs Distribution

import jax import jax.numpy as jnp from jax import grad, vmap, jit from jax import random import matplotlib.pyplot as plt # 1. Define Kernels def rbf_kernel(X, h=-1): # vectorized RBF kernel # X: (N, D) # returns K: (N, N), grad_K: (N, N, D) # Pairwise distances diff = X[:, None, :] - X[None, :, :] # (N, N, D) sq_dist = jnp.sum(diff**2, axis=-1) # (N, N) # Median Heuristic if h < 0: h = jnp.median(sq_dist) / jnp.log(X.shape[0]) K = jnp.exp(-sq_dist / h) # Gradient of Kernel w.r.t the first particle set grad_K = -K[..., None] * diff * (2/h) return K, grad_K # 2. SVGD Step @jit def svgd_step(particles, log_prob_grad, step_size, optimizer_state=None): # Compute Score Function grad_logp = log_prob_grad(particles) # (N, D) # Kernel interaction K, grad_K = rbf_kernel(particles) # (N, N), (N, N, D) # The Stein Force # phi(xi) = sum_j [ k(xj, xi) score(xj) + grad_xj k(xj, xi) ] / N # Note: Our rbf_kernel implementation returns grad w.r.t first arg. # Symmetry allows us to use it directly. term1 = K @ grad_logp # (N, D) term2 = jnp.sum(grad_K, axis=1) # (N, D) # Sum over j phi = (term1 + term2) / particles.shape[0] return particles + step_size * phi # 3. Target Distribution: Bimodal Mixture def target_log_prob(x): # Mixture of two Gaussians at (-2, -2) and (2, 2) mu1 = jnp.array([-2.0, -2.0]) mu2 = jnp.array([2.0, 2.0]) w1, w2 = 0.5, 0.5 log_p1 = -0.5 * jnp.sum((x - mu1)**2) log_p2 = -0.5 * jnp.sum((x - mu2)**2) # LogSumExp trick return jax.scipy.special.logsumexp(jnp.array([log_p1, log_p2])) # Wrapper for vmapped gradients dist_grad = vmap(grad(target_log_prob)) def run_simulation(): key = random.PRNGKey(42) # Start all particles at (0,0) (Mode collapse state) particles = random.normal(key, (100, 2)) * 0.1 history = [] step_size = 0.1 for i in range(200): particles = svgd_step(particles, dist_grad, step_size) if i % 100 == 0: history.append(particles) # Plotting # Particles should split into two groups and cover both modes! return particles # Observation: Even effectively starting from a single point, the "Repulsive Force" # of the kernel (grad_K) pushes particles apart, forcing exploration of the second mode. # This proves SVGD > MCMC for multimodal mixing in many cases.

6. Variance Reduction: The Pathwise Derivative

Why did VAEs (2014) revolutionize the field? Because standard gradient estimation for expectations is noisy. We want ϕEqϕ[f(z)]\nabla_\phi \mathbb{E}_{q_\phi} [f(z)].

Method 1: Score Function (REINFORCE)

ϕEq[f]=Eq[f(z)ϕlogqϕ(z)]\nabla_\phi \mathbb{E}_q [f] = \mathbb{E}_q [ f(z) \nabla_\phi \log q_\phi(z) ]

Likelihood Ratio trick. valid for any density qq (even discrete). Variance: Proportional to Var(f(z))\text{Var}(f(z)). Even if the optimum is found, the gradient noise remains high unless f(z)0f(z) \approx 0 (Requires baselines). O(1/LearningRate)O(1/ \text{LearningRate}) convergence.

Method 2: Pathwise Derivative (Reparameterization) Assume continuous zz and diffeomorphism z=g(ϵ,ϕ)z = g(\epsilon, \phi).

Eqϕ[f(z)]=Ep(ϵ)[f(g(ϵ,ϕ))]\mathbb{E}_{q_\phi} [f(z)] = \mathbb{E}_{p(\epsilon)} [ f(g(\epsilon, \phi)) ]

Gradient moves inside:

ϕE=Ep(ϵ)[zf(z)ϕg(ϵ,ϕ)]\nabla_\phi \mathbb{E} = \mathbb{E}_{p(\epsilon)} [ \nabla_z f(z) \nabla_\phi g(\epsilon, \phi) ]

Variance: Relies on zf\nabla_z f (gradient of the model). Typical variance is orders of magnitude lower than REINFORCE. Why? The Score Function effectively “probes” the function by random sampling. The Pathwise derivative uses the analytic knowledge of the function’s slope to guide the update. Limitation: Requires ff (model) and qq to be differentiable. Cannot handle discrete latent variables easily (requires Gumbel-Softmax relaxation).


Conclusion: The Optimization Perspective

Variational Inference represents a philosophical shift in Statistics. Instead of simulating the process (MCMC), we design an optimization landscape (ELBO) whose geometry entices the solution to reveal itself. From the functional calculus of the Euler-Lagrange equations to the differential geometry of Natural Gradients and the particle physics of Stein Flows, VI provides a rich playground where Geometry meets Probability. As we scale to massive datasets and complex deep generative models, the “Optimization” view of inference is likely to dominate the “Sampling” view for the foreseeable future.


Historical Timeline

YearEventSignificance
1998Michael Jordan et al.Formalize Variational Methods for Graphical Models.
1998Shun-Ichi AmariNatural Gradient Descent.
2013Hoffman et al.Stochastic Variational Inference (SVI).
2014Kingma & WellingVAEs (Reparameterization Trick).
2015Rezende & MohamedNormalizing Flows.
2016Liu & WangStein Variational Gradient Descent (SVGD).

Appendix A: VAE Loss Derivation (Full Details)

The Variational Autoencoder has Generative Model pθ(xz)p_\theta(x|z) and prior p(z)p(z). Approximator qϕ(zx)q_\phi(z|x). ELBO:

L=Eqϕ(zx)[logpθ(xz)]DKL(qϕ(zx)p(z))\mathcal{L} = \mathbb{E}_{q_\phi(z|x)} [ \log p_\theta(x|z) ] - D_{KL}( q_\phi(z|x) \| p(z) )

A.1 The Gaussian KL Term

Let q(zx)=N(μ,Σ)q(z|x) = \mathcal{N}(\mu, \Sigma) and p(z)=N(0,I)p(z) = \mathcal{N}(0, I). General formula for KL between two Gaussians N0,N1\mathcal{N}_0, \mathcal{N}_1:

DKL(N0N1)=12[tr(Σ11Σ0)+(μ1μ0)TΣ11(μ1μ0)k+lndetΣ1detΣ0]D_{KL}(N_0 \| N_1) = \frac{1}{2} \left[ \text{tr}(\Sigma_1^{-1} \Sigma_0) + (\mu_1 - \mu_0)^T \Sigma_1^{-1} (\mu_1 - \mu_0) - k + \ln \frac{\det \Sigma_1}{\det \Sigma_0} \right]

Plugging in μ0=μ,Σ0=Σ\mu_0=\mu, \Sigma_0=\Sigma (usually diag representation σ2\sigma^2) and μ1=0,Σ1=I\mu_1=0, \Sigma_1=I: This loss term is minimized when μ0\mu \to 0 and σ1\sigma \to 1. It acts as a regularizer keeping the latent space compact.

A.2 The Reconstruction Term

Monte Carlo estimate. z(l)=μ+σϵ(l)z^{(l)} = \mu + \sigma \odot \epsilon^{(l)}. If p(xz)=N(x;Decoder(z),I)p(x|z) = \mathcal{N}(x; \text{Decoder}(z), I):

logp12xDecoder(z)2\log p \propto - \frac{1}{2} \| x - \text{Decoder}(z) \|^2

This recovers the MSE loss.


Appendix B: The Gumbel-Softmax Trick (Discrete Regularization)

Reparameterization works for continuous variables. What about categories? We cannot backpropagate through sampling a categorical variable zCat(π)z \sim \text{Cat}(\pi). Gumbel-Max Trick:

z=one_hot(argmax(logπi+gi))z = \text{one\_hot}( \text{argmax}( \log \pi_i + g_i ) )

where giGumbel(0,1)g_i \sim \text{Gumbel}(0, 1). Concrete Relaxation (Jang et al. 2016): Replace argmax with Softmax with temperature τ\tau:

yi=exp((logπi+gi)/τ)exp((logπj+gj)/τ)y_i = \frac{\exp((\log \pi_i + g_i)/\tau)}{\sum \exp((\log \pi_j + g_j)/\tau)}

As τ0\tau \to 0, yy approaches the one-hot sample. As τ\tau \to \infty, yy approaches uniform.


Appendix C: Generalization to Alpha-Divergences

Why minimize DKL(qp)D_{KL}(q \| p) and not DKL(pq)D_{KL}(p \| q) (Expectation Propagation) or something else? The “exclusive” KL (qpq\|p) forces qq to fit into a mode of pp (Zero forcing). The “inclusive” KL (pqp\|q) forces qq to cover the entire mass of pp (Mass covering). We can generalize using Renyi’s α\alpha-divergence:

Dα(pq)=1α1logp(z)αq(z)1αdzD_\alpha(p \| q) = \frac{1}{\alpha - 1} \log \int p(z)^\alpha q(z)^{1-\alpha} dz

Limit α1\alpha \to 1: Inclusive KL. Limit α0\alpha \to 0: Exclusive KL (qpq\|p). Black Box Alpha-VI (Hernandez-Lobato et al. 2016): We can minimize DαD_\alpha directly using the VR-max bound or importance sampling.

Lα(q)1αlog1Kk=1K(p(x,zk)q(zk))α\mathcal{L}_\alpha(q) \approx \frac{1}{\alpha} \log \frac{1}{K} \sum_{k=1}^K \left( \frac{p(x, z_k)}{q(z_k)} \right)^\alpha

This connects VI to Particle Filters and Sequential Monte Carlo. By tuning α\alpha, we can control the behavior of the approximator: from mode-seeking to mean-seeking to heavy-tailed covering.


Appendix D: Glossary of Terms

  • BBVI: Black-Box Variational Inference. Uses gradients of ELBO.
  • ELBO: Evidence Lower Bound. The objective function for VI.
  • LDA: Latent Dirichlet Allocation. A classic topic model using Mean Field VI.
  • Mean Field: Assumption that qq factorizes fully.
  • Natural Gradient: Gradient step in Riemannian manifold defined by Fisher Information.
  • Normalizing Flow: A sequence of invertible transformations to model complex densities.
  • Reparameterization Trick: E[f(z)]=E[f(g(ϵ))]\nabla \mathbb{E}[f(z)] = \mathbb{E}[\nabla f(g(\epsilon))]. Low variance.
  • SVGD: Particle-based VI using kernelized Stein discrepancy.

References

1. Blei, D. M. et al. (2017). “Variational Inference: A Review for Statisticians”. Extensive overview of VI as an alternative to MCMC. Covers Mean Field, Stochastic VI, and connections to convex optimization.

2. Liu, Q., & Wang, D. (2016). “Stein Variational Gradient Descent”. Introduces SVGD. Uses Stein’s identity and RKHS logic to derive a deterministic particle flow that simulates the heat flow of the posterior.

3. Kingma, D. P., & Welling, M. (2014). “Auto-Encoding Variational Bayes”. The paper that introduced VAEs and the Reparameterization Trick. It connected Deep Learning (backprop) with Variational Inference.

4. Rezende, D. J., & Mohamed, S. (2015). “Variational Inference with Normalizing Flows”.

  • mostró how to construct arbitrarily complex posteriors q(z)q(z) by transforming simple Gaussians through invertible neural networks.*

5. Amari, S. I. (1998). “Natural Gradient Works Efficiently in Learning”. The seminal paper defining the Natural Gradient for neural networks. It shows that the parameter space is Riemannian and the Fisher Information Metric is the unique metric invariant to reparameterization.

6. Jordan, M. I. et al. (1999). “An Introduction to Variational Methods for Graphical Models”. The classic tutorial that established the Mean Field approximation as the standard tool for exponential families in graphical models.