Information Geometry & Natural Gradients

The coordinate dependence of gradient descent

Let (Ω,F,μ)(\Omega, \mathcal{F}, \mu) be a measure space. We have a parametric family S={PξξΞ}\mathcal{S} = \{ P_\xi \mid \xi \in \Xi \} dominated by μ\mu, ΞRd\Xi \subseteq \mathbb{R}^d open, with densities:

p(x;ξ)=dPξdμ(x)p(x; \xi) = \frac{dP_\xi}{d\mu}(x)

We want a Riemannian metric gg and a connection \nabla on S\mathcal{S} that sit inside the family itself and do not care about how we happen to write the parameters down.

The trouble is that gradient descent takes the distance between PξP_\xi and Pξ+δξP_{\xi + \delta \xi} to be δξ2\|\delta \xi\|_2, and under a reparameterization ξϕ(ξ)\xi \to \phi(\xi) this Euclidean distance shifts around even though the distributions have not moved at all. So we need a divergence D[P:Q]D[P : Q] whose induced metric is

  1. Reparameterization-invariant, so that ds2(ξ,ξ+dξ)ds^2(\xi, \xi + d\xi) is a scalar invariant.
  2. Sufficient-statistic-invariant, so that if TT from Ω\Omega to Ω\Omega' is sufficient for ξ\xi, the geometry on S\mathcal{S} matches the geometry on the induced family over Ω\Omega'.

Regularity conditions and their failures

The standard theory sits on top of the following assumptions and each one quietly kicks out a bunch of cases we actually care about.

A1 (Identifiability). ξPξ\xi \mapsto P_\xi is injective. Neural networks break this all over the place through permutation symmetry and overparameterization and entire manifolds of solutions that all give the same distribution.

A2 (Common support). supp(Pξ)\text{supp}(P_\xi) does not depend on ξ\xi. The uniform distribution U[0,ξ]U[0, \xi] blows this up right away and the likelihood stops being differentiable at the boundary.

A3 (Smoothness). (ξ;x)=logp(x;ξ)\ell(\xi; x) = \log p(x; \xi) is at least C3C^3 in ξ\xi.

A4 (Interchange of derivative and integral).

ξΩp(x;ξ)dμ(x)=Ωξp(x;ξ)dμ(x)\nabla_\xi \int_{\Omega} p(x; \xi) d\mu(x) = \int_{\Omega} \nabla_\xi p(x; \xi) d\mu(x)

This needs the score to be uniformly integrable and it gets assumed way more often than anyone bothers to check it.

Deriving the metric from KL divergence

KL divergence is asymmetric and it breaks the triangle inequality so it is not a metric, but if you zoom in and do a second-order expansion you get back a real Riemannian metric.

DKL(ξξ)=p(x;ξ)logp(x;ξ)p(x;ξ)dμ(x)D_{KL}(\xi \| \xi') = \int p(x; \xi) \log \frac{p(x; \xi)}{p(x; \xi')} d\mu(x)

Set ξ=ξ+δξ\xi' = \xi + \delta \xi and Taylor-expand (x;ξ)\ell(x; \xi') around ξ\xi to see what comes out.

(ξ+δξ)=(ξ)+()Tδξ+12δξT(2)δξ+O(δξ3)\ell(\xi + \delta \xi) = \ell(\xi) + (\nabla \ell)^T \delta \xi + \frac{1}{2} \delta \xi^T (\nabla^2 \ell) \delta \xi + O(\|\delta \xi\|^3)

Plugging in gives

DKLEξ[()T]δξ12δξTEξ[2]δξD_{KL} \approx -\mathbb{E}_{\xi} [ (\nabla \ell)^T ] \delta \xi - \frac{1}{2} \delta \xi^T \mathbb{E}_{\xi} [ \nabla^2 \ell ] \delta \xi

The linear term drops out and the calculation below shows why.

Eξ[logp(x;ξ)]=p(x;ξ)p(x;ξ)p(x;ξ)dμ(x)=p(x;ξ)dμ(x)\mathbb{E}_\xi [\nabla \log p(x; \xi)] = \int p(x; \xi) \frac{\nabla p(x; \xi)}{p(x; \xi)} d\mu(x) = \int \nabla p(x; \xi) d\mu(x)

