Information Geometry & Natural Gradients
The coordinate dependence of gradient descent
Let be a measure space. We have a parametric family dominated by , open, with densities:
We want a Riemannian metric and a connection on that sit inside the family itself and do not care about how we happen to write the parameters down.
The trouble is that gradient descent takes the distance between and to be , and under a reparameterization this Euclidean distance shifts around even though the distributions have not moved at all. So we need a divergence whose induced metric is
- Reparameterization-invariant, so that is a scalar invariant.
- Sufficient-statistic-invariant, so that if from to is sufficient for , the geometry on matches the geometry on the induced family over .
Regularity conditions and their failures
The standard theory sits on top of the following assumptions and each one quietly kicks out a bunch of cases we actually care about.
A1 (Identifiability). is injective. Neural networks break this all over the place through permutation symmetry and overparameterization and entire manifolds of solutions that all give the same distribution.
A2 (Common support). does not depend on . The uniform distribution blows this up right away and the likelihood stops being differentiable at the boundary.
A3 (Smoothness). is at least in .
A4 (Interchange of derivative and integral).
This needs the score to be uniformly integrable and it gets assumed way more often than anyone bothers to check it.
Deriving the metric from KL divergence
KL divergence is asymmetric and it breaks the triangle inequality so it is not a metric, but if you zoom in and do a second-order expansion you get back a real Riemannian metric.
Set and Taylor-expand around to see what comes out.
Plugging in gives
The linear term drops out and the calculation below shows why.
By A4,
So sits at a local minimum at like you would expect and what is left over is the quadratic form
Now for the Hessian-to-outer-product identity. Differentiating one more time gives
and so
The local distance gives a Riemannian metric on and the matrix is the Fisher Information Matrix.
Chentsov’s uniqueness theorem
Chentsov (1972) showed that the Fisher metric is the only Riemannian metric (up to scale) that stays invariant under congruent embeddings by Markov morphisms.
The idea is that a measurable map from to pushes measures forward and if is a sufficient statistic then no information is lost, so distances have to be preserved. The Fisher metric does this on its own and no other Riemannian metric does.
Dual connections
A metric gives lengths and angles but if you want geodesics and flatness you also need a connection.
In ordinary Riemannian geometry the obvious pick is the Levi-Civita connection which is metric-compatible and torsion-free and unique. Statistical manifolds carry a whole one-parameter family of connections with and this extra structure is where a lot of the action is.
The -connection has Christoffel symbols
Using the skewness tensor this cleans up to
and are dual with respect to , and so for any vector fields and and we have
The exponential connection at and the mixture connection at are the pair that does most of the work.
The hyperbolic geometry of Gaussians
The univariate Gaussian makes the theory concrete. Take .
The scores are
We work out the Fisher matrix one entry at a time.
.
because the odd central moments of a Gaussian wipe out, so and sit orthogonal to each other in the Riemannian sense.
For we expand and use which gives
The Fisher matrix is
and the line element is
Put this next to the Poincare upper half-plane where . The factor of 2 just rescales the curvature but the shape is the same and the manifold of univariate Gaussians has constant negative curvature.
Geodesics between and come from solving the Euler-Lagrange equations for . When the geodesic is a vertical line that just rescales the variance and otherwise the geodesics are semi-ellipses sitting on the -axis.
Because of this negative curvature, averaging parameters coordinate by coordinate does not land you on the geometric center and the Frechet mean on this manifold can sit pretty far from the naive average.
Exponential families and e-flatness
Take the exponential family in natural parameters and write it out as
The second derivative of the log-likelihood is and this has no in it at all, so the Hessian is just a deterministic function of .
So the e-connection Christoffel symbols collapse to
because the score has zero mean. So the manifold is flat under the e-connection and the natural parameters are affine coordinates and geodesics are straight lines given by .
Mixture families and the Pythagorean theorem
And on the other side, the mixture family
is flat under the m-connection at and the expectation parameters are affine and geodesics are linear mixtures given by .
Now that we have dual flat structures sitting on the same space, a Pythagorean theorem falls out. When the e-geodesic hits the m-geodesic at right angles at we get
The KL divergence between members of an exponential family is a Bregman divergence on the cumulant potential , and it reads
where .
If you write out and subtract , a bunch of things cancel and what is left is
The vector is the tangent in dual coordinates along and is the tangent in primal coordinates along . Orthogonality says their inner product wipes out and so .
This gives you the projection theorem, which says that the m-projection of onto an e-flat submanifold is unique and satisfies the Pythagorean relation, and the MLE is just the m-projection of the empirical distribution onto the model manifold.
The uniform distribution and how regularity breaks
Breaking A2 changes the theory in a qualitative way. Take .
The score does not have zero mean anymore and the whole derivation from the KL section falls apart, because and Leibniz’s rule picks up a boundary term .
And the fallout is real. The Fisher information looks finite on paper but the Cramer-Rao bound does not apply because A2 has failed. The MLE has variance and this beats the Fisher rate, because the support boundary holds information that a score-based story just cannot see.
Natural gradient descent
Say we want to minimize a loss on . The update mashes covectors and vectors together because is a covector and is a vector, and they do not live in the same space.
So instead we solve
and plugging in the quadratic approximation gives
Cramer-Rao as Cauchy-Schwarz on the tangent space
The Cramer-Rao bound has a clean geometric proof that is really just Cauchy-Schwarz on the tangent space.
The score sits in and the Fisher information is its squared norm .
For an unbiased estimator we work out . Using and A4 we get
And matrix Cauchy-Schwarz on this gives
High curvature means large and that says the distributions are easy to tell apart so estimators can be precise, and low curvature means small and the distributions look locally the same so no estimator can do well.
Singular models and Watanabe’s RLCT
When identifiability fails and A1 goes out the window, the Fisher matrix goes singular and we get
The manifold dimension collapses at these points and the tangent space drops to a tangent cone, and in deep learning singular Fisher matrices are just the default rather than the exception.
Watanabe (2009) showed that near singularities the Bayesian posterior does not settle down to and the standard BIC complexity term gets replaced by where is the Real Log Canonical Threshold (RLCT) and satisfies
Singular models are simpler than their parameter count makes them look. The posterior volume is set by resolution of singularities and that is a question for algebraic geometry and not Riemannian geometry. Standard information geometry stops working here and you have to pick up Watanabe’s singular learning theory instead.
Implementation (JAX)
We check e- and m-geodesic orthogonality and put Natural Gradient up against SGD on a warped Gaussian to see how they converge.
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap, lax
from jax.scipy.stats import multivariate_normal
from typing import NamedTuple, Tuple
# ------------------------------------------------------------------
# SYSTEM CONFIGURATION
# ------------------------------------------------------------------
SEED = 42
LEARNING_RATE = 0.1
NUM_STEPS = 100
DAMPING = 1e-4
# ------------------------------------------------------------------
# 1. Manifold Definition: Warped Gaussian
# ------------------------------------------------------------------
def get_sigma(theta: jax.Array) -> jax.Array:
"""
Constructs the covariance matrix Sigma(theta) = diag(theta^2 + 1).
Ensures positive definiteness everywhere.
"""
return jnp.diag(theta**2 + 1.0)
def log_likelihood(theta: jax.Array, x: jax.Array) -> jax.Array:
""" Computes sum of log-likelihoods for data x given theta. """
mu = theta
cov = get_sigma(theta)
return jnp.sum(multivariate_normal.logpdf(x, mu, cov))
# ------------------------------------------------------------------
# 2. Fisher Information Computation
# ------------------------------------------------------------------
@jit
def compute_fisher_mc(theta: jax.Array, key: jax.Array, num_samples: int = 1000) -> jax.Array:
"""
Approximates the Fisher Information Matrix using Monte Carlo integration.
G(theta) = E[score * score^T]
"""
cov = get_sigma(theta)
# Sampling from the model distribution at theta
samples = random.multivariate_normal(key, theta, cov, shape=(num_samples,))
def score_fn(t, x_single):
return grad(lambda p: multivariate_normal.logpdf(x_single, p, get_sigma(p)))(t)
# Vectorized score computation
scores = vmap(lambda x: score_fn(theta, x))(samples)
# Outer product expectation
outer_products = vmap(lambda s: jnp.outer(s, s))(scores)
return jnp.mean(outer_products, axis=0)
# ------------------------------------------------------------------
# 3. Optimization Loop (JIT-Compiled Scan)
# ------------------------------------------------------------------
def loss_fn(theta: jax.Array, batch: jax.Array) -> jax.Array:
""" Negative Log Likelihood Loss. """
return -log_likelihood(theta, batch) / batch.shape[0]
class OptState(NamedTuple):
theta_sgd: jax.Array
theta_ngd: jax.Array
key: jax.Array
@jit
def update_step_sgd(theta: jax.Array, batch: jax.Array) -> jax.Array:
grads = grad(loss_fn)(theta, batch)
return theta - LEARNING_RATE * grads
@jit
def update_step_ngd(theta: jax.Array, batch: jax.Array, key: jax.Array) -> jax.Array:
grads = grad(loss_fn)(theta, batch)
fisher = compute_fisher_mc(theta, key)
# Natural Gradient: G^-1 * grad
# Numerically stable solve: (G + damping * I) * update = grad
regularized_fisher = fisher + DAMPING * jnp.eye(fisher.shape[0])
nat_grad = jnp.linalg.solve(regularized_fisher, grads)
return theta - LEARNING_RATE * nat_grad
@jit
def run_experiment() -> Tuple[jax.Array, jax.Array]:
""" Fully compiled training loop using lax.scan. """
key = random.PRNGKey(SEED)
key, subkey_data, subkey_init = random.split(key, 3)
# Ground Truth
true_theta = jnp.array([2.0, 3.0])
data = random.multivariate_normal(
subkey_data, true_theta, get_sigma(true_theta), shape=(500,)
)
# Initialization
theta_0 = jnp.array([0.5, 0.5])
init_state = OptState(theta_sgd=theta_0, theta_ngd=theta_0, key=subkey_init)
def step_fn(state: OptState, _):
key, subkey_ngd = random.split(state.key)
# Parallel updates
next_sgd = update_step_sgd(state.theta_sgd, data)
next_ngd = update_step_ngd(state.theta_ngd, data, subkey_ngd)
new_state = OptState(theta_sgd=next_sgd, theta_ngd=next_ngd, key=key)
# Record trajectories
return new_state, (next_sgd, next_ngd)
# Execute simulation
final_state, (path_sgd, path_ngd) = lax.scan(step_fn, init_state, None, length=NUM_STEPS)
return path_sgd, path_ngdImplications for optimization
Pulling all the theory together we land on a handful of concrete takeaways.
- The Fisher metric is the unique valid choice (Chentsov).
- Exponential families are both e-flat and m-flat (Amari).
- The Cramer-Rao bound is Cauchy-Schwarz on .
- Deep learning lives in the singular regime where is rank-deficient (Watanabe).
- The natural gradient is the unique first-order update that respects the geometry.
A quick word on Adam. It approximates with a diagonal and that has no clean geometric story behind it. Natural gradient scales by with units and Adam scales by with units . In practice Adam is much closer to sign descent where you just normalize gradient magnitudes per coordinate, and not really Riemannian steepest descent at all. Adam works, but for reasons that are mostly beside the curvature correction story.
Without information geometry, optimization is running around in a coordinate system that has no real meaning attached to it.
Timeline
| Year | Event | Significance |
|---|---|---|
| 1945 | C.R. Rao | Introduces Fisher Information Metric (Riemannian). |
| 1972 | N. Chentsov | Proves Uniqueness Theorem for the metric. |
| 1979 | B. Efron | Defines statistical curvature. |
| 1985 | Shun-ichi Amari | Develops Dualistic Geometry (-connections). |
| 1998 | Amari | Proposes Natural Gradient Descent. |
| 2009 | Sumio Watanabe | Singular Learning Theory (Algebraic Geometry of learning). |
| 2014 | Pascanu & Bengio | Revisited Natural Gradient for Neural Networks. |
Legendre Duality
The exponential and mixture duality is really just a Legendre transform.
is the cumulant generating function and it reads
and its Legendre conjugate is
The sup gets hit where , which says , and this is the map that takes you from natural parameters to expectation parameters.
is negative entropy up to constants and the inverse map goes the other way as .
The Hessians give you the metric in each coordinate system,
and these two are inverses of each other up to coordinate-change Jacobians, and that confirms the Riemannian structure is consistent across the dual representations.
Fisher vs. Wasserstein
People mix these two up all the time but they actually live on different spaces.
Fisher geometry sits on parameter space and is about densities, and Wasserstein geometry sits on the sample space lifted up to measures and it needs a ground metric that Fisher geometry does not need at all.
Their geodesics are fundamentally different. The Fisher geodesic from the e-connection interpolates multiplicatively as
and the Wasserstein geodesic just slides mass around horizontally as
Try interpolating and and see what happens. The Fisher m-geodesic runs through a bimodal mixture in the middle and the Wasserstein geodesic just slides the bump over smoothly and hits at the halfway point.
Which one you reach for depends on the question you are asking. Fisher geometry fits inference and asks how much a sample tells you about , and Wasserstein fits transport and asks what it costs to push one distribution into another. Fisher comes out of the entropy Hessian and Wasserstein comes out of kinetic energy minimization from Benamou-Brenier.
Vocabulary
An affine connection sets up parallel transport and covariant derivatives on a manifold and tells you how to compare tangent vectors at different points. The Fisher Information Metric is the only Riemannian metric on families of probability distributions (up to scale) that stays invariant under sufficient statistics and it falls out of the Hessian of KL divergence. The natural gradient is the steepest descent direction that actually accounts for the curvature of the statistical manifold instead of pretending parameter space is flat. A statistical manifold is a family of probability distributions equipped with the Fisher metric and -connections. Dual connections are a pair that satisfy the compatibility condition with respect to the metric. The Kullback-Leibler divergence is the asymmetric divergence whose second-order Taylor expansion spits out the Fisher metric, and even though it is not a distance in the topological sense its local geometry is Riemannian.
References
1. Amari, S., & Nagaoka, H. (2000). “Methods of Information Geometry”. The standard reference. It sets up -connections and dually flat spaces and applications to estimation.
2. Chentsov, N. N. (1972). “Statistical Decision Rules and Optimal Inference”. Proves the uniqueness of the Fisher metric from Markov invariance.
3. Rao, C. R. (1945). “Information and the accuracy attainable in the estimation of statistical parameters”. The original paper that proposed the Riemannian metric.
4. Watanabe, S. (2009). “Algebraic Geometry and Statistical Learning Theory”. The foundation of Singular Learning Theory and it handles the place where regular information geometry falls over in neural networks.
5. Martens, J. (2014). “New insights and perspectives on the natural gradient method”. A modern take on why Natural Gradient works for deep learning (K-FAC).