Information Geometry & Natural Gradients
1. Problem Formulation
Let be a measure space. Consider a parametric family of probability measures dominated by , where is an open set of parameters. We denote the Radon-Nikodym derivatives (probability densities) by:
Our objective is to define a geometric structure (a Riemannian metric and an affine connection ) on the manifold that satisfies intrinsic invariance.
The tension arises from the arbitrary nature of the parameterization . In standard Euclidean optimization (e.g., Gradient Descent), we implicitly assume the distance between and is . This is physically unjustified. A change in coordinates distorts this distance metric. Furthermore, the Euclidean distance is not invariant to the geometry of the sample space .
We require a divergence functional such that the induced metric structure is:
- Invariant to Reparameterization: is a scalar invariant.
- Invariant to Sufficient Statistics: If is a sufficient statistic for , then the geometry of must be identical to the geometry of the induced family on .
2. The Tools (Definitions and Regularity)
To ensure the existence of the Fisher Information and the validity of valid Taylor expansions, we explicitly state the required regularity conditions. Note that “standard” assumptions in machine learning often violate these (e.g., ReLU networks leading to singular Hessians, or uniform distributions violating support independence).
Assumption A1 (Identifiability): The map is injective. That is, on a set of non-zero measure. Failure Mode: In neural networks, permutation symmetry of neurons violates this locally. Overparameterization violates this globally (manifolds of equivalent solutions).
Assumption A2 (common Support): The support of the density, , is independent of . Failure Mode: Uniform distribution . The support boundary depends on , making the likelihood non-differentiable.
Assumption A3 (Smoothness): The log-likelihood function is -times differentiable with respect to , where .
Assumption A4 (Regularity of Integration): Differentiation with respect to and integration with respect to commute. Specifically:
This assumes the score function exists and is uniformly integrable.
3. Derivation of the Metric
We define the geometry locally via the Kullback-Leibler divergence as . We do not assume this is a metric distance a priori; KL is not symmetric and fails the triangle inequality. However, its second-order Taylor expansion induces a quadratic form.
Let . Expand around :
Substituting this into the KL definition:
Step 3.1: The Vanishing Linear Term We must verify that .
By Reference to Assumption A4, we exchange derivative and integral:
Thus, the first-order term vanishes. This is necessary for to be a local minimum at .
Step 3.2: The Quadratic Form We are left with the Hessian of the likelihood:
We invoke the identity linking the Hessian to the outer product of scores. Differentiating the score identity again:
Using :
Thus, the Fisher Information Matrix is defined equivalently as:
The local distance is given by the quadratic form . This defines a Riemannian metric on .
4. Uniqueness: Chentsov’s Theorem
Is this the only valid metric? Chentsov (1972) proved that the Fisher Information Metric is the unique Riemannian metric (up to a scaling factor) that is invariant under congruent embeddings by Markov morphisms.
Construct: Let be a measurable map (statistic). This induces a mapping from measures on to measures on . If is a sufficient statistic, no information is lost. The distance between and must be identical to the distance between their images under . Standard Euclidean distance fails this. The Fisher Metric, being defined by the covariance of the score, inherently respects sufficiency.
5. Dualistic Geometry and A ffine Connections
A metric allows us to measure lengths and angles. To define “straight lines” (geodesics) and discuss flatness, we need an affine connection . Standard Riemannian geometry uses the Levi-Civita connection , which is determined uniquely by the conditions:
- Metric compatibility:
- Torsion-freeness.
In Statistical Manifolds, however, we naturally encounter a family of connections parameterized by .
The -Connection The Christoffel symbols of the first kind for the -connection are defined as:
This definition facilitates simplification using the Skewness Tensor . Differentiating the metric identity :
Standard derivation leads to:
where is the Levi-Civita connection.
Duality: Two connections and are said to be dual with respect to metric if for all vector fields :
Theorem: The -connection and -connection are dual. Specifically, the Exponential Connection () and the Mixture Connection () are duals.
6. Case Study: The Hyperbolic Geometry of the Normal Family
We now apply our tools to the most fundamental object in statistics: the Univariate Gaussian. We derive the Riemannian structure directly.
Consider the manifold . The density is:
The log-likelihood :
Step 6.1: The Score Function We compute the partial derivatives (scores) with respect to coordinates .
Step 6.2: The Fisher Information Matrix We compute the elements of .
Element :
Since :
Element :
This involves (skewness) and (mean). For a Gaussian, odd central moments vanish.
This implies the parameters and are orthogonality in the Riemannian sense.
Element :
Recall the Gaussian moments: , .
Thus, the Fisher Information Matrix is:
Step 6.3: The Riemannian Metric and Distance The line element is:
This closely resembles the metric of the Poincaré Upper Half-Plane model of Hyperbolic Geometry (). The factor of 2 indicates a difference in curvature scaling.
Step 6.4: Geodesics To find the shortest paths (geodesics) between distributions and , we solve the Euler-Lagrange equations for the functional .
The geodesics correspond to:
- Vertical lines: If , the path is simply scaling the variance.
- Semi-ellipses: If , the geodesics are semi-ellipses centered on the -axis.
This confirms that the manifold of Gaussian distributions has constant negative curvature. We are not operating in a flat space; we are operating in a hyperbolic space. Traditional Euclidean averaging of parameters () is not the geometric center (Fréchet mean) of the distributions.
7. The Exponential Family (1-flatness)
Consider the exponential family in canonical parameters :
We analyze the curvature.
The second derivative is constant with respect to . This is the crucial property. Substitute into the definition of (the e-connection):
Since is deterministic (independent of ), it comes out of the expectation:
Reference Step 3.1: . Therefore:
Conclusion: The exponential family manifold is flat under the e-connection. The parameters are an affine coordinate system. Geodesics are straight lines in : .
8. The Mixture Family (-1-flatness)
Consider the mixture family:
This manifold is flat under the m-connection (). The expectation parameters form the affine coordinate system. Geodesics are linear mixtures: .
The Generalized Pythagorean Theorem: Since and are dual flat, if we have a triangle where the e-geodesic is orthogonal to the m-geodesic at , then:
Proof: We provide a derivation below. Let be the -affine coordinates (natural parameters) and be the -affine coordinates (expectation parameters). has coordinates , has coordinates . lies on the -geodesic from to some other point, or we describe the geodesics.
Let the curve connecting and be an e-geodesic. In the -coordinate system, this is a straight line:
The tangent vector is .
Let the curve connecting and be an m-geodesic. In the -coordinate system, this is a straight line. (Note: and are dual coordinate systems). The tangent vector at is best expressed in dual coordinates.
Consider the KL divergence definition between members of an exponential family:
Note that corresponds to the Bregman divergence on the convex potential :
This assumes .
Now expand the RHS terms:
Summing them:
We want this to equal .
The difference is:
Group by :
For the Pythagorean theorem to hold (), we require:
Geometric Interpretation:
- : Change in dual parameter along the path .
- : Change in primal parameter along path .
If is an m-geodesic, then changes linearly, so is the tangent vector (in -space). If is an e-geodesic, then changes linearly, so is the tangent vector (in -space).
Thus, if the m-geodesic is orthogonal to the e-geodesic , the divergence splits. orthogonality here means the Euclidean dot product of the parameters in dual spaces is 0. This justifies the Projection Theorem: The m-projection of onto a e-flat submanifold is unique and satisfies the Pythagorean relation.
Application: The Maximum Likelihood Estimator (MLE) is the m-projection of the empirical distribution onto the model manifold .
9. Pathologies: The Uniform Boundary
Violating Assumption A2 leads to pathologies. Consider the uniform distribution .
Observe that . This violates the zero-score condition. The derivation in Section 3 collapses. Why? Because . The Leibniz integral rule picks up a boundary term: . . So . Our tools must be adjusted to include boundary terms.
Fisher Information Singularity: While this appears finite, the regularity conditions for the Cramer-Rao bound () require A2/A4. Since A2 fails, Cramer-Rao does not apply. The MLE is . The variance of scales as , which is faster than the rate predicted by Fisher. This phenomenon, known as “Super-efficiency,” violates the geometric intuition. The manifold has a boundary that contains information.
10. Natural Gradient Descent
We perform optimization on . We wish to minimize a loss . The straightforward update is geometrically invalid because is a covariant vector (1-form), while is a contravariant vector. They cannot be added.
We formulate the update as:
Approximating :
This yields the Natural Gradient update:
11. Implementation (JAX)
We verify the orthogonality of the e- and m- geodesics and the convergence of Natural Gradient vs SGD on a warped Gaussian landscape.
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap, lax
from jax.scipy.stats import multivariate_normal
from typing import NamedTuple, Tuple
# ------------------------------------------------------------------
# SYSTEM CONFIGURATION
# ------------------------------------------------------------------
SEED = 42
LEARNING_RATE = 0.1
NUM_STEPS = 100
DAMPING = 1e-4
# ------------------------------------------------------------------
# 1. Manifold Definition: Warped Gaussian
# ------------------------------------------------------------------
def get_sigma(theta: jax.Array) -> jax.Array:
"""
Constructs the covariance matrix Sigma(theta) = diag(theta^2 + 1).
Ensures positive definiteness everywhere.
"""
return jnp.diag(theta**2 + 1.0)
def log_likelihood(theta: jax.Array, x: jax.Array) -> jax.Array:
""" Computes sum of log-likelihoods for data x given theta. """
mu = theta
cov = get_sigma(theta)
return jnp.sum(multivariate_normal.logpdf(x, mu, cov))
# ------------------------------------------------------------------
# 2. Fisher Information Computation
# ------------------------------------------------------------------
@jit
def compute_fisher_mc(theta: jax.Array, key: jax.Array, num_samples: int = 1000) -> jax.Array:
"""
Approximates the Fisher Information Matrix using Monte Carlo integration.
G(theta) = E[score * score^T]
"""
cov = get_sigma(theta)
# Sampling from the model distribution at theta
samples = random.multivariate_normal(key, theta, cov, shape=(num_samples,))
def score_fn(t, x_single):
return grad(lambda p: multivariate_normal.logpdf(x_single, p, get_sigma(p)))(t)
# Vectorized score computation
scores = vmap(lambda x: score_fn(theta, x))(samples)
# Outer product expectation
outer_products = vmap(lambda s: jnp.outer(s, s))(scores)
return jnp.mean(outer_products, axis=0)
# ------------------------------------------------------------------
# 3. Optimization Loop (JIT-Compiled Scan)
# ------------------------------------------------------------------
def loss_fn(theta: jax.Array, batch: jax.Array) -> jax.Array:
""" Negative Log Likelihood Loss. """
return -log_likelihood(theta, batch) / batch.shape[0]
class OptState(NamedTuple):
theta_sgd: jax.Array
theta_ngd: jax.Array
key: jax.Array
@jit
def update_step_sgd(theta: jax.Array, batch: jax.Array) -> jax.Array:
grads = grad(loss_fn)(theta, batch)
return theta - LEARNING_RATE * grads
@jit
def update_step_ngd(theta: jax.Array, batch: jax.Array, key: jax.Array) -> jax.Array:
grads = grad(loss_fn)(theta, batch)
fisher = compute_fisher_mc(theta, key)
# Natural Gradient: G^-1 * grad
# Numerically stable solve: (G + damping * I) * update = grad
regularized_fisher = fisher + DAMPING * jnp.eye(fisher.shape[0])
nat_grad = jnp.linalg.solve(regularized_fisher, grads)
return theta - LEARNING_RATE * nat_grad
@jit
def run_experiment() -> Tuple[jax.Array, jax.Array]:
""" Fully compiled training loop using lax.scan. """
key = random.PRNGKey(SEED)
key, subkey_data, subkey_init = random.split(key, 3)
# Ground Truth
true_theta = jnp.array([2.0, 3.0])
data = random.multivariate_normal(
subkey_data, true_theta, get_sigma(true_theta), shape=(500,)
)
# Initialization
theta_0 = jnp.array([0.5, 0.5])
init_state = OptState(theta_sgd=theta_0, theta_ngd=theta_0, key=subkey_init)
def step_fn(state: OptState, _):
key, subkey_ngd = random.split(state.key)
# Parallel updates
next_sgd = update_step_sgd(state.theta_sgd, data)
next_ngd = update_step_ngd(state.theta_ngd, data, subkey_ngd)
new_state = OptState(theta_sgd=next_sgd, theta_ngd=next_ngd, key=key)
# Record trajectories
return new_state, (next_sgd, next_ngd)
# Execute simulation
final_state, (path_sgd, path_ngd) = lax.scan(step_fn, init_state, None, length=NUM_STEPS)
return path_sgd, path_ngd11. The Cramer-Rao Bound (Geometric Interpretation)
The Cramer-Rao Lower Bound (CRLB) is the fundamental limit of frequentist inference. Usually derived via algebraic manipulation of covariance, it is geometrically the Cauchy-Schwarz inequality on the Statistical Manifold.
Consider an unbiased estimator for . Let be a tangent vector. The score function lives in the Hilbert space . The Fisher Information is the norm in this space:
Consider the covariance between the estimator error and the score .
Using the identity :
Using :
(Assuming A4 holds).
Now apply the standard matrix inequality for covariance matrices:
Here , . .
Interpretation: The variance of any estimator is bounded by the inverse curvature of the manifold.
- High curvature ( large) Distributions are far apart Easy to distinguish Low Variance.
- Low curvature ( small) Distributions are similar Hard to distinguish High Variance.
12. Singular Learning Theory (The Geometry of Degeneracy)
Consider the case where Assumption A1 (Identifiability) fails? This is the case in Deep Learning. A Neural Network with permutable nodes is non-identifiable. The Fisher Matrix implies singularities.
At these points, the manifold dimension collapses. The “Tangent Space” is no longer a vector space; it is a tangent cone.
Watanabe’s Discovery: Sumio Watanabe (2009) proved that in singular regions, the Bayesian posterior does not converge as . Instead of the standard Asymptotic Expansion:
The complexity term is replaced by , where is the Real Log Canonical Threshold (RLCT).
This means singular models are less complex than their parameter count suggests. Geometrically, the volume of the posterior contraction is determined by the resolution of singularities in algebraic geometry. Standard Information Geometry (Riemannian) fails here. We require Singular Information Geometry.
13. Conclusion: From Geometry to Optimization
We have established:
- Strict Construction: The Fisher metric arises uniquely from invariance requirements (Chentsov).
- Dual Structure: The manifold is simultaneously -flat and -flat (Amari).
- Fundamental Bounds: The Cramer-Rao bound is the Cauchy-Schwarz inequality on .
- Singularity: Modern Deep Learning lives in the breakdown of this theory (Singular Learning Theory), where is rank-deficient.
- Operationalization: The Natural Gradient is the only type-safe first-order optimization step.
The Critique of Adam: Adaptive methods like Adam approximate by a diagonal matrix . Strictly, this is dimensionally inconsistent. is a -tensor. Its square root is not well-defined geometrically as a pre-conditioner in this way. Standard Natural Gradient scales by (units ). Adam scales by (units ). This implies Adam is not approximating curvature; it is normalizing magnitude. It operates on a different heuristic (Sign Descent) rather than Riemannian steepest descent.
Final Thought: Information Geometry provides the rigorous language to discuss optimization in probability space. Without it, we are merely adjusting knobs in a coordinate system that doesn’t exist.
Historical Timeline
| Year | Event | Significance |
|---|---|---|
| 1945 | C.R. Rao | Introduces Fisher Information Metric (Riemannian). |
| 1972 | N. Chentsov | Proves Uniqueness Theorem for the metric. |
| 1979 | B. Efron | Defines statistical curvature. |
| 1985 | Shun-ichi Amari | Develops Dualistic Geometry (-connections). |
| 1998 | Amari | Proposes Natural Gradient Descent. |
| 2009 | Sumio Watanabe | Singular Learning Theory (Algebraic Geometry of learning). |
| 2014 | Pascanu & Bengio | Revisited Natural Gradient for Neural Networks. |
Appendix A: The Legendre Duality
The duality between Exponential and Mixture families is an instance of Legendre Transformation in convex analysis.
Let be the convex function (cumulant generating function) defining the exponential family: The dual potential is the Legendre conjugate of : The supremum is attained at the point where the gradient matches: This mapping is the coordinate transformation from natural parameters to expectation parameters.
The function turns out to be the negative entropy (plus constants). The convex duality implies: Thus, the transformation between coordinate systems is given by the gradient of a convex potential. The Hessian of the potential is the metric: Matrices and are inverses of each other (up to coordinate change Jacobian). This confirms the Riemannian structure is consistent across dual representations.
Appendix B: Fisher vs. Wasserstein
Optimal Transport (Wasserstein Metric) is increasingly used for loss functions. How does it compare to Fisher geometry?
1. The Objects:
- Fisher Information describes the geometry of the Parameter Space . It is defined on the manifold of densities.
- Wasserstein Distance describes the geometry of the Sample Space lifted to measures. It depends on the ground metric .
2. The Geodesics:
-
Fisher Geodesic (e-connection): This is a multiplicative interpolation. Example: Interpolating and . The intermediate passes through if we stay in Gaussian family. But the mixture distribution (m-geodesic) is bimodal.
-
Wasserstein Geodesic (Displacement): This is a horizontal displacement of probability mass. Example: . The density physically slides across the axis. is the midpoint.
3. When to use which?
- Fisher: When you care about inference. How much information does a sample give about ?
- Wasserstein: When you care about mass transport. How much work does it take to morph image A into image B?
The geometric distinction is categorical: Fisher comes from the entropy Hessian (Dualistic). Wasserstein comes from the Kinetic Energy minimization (Benamou-Brenier).
Appendix C: Glossary of Definitions
- Affine Connection: Geometric object defining parallel transport and derivatives.
- Fisher Information Metric: The unique Riemannian metric on probability manifolds.
- Natural Gradient: Steepest descent direction accounting for curvature ().
- Statistical Manifold: A family of probability distributions equipped with geometric structure.
- Dual Connections: Pair of connections () satisfying the duality condition w.r.t the metric.
- Kullback-Leibler Divergence: The canonical divergence generating the Fisher metric.
References
1. Amari, S., & Nagaoka, H. (2000). “Methods of Information Geometry”. The Bible of the field. Defines -connections, dually flat spaces, and applications to estimation.
2. Chentsov, N. N. (1972). “Statistical Decision Rules and Optimal Inference”. Proved the uniqueness of the Fisher metric based on Markov invariance.
3. Rao, C. R. (1945). “Information and the accuracy attainable in the estimation of statistical parameters”. The original paper proposing the Riemannian metric.
4. Watanabe, S. (2009). “Algebraic Geometry and Statistical Learning Theory”. The foundation of Singular Learning Theory, handling the breakdown of regular information geometry in neural networks.
5. Martens, J. (2014). “New insights and perspectives on the natural gradient method”. A modern analysis of why Natural Gradient works for deep learning (K-FAC).