By A4,

p(x;ξ)dμ(x)=(1)=0\nabla \int p(x; \xi) d\mu(x) = \nabla(1) = 0

So DKLD_{KL} sits at a local minimum at δξ=0\delta \xi = 0 like you would expect and what is left over is the quadratic form

DKL12δξT(Eξ[2])δξD_{KL} \approx \frac{1}{2} \delta \xi^T \left( -\mathbb{E}_{\xi} [ \nabla^2 \ell ] \right) \delta \xi

Now for the Hessian-to-outer-product identity. Differentiating ()pdx=0\int (\nabla \ell) p dx = 0 one more time gives

(2+()()T)pdx=0\int \left( \nabla^2 \ell + (\nabla \ell)(\nabla \ell)^T \right) p dx = 0

and so

Gij(ξ)=E[ξiξj]=E[2ξiξj]G_{ij}(\xi) = \mathbb{E} \left[ \frac{\partial \ell}{\partial \xi_i} \frac{\partial \ell}{\partial \xi_j} \right] = - \mathbb{E} \left[ \frac{\partial^2 \ell}{\partial \xi_i \partial \xi_j} \right]

The local distance ds2=δξTG(ξ)δξds^2 = \delta \xi^T G(\xi) \delta \xi gives a Riemannian metric on S\mathcal{S} and the matrix G(ξ)G(\xi) is the Fisher Information Matrix.

Chentsov’s uniqueness theorem

Chentsov (1972) showed that the Fisher metric is the only Riemannian metric (up to scale) that stays invariant under congruent embeddings by Markov morphisms.

The idea is that a measurable map FF from Ω\Omega to Ω\Omega' pushes measures forward and if FF is a sufficient statistic then no information is lost, so distances have to be preserved. The Fisher metric does this on its own and no other Riemannian metric does.

gij(T(X))(θ)=gij(X)(θ)iff T is sufficient.g_{ij}^{(T(X))}(\theta) = g_{ij}^{(X)}(\theta) \quad \text{iff } T \text{ is sufficient.}

Dual connections

A metric gives lengths and angles but if you want geodesics and flatness you also need a connection.

In ordinary Riemannian geometry the obvious pick is the Levi-Civita connection which is metric-compatible and torsion-free and unique. Statistical manifolds carry a whole one-parameter family of connections (α)\nabla^{(\alpha)} with αR\alpha \in \mathbb{R} and this extra structure is where a lot of the action is.

The α\alpha-connection has Christoffel symbols

Γijk(α)=E[(ij+1α2(i)(j))(k)]\Gamma_{ijk}^{(\alpha)} = \mathbb{E} \left[ \left( \partial_i \partial_j \ell + \frac{1-\alpha}{2} (\partial_i \ell)(\partial_j \ell) \right) (\partial_k \ell) \right]

Using the skewness tensor Tijk=E[(i)(j)(k)]T_{ijk} = \mathbb{E}[(\partial_i \ell)(\partial_j \ell)(\partial_k \ell)] this cleans up to

Γijk(α)=Γijk(0)α2Tijk\Gamma_{ijk}^{(\alpha)} = \Gamma_{ijk}^{(0)} - \frac{\alpha}{2} T_{ijk}

(α)\nabla^{(\alpha)} and (α)\nabla^{(-\alpha)} are dual with respect to gg, and so for any vector fields XX and YY and ZZ we have

XY,Zg=XY,Zg+Y,XZgX \langle Y, Z \rangle_g = \langle \nabla_X Y, Z \rangle_g + \langle Y, \nabla_X^* Z \rangle_g

The exponential connection at α=1\alpha=1 and the mixture connection at α=1\alpha=-1 are the pair that does most of the work.

The hyperbolic geometry of Gaussians

The univariate Gaussian makes the theory concrete. Take S={N(μ,σ2)μR,σ>0}\mathcal{S} = \{ N(\mu, \sigma^2) \mid \mu \in \mathbb{R}, \sigma > 0 \}.

p(x;μ,σ)=12πσexp((xμ)22σ2)p(x; \mu, \sigma) = \frac{1}{\sqrt{2\pi}\sigma} \exp \left( - \frac{(x-\mu)^2}{2\sigma^2} \right) =12log(2π)logσ(xμ)22σ2\ell = -\frac{1}{2} \log(2\pi) - \log \sigma - \frac{(x-\mu)^2}{2\sigma^2}

The scores are

μ=xμσ2,σ=1σ+(xμ)2σ3\partial_\mu \ell = \frac{x-\mu}{\sigma^2}, \qquad \partial_\sigma \ell = -\frac{1}{\sigma} + \frac{(x-\mu)^2}{\sigma^3}

We work out the Fisher matrix one entry at a time.

gμμ=1σ4E[(xμ)2]=1σ2g_{\mu\mu} = \frac{1}{\sigma^4} \mathbb{E}[(x-\mu)^2] = \frac{1}{\sigma^2}.

gμσ=0g_{\mu\sigma} = 0 because the odd central moments of a Gaussian wipe out, so μ\mu and σ\sigma sit orthogonal to each other in the Riemannian sense.

For gσσg_{\sigma\sigma} we expand and use E[(xμ)4]=3σ4\mathbb{E}[(x-\mu)^4] = 3\sigma^4 which gives

gσσ=3σ22σ2+1σ2=2σ2g_{\sigma\sigma} = \frac{3}{\sigma^2} - \frac{2}{\sigma^2} + \frac{1}{\sigma^2} = \frac{2}{\sigma^2}

The Fisher matrix is

G(μ,σ)=(1/σ2002/σ2)G(\mu, \sigma) = \begin{pmatrix} 1/\sigma^2 & 0 \\ 0 & 2/\sigma^2 \end{pmatrix}

and the line element is

ds2=dμ2+2dσ2σ2ds^2 = \frac{d\mu^2 + 2d\sigma^2}{\sigma^2}

Put this next to the Poincare upper half-plane where ds2=dx2+dy2y2ds^2 = \frac{dx^2 + dy^2}{y^2}. The factor of 2 just rescales the curvature but the shape is the same and the manifold of univariate Gaussians has constant negative curvature.

Geodesics between N(μ1,σ1)N(\mu_1, \sigma_1) and N(μ2,σ2)N(\mu_2, \sigma_2) come from solving the Euler-Lagrange equations for L=μ˙2+2σ˙2σdtL = \int \frac{\sqrt{\dot{\mu}^2 + 2\dot{\sigma}^2}}{\sigma} dt. When μ1=μ2\mu_1 = \mu_2 the geodesic is a vertical line that just rescales the variance and otherwise the geodesics are semi-ellipses sitting on the μ\mu-axis.

Because of this negative curvature, averaging parameters coordinate by coordinate does not land you on the geometric center and the Frechet mean on this manifold can sit pretty far from the naive average.

Exponential families and e-flatness

Take the exponential family in natural parameters θ\theta and write it out as

p(x;θ)=exp(θiFi(x)ψ(θ))p(x; \theta) = \exp( \theta^i F_i(x) - \psi(\theta) )

The second derivative of the log-likelihood is ij=ijψ(θ)\partial_i \partial_j \ell = -\partial_i \partial_j \psi(\theta) and this has no xx in it at all, so the Hessian is just a deterministic function of θ\theta.

So the e-connection Christoffel symbols collapse to

Γijk(1)=E[ijk]=(ij)E[k]=0\Gamma_{ijk}^{(1)} = \mathbb{E} [ \partial_i \partial_j \ell \cdot \partial_k \ell ] = (\partial_i \partial_j \ell) \mathbb{E}[\partial_k \ell] = 0

because the score has zero mean. So the manifold is flat under the e-connection and the natural parameters are affine coordinates and geodesics are straight lines given by θ(t)=(1t)θ1+tθ2\theta(t) = (1-t)\theta_1 + t\theta_2.

Mixture families and the Pythagorean theorem

And on the other side, the mixture family

p(x;η)=(1ηi)p0(x)+ηipi(x)p(x; \eta) = (1 - \sum \eta_i) p_0(x) + \sum \eta_i p_i(x)

is flat under the m-connection at α=1\alpha=-1 and the expectation parameters η\eta are affine and geodesics are linear mixtures given by P(t)=(1t)P1+tP2P(t) = (1-t)P_1 + tP_2.

Now that we have dual flat structures sitting on the same space, a Pythagorean theorem falls out. When the e-geodesic PQPQ hits the m-geodesic QRQR at right angles at QQ we get

DKL(PR)=DKL(PQ)+DKL(QR)D_{KL}(P \| R) = D_{KL}(P \| Q) + D_{KL}(Q \| R)

The KL divergence between members of an exponential family is a Bregman divergence on the cumulant potential ψ\psi, and it reads

DKL(PθPPθR)=ψ(θR)ψ(θP)ψ(θP)(θRθP)D_{KL}(P_{\theta_P} \| P_{\theta_R}) = \psi(\theta_R) - \psi(\theta_P) - \nabla \psi(\theta_P) \cdot (\theta_R - \theta_P)

where ηP=ψ(θP)\eta_P = \nabla \psi(\theta_P).

If you write out DKL(PQ)+DKL(QR)D_{KL}(P\|Q) + D_{KL}(Q\|R) and subtract DKL(PR)D_{KL}(P\|R), a bunch of things cancel and what is left is

Δ=(ηPηQ)(θRθQ)\Delta = (\eta_P - \eta_Q) \cdot (\theta_R - \theta_Q)

The vector ηPηQ\eta_P - \eta_Q is the tangent in dual coordinates along PQP \to Q and θRθQ\theta_R - \theta_Q is the tangent in primal coordinates along QRQ \to R. Orthogonality says their inner product wipes out and so Δ=0\Delta = 0.

This gives you the projection theorem, which says that the m-projection of PP onto an e-flat submanifold is unique and satisfies the Pythagorean relation, and the MLE is just the m-projection of the empirical distribution onto the model manifold.

The uniform distribution and how regularity breaks

Breaking A2 changes the theory in a qualitative way. Take p(x;θ)=U[0,θ]=1θI(0xθ)p(x; \theta) = U[0, \theta] = \frac{1}{\theta} \mathbb{I}(0 \le x \le \theta).

(x;θ)=logθ,=1θ\ell(x; \theta) = -\log \theta, \qquad \nabla \ell = -\frac{1}{\theta} E[]=0θ(1θ)1θdx=1θ\mathbb{E}[\nabla \ell] = \int_0^\theta \left(-\frac{1}{\theta}\right) \frac{1}{\theta} dx = -\frac{1}{\theta}

The score does not have zero mean anymore and the whole derivation from the KL section falls apart, because ddθ0θpdx0θpdx\frac{d}{d\theta} \int_0^\theta p\, dx \neq \int_0^\theta \nabla p\, dx and Leibniz’s rule picks up a boundary term p(θ;θ)=1/θp(\theta; \theta) = 1/\theta.

And the fallout is real. The Fisher information G=1/θ2G = 1/\theta^2 looks finite on paper but the Cramer-Rao bound does not apply because A2 has failed. The MLE θ^=max(Xi)\hat{\theta} = \max(X_i) has variance O(n2)O(n^{-2}) and this beats the O(n1)O(n^{-1}) Fisher rate, because the support boundary holds information that a score-based story just cannot see.

Natural gradient descent

Say we want to minimize a loss L(θ)\mathcal{L}(\theta) on S\mathcal{S}. The update θnew=θηL\theta_{new} = \theta - \eta \nabla \mathcal{L} mashes covectors and vectors together because L\nabla \mathcal{L} is a covector and Δθ\Delta \theta is a vector, and they do not live in the same space.

So instead we solve

minδθL(θ+δθ)subject toDKL(θθ+δθ)=ϵ\min_{\delta \theta} \mathcal{L}(\theta + \delta \theta) \quad \text{subject to} \quad D_{KL}(\theta \| \theta + \delta \theta) = \epsilon

and plugging in the quadratic approximation DKL12δθTGδθD_{KL} \approx \frac{1}{2} \delta \theta^T G \delta \theta gives

δθ=ηG1(θ)L(θ)\delta \theta = -\eta G^{-1}(\theta) \nabla \mathcal{L}(\theta)

Cramer-Rao as Cauchy-Schwarz on the tangent space

The Cramer-Rao bound has a clean geometric proof that is really just Cauchy-Schwarz on the tangent space.

The score Sθ(x)=θ(x;θ)S_\theta(x) = \nabla_\theta \ell(x; \theta) sits in L2(Pθ)L^2(P_\theta) and the Fisher information is its squared norm Sθ2=G(θ)\|S_\theta\|^2 = G(\theta).

For an unbiased estimator θ^\hat{\theta} we work out Cov(θ^,Sθ)\text{Cov}(\hat{\theta}, S_\theta). Using p=pSθ\nabla p = p S_\theta and A4 we get

E[(θ^θ)SθT]=I\mathbb{E}[(\hat{\theta} - \theta) S_\theta^T] = I

And matrix Cauchy-Schwarz on this gives

Var(θ^)G(θ)1Var(\hat{\theta}) \ge G(\theta)^{-1}

High curvature means large GG and that says the distributions are easy to tell apart so estimators can be precise, and low curvature means small GG and the distributions look locally the same so no estimator can do well.

Singular models and Watanabe’s RLCT

When identifiability fails and A1 goes out the window, the Fisher matrix goes singular and we get

Θsing={θΘdet(G(θ))=0}\Theta_{sing} = \{ \theta \in \Theta \mid \det(G(\theta)) = 0 \}

The manifold dimension collapses at these points and the tangent space drops to a tangent cone, and in deep learning singular Fisher matrices are just the default rather than the exception.

Watanabe (2009) showed that near singularities the Bayesian posterior does not settle down to N(0,G1)N(0, G^{-1}) and the standard BIC complexity term d2logn\frac{d}{2} \log n gets replaced by λlogn\lambda \log n where λ\lambda is the Real Log Canonical Threshold (RLCT) and satisfies

λ<d2\lambda < \frac{d}{2}

Singular models are simpler than their parameter count makes them look. The posterior volume is set by resolution of singularities and that is a question for algebraic geometry and not Riemannian geometry. Standard information geometry stops working here and you have to pick up Watanabe’s singular learning theory instead.

Implementation (JAX)

We check e- and m-geodesic orthogonality and put Natural Gradient up against SGD on a warped Gaussian to see how they converge.

FISHER INFORMATION MANIFOLD
MODEL: N(θ, diag(θ² + 1))
VISUALIZATION: TISSOT INDICATRIX (EXACT)
import jax import jax.numpy as jnp from jax import random, grad, jit, vmap, lax from jax.scipy.stats import multivariate_normal from typing import NamedTuple, Tuple # ------------------------------------------------------------------ # SYSTEM CONFIGURATION # ------------------------------------------------------------------ SEED = 42 LEARNING_RATE = 0.1 NUM_STEPS = 100 DAMPING = 1e-4 # ------------------------------------------------------------------ # 1. Manifold Definition: Warped Gaussian # ------------------------------------------------------------------ def get_sigma(theta: jax.Array) -> jax.Array: """ Constructs the covariance matrix Sigma(theta) = diag(theta^2 + 1). Ensures positive definiteness everywhere. """ return jnp.diag(theta**2 + 1.0) def log_likelihood(theta: jax.Array, x: jax.Array) -> jax.Array: """ Computes sum of log-likelihoods for data x given theta. """ mu = theta cov = get_sigma(theta) return jnp.sum(multivariate_normal.logpdf(x, mu, cov)) # ------------------------------------------------------------------ # 2. Fisher Information Computation # ------------------------------------------------------------------ @jit def compute_fisher_mc(theta: jax.Array, key: jax.Array, num_samples: int = 1000) -> jax.Array: """ Approximates the Fisher Information Matrix using Monte Carlo integration. G(theta) = E[score * score^T] """ cov = get_sigma(theta) # Sampling from the model distribution at theta samples = random.multivariate_normal(key, theta, cov, shape=(num_samples,)) def score_fn(t, x_single): return grad(lambda p: multivariate_normal.logpdf(x_single, p, get_sigma(p)))(t) # Vectorized score computation scores = vmap(lambda x: score_fn(theta, x))(samples) # Outer product expectation outer_products = vmap(lambda s: jnp.outer(s, s))(scores) return jnp.mean(outer_products, axis=0) # ------------------------------------------------------------------ # 3. Optimization Loop (JIT-Compiled Scan) # ------------------------------------------------------------------ def loss_fn(theta: jax.Array, batch: jax.Array) -> jax.Array: """ Negative Log Likelihood Loss. """ return -log_likelihood(theta, batch) / batch.shape[0] class OptState(NamedTuple): theta_sgd: jax.Array theta_ngd: jax.Array key: jax.Array @jit def update_step_sgd(theta: jax.Array, batch: jax.Array) -> jax.Array: grads = grad(loss_fn)(theta, batch) return theta - LEARNING_RATE * grads @jit def update_step_ngd(theta: jax.Array, batch: jax.Array, key: jax.Array) -> jax.Array: grads = grad(loss_fn)(theta, batch) fisher = compute_fisher_mc(theta, key) # Natural Gradient: G^-1 * grad # Numerically stable solve: (G + damping * I) * update = grad regularized_fisher = fisher + DAMPING * jnp.eye(fisher.shape[0]) nat_grad = jnp.linalg.solve(regularized_fisher, grads) return theta - LEARNING_RATE * nat_grad @jit def run_experiment() -> Tuple[jax.Array, jax.Array]: """ Fully compiled training loop using lax.scan. """ key = random.PRNGKey(SEED) key, subkey_data, subkey_init = random.split(key, 3) # Ground Truth true_theta = jnp.array([2.0, 3.0]) data = random.multivariate_normal( subkey_data, true_theta, get_sigma(true_theta), shape=(500,) ) # Initialization theta_0 = jnp.array([0.5, 0.5]) init_state = OptState(theta_sgd=theta_0, theta_ngd=theta_0, key=subkey_init) def step_fn(state: OptState, _): key, subkey_ngd = random.split(state.key) # Parallel updates next_sgd = update_step_sgd(state.theta_sgd, data) next_ngd = update_step_ngd(state.theta_ngd, data, subkey_ngd) new_state = OptState(theta_sgd=next_sgd, theta_ngd=next_ngd, key=key) # Record trajectories return new_state, (next_sgd, next_ngd) # Execute simulation final_state, (path_sgd, path_ngd) = lax.scan(step_fn, init_state, None, length=NUM_STEPS) return path_sgd, path_ngd

Implications for optimization

Pulling all the theory together we land on a handful of concrete takeaways.

  1. The Fisher metric is the unique valid choice (Chentsov).
  2. Exponential families are both e-flat and m-flat (Amari).
  3. The Cramer-Rao bound is Cauchy-Schwarz on TθST_\theta \mathcal{S}.
  4. Deep learning lives in the singular regime where GG is rank-deficient (Watanabe).
  5. The natural gradient G1LG^{-1} \nabla \mathcal{L} is the unique first-order update that respects the geometry.

A quick word on Adam. It approximates GG with a diagonal diag(E[g2])\text{diag}(\sqrt{\mathbb{E}[g^2]}) and that has no clean geometric story behind it. Natural gradient scales by G1G^{-1} with units 1/g21/g^2 and Adam scales by G1/2G^{-1/2} with units 1/g1/g. In practice Adam is much closer to sign descent where you just normalize gradient magnitudes per coordinate, and not really Riemannian steepest descent at all. Adam works, but for reasons that are mostly beside the curvature correction story.

Without information geometry, optimization is running around in a coordinate system that has no real meaning attached to it.


Timeline

YearEventSignificance
1945C.R. RaoIntroduces Fisher Information Metric (Riemannian).
1972N. ChentsovProves Uniqueness Theorem for the metric.
1979B. EfronDefines statistical curvature.
1985Shun-ichi AmariDevelops Dualistic Geometry (α\alpha-connections).
1998AmariProposes Natural Gradient Descent.
2009Sumio WatanabeSingular Learning Theory (Algebraic Geometry of learning).
2014Pascanu & BengioRevisited Natural Gradient for Neural Networks.

Legendre Duality

The exponential and mixture duality is really just a Legendre transform.

ψ(θ)\psi(\theta) is the cumulant generating function and it reads ψ(θ)=logexp(θF(x))dμ(x)\psi(\theta) = \log \int \exp(\theta \cdot F(x)) d\mu(x)

and its Legendre conjugate is ϕ(η)=supθ{θηψ(θ)}\phi(\eta) = \sup_{\theta} \{ \theta \cdot \eta - \psi(\theta) \}

The sup gets hit where η=ψ(θ)\eta = \nabla \psi(\theta), which says ηi=E[Fi(x)]\eta_i = \mathbb{E}[F_i(x)], and this is the map that takes you from natural parameters to expectation parameters.

ϕ(η)\phi(\eta) is negative entropy up to constants and the inverse map goes the other way as θ=ϕ(η)\theta = \nabla \phi(\eta).

The Hessians give you the metric in each coordinate system, Gij(θ)=2ψθiθj,Gij(η)=2ϕηiηjG_{ij}(\theta) = \frac{\partial^2 \psi}{\partial \theta_i \partial \theta_j}, \qquad G^{ij}(\eta) = \frac{\partial^2 \phi}{\partial \eta_i \partial \eta_j}

and these two are inverses of each other up to coordinate-change Jacobians, and that confirms the Riemannian structure is consistent across the dual representations.


Fisher vs. Wasserstein

People mix these two up all the time but they actually live on different spaces.

Fisher geometry sits on parameter space S\mathcal{S} and is about densities, and Wasserstein geometry sits on the sample space Ω\Omega lifted up to measures and it needs a ground metric dΩ(x,y)d_\Omega(x, y) that Fisher geometry does not need at all.

Their geodesics are fundamentally different. The Fisher geodesic from the e-connection interpolates multiplicatively as logpt(x)=(1t)logp0(x)+tlogp1(x)ψ(t)\log p_t(x) = (1-t) \log p_0(x) + t \log p_1(x) - \psi(t)

and the Wasserstein geodesic just slides mass around horizontally as T(x)=x+tϕ(x)T(x) = x + t \nabla \phi(x)

Try interpolating N(0,1)N(0, 1) and N(10,1)N(10, 1) and see what happens. The Fisher m-geodesic runs through a bimodal mixture in the middle and the Wasserstein geodesic just slides the bump over smoothly and hits N(5,1)N(5, 1) at the halfway point.

Which one you reach for depends on the question you are asking. Fisher geometry fits inference and asks how much a sample tells you about θ\theta, and Wasserstein fits transport and asks what it costs to push one distribution into another. Fisher comes out of the entropy Hessian and Wasserstein comes out of kinetic energy minimization from Benamou-Brenier.


Vocabulary

An affine connection sets up parallel transport and covariant derivatives on a manifold and tells you how to compare tangent vectors at different points. The Fisher Information Metric is the only Riemannian metric on families of probability distributions (up to scale) that stays invariant under sufficient statistics and it falls out of the Hessian of KL divergence. The natural gradient G1LG^{-1} \nabla \mathcal{L} is the steepest descent direction that actually accounts for the curvature of the statistical manifold instead of pretending parameter space is flat. A statistical manifold is a family of probability distributions equipped with the Fisher metric and α\alpha-connections. Dual connections are a pair ((α),(α))(\nabla^{(\alpha)}, \nabla^{(-\alpha)}) that satisfy the compatibility condition XY,Z=XY,Z+Y,XZX \langle Y, Z \rangle = \langle \nabla_X Y, Z \rangle + \langle Y, \nabla_X^* Z \rangle with respect to the metric. The Kullback-Leibler divergence is the asymmetric divergence whose second-order Taylor expansion spits out the Fisher metric, and even though it is not a distance in the topological sense its local geometry is Riemannian.


References

1. Amari, S., & Nagaoka, H. (2000). “Methods of Information Geometry”. The standard reference. It sets up α\alpha-connections and dually flat spaces and applications to estimation.

2. Chentsov, N. N. (1972). “Statistical Decision Rules and Optimal Inference”. Proves the uniqueness of the Fisher metric from Markov invariance.

3. Rao, C. R. (1945). “Information and the accuracy attainable in the estimation of statistical parameters”. The original paper that proposed the Riemannian metric.

4. Watanabe, S. (2009). “Algebraic Geometry and Statistical Learning Theory”. The foundation of Singular Learning Theory and it handles the place where regular information geometry falls over in neural networks.

5. Martens, J. (2014). “New insights and perspectives on the natural gradient method”. A modern take on why Natural Gradient works for deep learning (K-FAC).