Wasserstein Gradient Flows

Gradient descent minimizes F(x)F(x) by chasing the steepest direction in Euclidean space and the obvious next question is what happens when the thing you are optimizing is not a point in Rd\mathbb{R}^d but a probability distribution ρ\rho and you want the optimizer to respect how mass actually moves.

Neither the L2L^2 norm nor KL divergence cares about the cost of physically moving mass and an L2L^2 gradient flow wipes out mass in one region and spawns it somewhere else with no regard for geometry and the result is a path that makes no physical sense.

The W2W_2 metric fixes this by turning P2(Rd)\mathcal{P}_2(\mathbb{R}^d) into an infinite-dimensional Riemannian manifold and Jordan and Kinderlehrer and Otto (1998) showed that a lot of classical PDEs like the heat equation and Fokker-Planck and the porous medium equation are gradient flows of familiar functionals like entropy and internal energy sitting in this geometry.

This lets you point optimization language at PDEs and manifold geometry at stochastic processes and the rest of this piece builds Otto calculus and the tools for optimization and stability analysis on the space of measures.


1. Benamou-Brenier: Transport as Fluid Flow

Kantorovich’s formulation is static and it picks out a coupling γ\gamma but says nothing about the path and Benamou and Brenier (2000) recast optimal transport as a fluid mechanics problem where you watch mass actually move.

A density ρt(x)\rho_t(x) moving under a velocity field vt(x)v_t(x) satisfies the continuity equation.

tρt+(ρtvt)=0\partial_t \rho_t + \nabla \cdot (\rho_t v_t) = 0

This is conservation of mass and the total measure holds steady while the flux is run by vtv_t and nothing leaks in or out.

Their theorem says the squared W2W_2 distance equals the minimum kinetic energy you need to push ρ0\rho_0 into ρ1\rho_1.

W22(ρ0,ρ1)=inf(ρ,v){01Rdρt(x)vt(x)2dxdt:tρ+(ρv)=0}W_2^2(\rho_0, \rho_1) = \inf_{(\rho, v)} \left\{ \int_0^1 \int_{\mathbb{R}^d} \rho_t(x) |v_t(x)|^2 dx dt : \partial_t \rho + \nabla \cdot (\rho v) = 0 \right\}

This is where the Riemannian structure shows up and the density ρ\rho plays the role of a point on the manifold and vv sits in the tangent space TρP2T_\rho \mathcal{P}_2.

Tangent space structure

Not every velocity field actually moves mass and a divergence-free piece like a vortex leaves the density untouched and only burns kinetic energy and by Helmholtz-Hodge the optimal velocity has to be a gradient.

v=ϕv = \nabla \phi

And so the tangent space at ρ\rho is a closure of gradients.

TρP2{ϕ:ϕCc(Rd)}L2(ρ)T_\rho \mathcal{P}_2 \cong \overline{\{ \nabla \phi : \phi \in C_c^\infty(\mathbb{R}^d) \}}^{L^2(\rho)}

The inner product on tangent vectors ϕ1\nabla \phi_1 and ϕ2\nabla \phi_2 weights everything by the density.

gρ(ϕ1,ϕ2)=Rdϕ1(x),ϕ2(x)ρ(x)dxg_\rho(\nabla \phi_1, \nabla \phi_2) = \int_{\mathbb{R}^d} \langle \nabla \phi_1(x), \nabla \phi_2(x) \rangle \rho(x) dx

Moving mass where there is no density is free and moving mass through high-density regions costs more.


2. Otto Calculus

Otto (2001) pinned down how to differentiate functionals on P2\mathcal{P}_2 and the whole thing falls out of the continuity equation and an integration by parts.

Take a path ρt\rho_t with velocity vtv_t and look at the rate of change of F\mathcal{F} along this path.

ddtF(ρt)=RdδFδρ(x)tρt(x)dx\frac{d}{dt} \mathcal{F}(\rho_t) = \int_{\mathbb{R}^d} \frac{\delta \mathcal{F}}{\delta \rho}(x) \partial_t \rho_t(x) dx

Plug the continuity equation in and integrate by parts.

ddtF(ρt)=δFδρ(ρv)dx=(δFδρ)vρdx\frac{d}{dt} \mathcal{F}(\rho_t) = - \int \frac{\delta \mathcal{F}}{\delta \rho} \nabla \cdot (\rho v) dx = \int \nabla \left( \frac{\delta \mathcal{F}}{\delta \rho} \right) \cdot v \rho dx

And by the Riesz representation theorem on (TρP2,gρ)(T_\rho \mathcal{P}_2, g_\rho) the Wasserstein gradient drops out clean.

gradWF(ρ)=(δFδρ)\text{grad}_{W} \mathcal{F}(\rho) = \nabla \left( \frac{\delta \mathcal{F}}{\delta \rho} \right)

Setting v=gradWFv = -\text{grad}_{W} \mathcal{F} and dropping it back into the continuity equation gives the general Wasserstein gradient flow PDE.

tρ=(ρδFδρ)\partial_t \rho = \nabla \cdot \left( \rho \nabla \frac{\delta \mathcal{F}}{\delta \rho} \right)

The specific equation you get hangs entirely on which F\mathcal{F} you pick and different functionals give different classical PDEs.

Hessian and Fisher information

The Wasserstein Hessian HessWF\text{Hess}_W \mathcal{F} holds the second-order behavior and for F(ρ)=ρlogρ\mathcal{F}(\rho) = \int \rho \log \rho it ties straight into Fisher information.

HessWF(ρ)ϕ,ϕgρ=ijijϕijϕρdx\langle \text{Hess}_W \mathcal{F}(\rho) \nabla \phi, \nabla \phi \rangle_{g_\rho} = \int \sum_{ij} \partial_{ij} \phi \partial_{ij} \phi \rho dx

When this Hessian sits bounded below by λI\lambda I you get displacement convexity and the flow contracts and this is the route to the log-Sobolev inequality.

Bakry-Emery Γ2\Gamma_2

The iterated gradient shows up like this.

Γ2(ϕ)=12Δϕ2ϕ,Δϕ\Gamma_2(\phi) = \frac{1}{2} \Delta |\nabla \phi|^2 - \langle \nabla \phi, \nabla \Delta \phi \rangle

The condition Γ2(ϕ)λϕ2\Gamma_2(\phi) \ge \lambda |\nabla \phi|^2 means Ricci curvature is at least λ\lambda and this is the Bochner identity written in the language of optimal transport.


3. Classical PDEs as Gradient Flows

Specific choices of F\mathcal{F} bring back classical evolution equations and the same machinery cranks out entropy flows and McKean-Vlasov and porous medium.

Heat equation

Take F(ρ)=ρlogρdx\mathcal{F}(\rho) = \int \rho \log \rho \, dx which is negative entropy and the first variation is δFδρ=logρ+1\frac{\delta \mathcal{F}}{\delta \rho} = \log \rho + 1 and the gradient is (logρ)=ρ/ρ\nabla(\log \rho) = \nabla \rho / \rho and plugging it into the general equation gives the heat equation.

tρ=(ρρρ)=(ρ)=Δρ\partial_t \rho = \nabla \cdot \left( \rho \frac{\nabla \rho}{\rho} \right) = \nabla \cdot (\nabla \rho) = \Delta \rho

The heat equation is the gradient flow of entropy and diffusion is particles redistributing to maximize entropy as efficiently as they can in the Wasserstein metric.

McKean-Vlasov

Now take a functional with three pieces and those pieces are internal energy and potential energy and interaction energy.

F(ρ)=U(ρ)dx+V(x)ρ(dx)+12(Wρ)ρ(dx)\mathcal{F}(\rho) = \int U(\rho) dx + \int V(x) \rho(dx) + \frac{1}{2} \int (W * \rho) \rho(dx)

The gradient flow drops straight out of the same recipe.

tρ=(ρU(ρ)+ρV+ρ(Wρ))\partial_t \rho = \nabla \cdot (\rho \nabla U'(\rho) + \rho \nabla V + \rho \nabla (W * \rho))

Setting U(ρ)=ρlogρU(\rho) = \rho \log \rho and W=0W=0 brings back Fokker-Planck and this one setup models everything from mean-field neural network training to collective biological motion and in the neural network setting VV is the loss surface and WW holds how particles push on each other.


4. The JKO Scheme

Solving tρ=gradWF(ρ)\partial_t \rho = -\text{grad}_W \mathcal{F}(\rho) head-on is usually hopeless and the JKO scheme (1998) discretizes the flow in a way that respects the Wasserstein geometry and turns the problem into a sequence of variational steps.

In Rd\mathbb{R}^d the implicit Euler step is the same thing as the proximal operator.

xk+1=argminx{F(x)+12τxxk2}x_{k+1} = \text{argmin}_{x} \left\{ F(x) + \frac{1}{2\tau} \| x - x_k \|^2 \right\}

Swap Euclidean distance out for W2W_2.

ρk+1=argminρP2{F(ρ)+12τW22(ρ,ρk)}\rho_{k+1} = \text{argmin}_{\rho \in \mathcal{P}_2} \left\{ \mathcal{F}(\rho) + \frac{1}{2\tau} W_2^2(\rho, \rho_k) \right\}

Each step weighs dropping F\mathcal{F} against the cost of moving mass away from ρk\rho_k and as τ0\tau \to 0 the iterates slide onto the PDE solution and the JKO construction also hands you existence and uniqueness proofs by building the solution as a limit of JKO iterates.

Particle discretization

When the gradient flow matches a Langevin SDE dXt=V(Xt)dt+2dBtdX_t = -\nabla V(X_t) dt + \sqrt{2} dB_t the JKO scheme drops down to a particle simulation.

Xk+1i=XkiτV(Xki)+2τξkiX_{k+1}^i = X_k^i - \tau \nabla V(X_k^i) + \sqrt{2\tau} \xi_k^i

with ξkiN(0,I)\xi_k^i \sim \mathcal{N}(0, I) and this is a Monte Carlo discretization of the Wasserstein flow and nothing more.


5. Convergence via Functional Inequalities

The Riemannian structure turns geometric properties like curvature into concrete convergence rates and you can read off mixing times directly from the Hessian.

If F\mathcal{F} is λ\lambda-displacement convex with λ>0\lambda > 0 the flow converges exponentially fast.

W2(ρt,ρ)eλtW2(ρ0,ρ)W_2(\rho_t, \rho_\infty) \le e^{-\lambda t} W_2(\rho_0, \rho_\infty)

Three fundamental inequalities tie together entropy HH and Wasserstein distance WW and Fisher information II and they all come out of the same curvature bound.

  1. Log-Sobolev (LSI): H(ρρ)12λI(ρρ)H(\rho | \rho_\infty) \le \frac{1}{2\lambda} I(\rho | \rho_\infty)
  2. Talagrand: W22(ρ,ρ)2λH(ρρ)W_2^2(\rho, \rho_\infty) \le \frac{2}{\lambda} H(\rho | \rho_\infty)
  3. HWI: H(ρρ)W2(ρ,ρ)I(ρρ)λ2W22H(\rho | \rho_\infty) \le W_2(\rho, \rho_\infty) \sqrt{I(\rho | \rho_\infty)} - \frac{\lambda}{2} W_2^2

These form a chain and bounding Fisher information bounds entropy and that bounds how far mass has to travel and they show up all over MCMC convergence proofs and generalization bounds.


6. SVGD

Stein Variational Gradient Descent (Liu and Wang 2016) runs a Wasserstein-like flow when you cannot evaluate ρ\rho directly and only have the score logp\nabla \log p to work with.

Restrict velocity fields to an RKHS Hd\mathcal{H}^d and pick out the one that drops entropy the fastest.

v(x)=Exρ[logp(x)k(x,)+xk(x,)]v(x) = \mathbb{E}_{x \sim \rho} [ \nabla \log p(x) k(x, \cdot) + \nabla_x k(x, \cdot) ]

The first term pulls particles toward high-density regions and the second shoves them apart through k\nabla k and keeps them from collapsing onto each other and the kernel bandwidth slides you between two regimes and infinite bandwidth brings back the W2W_2 flow and narrow kernels give you independent Langevin chains.


7. Mean Field Games

Look at NN agents each running their own Wasserstein flow to minimize an individual cost and an agent at position xx solves a control problem.

infvE[0TL(xt,vt,ρt)dt+Φ(xT,ρT)]\inf_{v} \mathbb{E} \left[ \int_0^T L(x_t, v_t, \rho_t) dt + \Phi(x_T, \rho_T) \right]

with ρt\rho_t the distribution of all the other agents and each one responding to the crowd.

At equilibrium this system splits into two coupled PDEs and you get Hamilton-Jacobi-Bellman for the individual value function and Fokker-Planck for the population density and optimal transport is the special potential case where agents want to move from ρ0\rho_0 to ν\nu at minimum cost and this lets you point OT solvers at traffic flow and financial markets and other decentralized systems.


8. Neural Gradient Flows

Parameterize the velocity field vθ(t,x)v_\theta(t, x) with a neural network and push a base distribution like a Gaussian through the ODE and watch it land on the data distribution.

dXtdt=vθ(t,Xt)\frac{dX_t}{dt} = v_\theta(t, X_t)

The objective is to shrink the distance between the evolved density and the data distribution and this ends up being gradient descent in parameter space that tracks a Wasserstein flow in measure space and this is what sits under continuous normalizing flows and it gives the theoretical skeleton of diffusion models.


9. Wasserstein Proximal Operators

The JKO variational step now gets used directly in machine learning and it shows up as a proximal operator on measures.

ProxτL(ρ)=argminν{L(ν)+12τW22(ν,ρ)}\text{Prox}_{\tau \mathcal{L}}(\rho) = \arg \min_{\nu} \left\{ \mathcal{L}(\nu) + \frac{1}{2\tau} W_2^2(\nu, \rho) \right\}

In practice you compute these through entropic regularization like Sinkhorn (see Optimal Transport) or particle flows or RKHS smoothing and the proximal operator contracts under displacement convexity and it pays off in variational inference when you want to respect the geometry of the state space instead of just minimizing KL divergence.


10. Propagation of Chaos

Why does the particle system actually converge to the PDE and what makes the limit work out. Look at NN coupled SDEs.

dXti=(V(Xti)1NjiW(XtiXtj))dt+2dBtidX_t^i = \left( -\nabla V(X_t^i) - \frac{1}{N} \sum_{j \neq i} \nabla W(X_t^i - X_t^j) \right) dt + \sqrt{2} dB_t^i

As NN \to \infty each particle’s influence on any one other particle fades to zero and McKean (1966) showed the joint distribution factors into identical marginals and each particle ends up tracking the mean field of the ensemble and the limiting density satisfies the McKean-Vlasov equation and this is what justifies particle filters and ensemble Kalman filters and particle-based variational inference.


11. Implementation: Langevin Dynamics as JKO

You can simulate the JKO flow of potential plus entropy by sampling from p(x)eV(x)/ϵp(x) \propto e^{-V(x)/\epsilon} with discretized Langevin and the code below runs particles through a double well and watches them settle.

import numpy as np import matplotlib.pyplot as plt def potential_v(x): # Double well potential (non-convex, multiple modes) return (x**2 - 1)**2 def grad_v(x): # grad V = 4x(x^2 - 1) return 4 * x * (x**2 - 1) def run_wasserstein_flow(n_particles=2000, n_steps=2000, dt=0.005, noise_scale=1.0): # Start with a narrow Gaussian at 0 (unstable saddle) X = np.random.normal(0, 0.05, n_particles) history = [] times = [0, 100, 500, 1000, 2000] for t in range(n_steps + 1): # Update via Langevin Dynamics (discretized gradient flow) diffusion = np.random.normal(0, np.sqrt(2 * noise_scale * dt), n_particles) X = X - grad_v(X) * dt + diffusion if t in times: history.append(X.copy()) return history def visualize_flow(history): plt.figure(figsize=(12, 7)) x_range = np.linspace(-2.5, 2.5, 200) # Plot target Gibbs density V_vals = potential_v(x_range) target = np.exp(-V_vals) target /= np.trapz(target, x_range) plt.plot(x_range, target, 'k--', lw=2, label='Target (Gibbs) Density') # Plot particle histograms at different times colors = plt.cm.viridis(np.linspace(0, 1, len(history))) for i, (snap, color) in enumerate(zip(history, colors)): plt.hist(snap, bins=60, density=True, alpha=0.3, color=color, label=f'T={i}') plt.title("Convergence of Particle System to Equilibrium (Wasserstein Flow)") plt.xlabel("State x") plt.ylabel("Density") plt.legend() plt.grid(True, alpha=0.3)

12. Summary

Seeing Fokker-Planck and diffusion as gradient flows pulls together PDE theory and statistical mechanics and optimization and the Riemannian manifold of measures hands you Hessians and curvature and geodesics and the full toolkit for dissipative systems and whether you care about generative modeling or biological dynamics or posterior sampling the Wasserstein perspective gives you the right geometric framework.


Displacement convexity

Standard convexity F((1t)ρ0+tρ1)\mathcal{F}((1-t)\rho_0 + t\rho_1) is the wrong notion here because linear interpolation of measures drags in spurious multimodality and it loosens the geometry in the wrong way. Displacement convexity instead asks for convexity along the Wasserstein geodesic ρt=((1t)Id+tT)#ρ0\rho_t = ((1-t)\text{Id} + tT)_\# \rho_0 with TT the Brenier map and McCann showed that potential and interaction energies are displacement convex when their kernels are convex and that internal energy is displacement convex when PρP0P' \rho - P \ge 0.


Porous medium equation

Take F(ρ)=1m1ρmdx\mathcal{F}(\rho) = \frac{1}{m-1} \int \rho^m dx and Otto calculus hands you tρ=Δ(ρm)\partial_t \rho = \Delta (\rho^m) straight out.

Unlike the heat equation which has infinite propagation speed the porous medium equation has finite propagation speed and compact initial support stays compact and it models gas flow through rock and biological dispersal.


Discrete flows on graphs

Wasserstein flows extend to graphs G=(V,E)G = (V, E) through the Maas-Mielke framework and you define a distance using a discrete continuity equation and you get a discrete Ricci curvature out of Bakry-Emery on graphs and positive curvature gives you exponentially fast mixing of random walks.


JKO convergence sketch

Write the Euler-Lagrange equation for the JKO variational problem and the optimality condition ties the first variation of F\mathcal{F} to the optimal transport potential and then you take the gradient and substitute into continuity and you end up with a consistent discretization of the flux.


Mean field limit

Under propagation of chaos the NN-particle joint distribution drifts toward a product of marginals and the empirical measure converges weakly to the McKean-Vlasov solution as NN \to \infty and individual noise terms collapse into collective pressure.


Talagrand’s inequality and concentration

Talagrand’s T2T_2 says that for Gaussian γ\gamma you get W22(ρ,γ)2KL(ργ)W_2^2(\rho, \gamma) \le 2 \text{KL}(\rho || \gamma) and low KL means physically close and that means low W2W_2 and this gets used for concentration of measure of Lipschitz functions.


Regularity theory

The 2008 Ambrosio Gigli Savare book sets out the rigorous framework and nails down the theory you need. Even for non-smooth F\mathcal{F} as long as it is lower semicontinuous and displacement convex JKO defines a unique curve of maximal slope and this curve is the solution to the evolution equation.


Quantum optimal transport

The state is a density matrix σ\sigma and the entropy is Von Neumann S(σ)=Tr(σlogσ)S(\sigma) = -\text{Tr}(\sigma \log \sigma) and gradient flows under a non-commutative W2W_2 metric from Carlen and Maas bring back the Lindblad master equation and quantum decoherence and dissipation turn out to be transport processes with probability flux moving between energy levels.


Fisher-Rao vs. Wasserstein

There are two ways to put manifold structure on the space of measures and they look at the problem from different angles.

Fisher-Rao comes from information geometry and builds distance out of local KL divergence and it is vertical and it picks up probability changes at each point but ignores distances between points and the entropy gradient flow is tρ=ρ(logρ+1)\partial_t \rho = -\rho(\log \rho + 1) which is pure exponential decay.

Wasserstein builds distance out of physically moving mass and it is horizontal and the entropy gradient flow is tρ=Δρ\partial_t \rho = \Delta \rho which is the heat equation.

The practical split is that natural gradient which is Fisher gets used for parameter estimation and the Wasserstein gradient gets used for generative modeling and interpolation and each one fits its own job.


Key developments in Wasserstein gradient flows

The lineage of this field runs through a handful of decisive contributions and each one unlocked the next move. Schrodinger (1926) first noticed the link between diffusion and entropic interpolation and the idea came back decades later as Schrodinger bridges and McKean (1966) proved propagation of chaos for interacting particle systems and nailed down the mean-field limit that justifies modern particle methods and McCann (1997) introduced displacement convexity which is the right notion of convexity for functionals on measure space and it kicked out the naive and incorrect linear interpolation of densities.

The watershed was the 1998 paper by Jordan and Kinderlehrer and Otto which showed that the Fokker-Planck equation is a gradient flow of entropy in Wasserstein space and laid down the JKO scheme as a variational time-stepping method and Otto (2001) then worked out the full Riemannian calculus on P2\mathcal{P}_2 and gave the field its computational backbone and Benamou and Brenier (2000) contributed the dynamic fluid mechanics formulation of optimal transport and tied the Lagrangian and Eulerian perspectives together and more recently Liu and Wang (2016) introduced SVGD and pulled Wasserstein gradient flow ideas into practical machine learning inference.


Concepts at a glance

The continuity equation tρ+div(ρv)=0\partial_t \rho + \text{div}(\rho v) = 0 says mass is conserved and it is the fundamental constraint on any flow of probability and everything else builds on top of it. Displacement convexity asks for convexity along Wasserstein geodesics and not along linear interpolations and this is the condition that hands you unique minimizers and contractive flows. The JKO scheme is a variational time-discretization where each step solves a proximal problem weighing energy reduction against transport cost and the limit gives back the continuous PDE.

The McKean-Vlasov equation describes density evolution under combined potential and internal and interaction forces and it is the general form from which heat and Fokker-Planck and porous medium equations drop out as special cases. Otto calculus is the Riemannian calculus on P2\mathcal{P}_2 that makes all of this precise and it defines gradients and Hessians and curvature for functionals on measure space. Propagation of chaos is the theorem that NN interacting particles become independent in the NN \to \infty limit and it justifies particle-based methods and the Stein operator maps the score function to a velocity field in SVGD and the Wasserstein manifold (P2,W2)(\mathcal{P}_2, W_2) is the infinite-dimensional Riemannian manifold on which all of these flows live.


References

1. Jordan, R., Kinderlehrer, D., & Otto, F. (1998). “The variational formulation of the Fokker-Planck equation”. 2. Otto, F. (2001). “The geometry of dissipative evolution equations”. 3. McCann, R. J. (1997). “A convexity principle for interacting gases”. 4. Ambrosio, L., Gigli, N., & Savare, G. (2008). “Gradient Flows…”. 5. Villani, C. (2009). “Optimal Transport: Old and New”. 6. Bakry, D., et al. (2014). “Analysis and Geometry of Markov Diffusion Operators”. 7. Liu, Q., & Wang, D. (2016). “Stein Variational Gradient Descent”. 8. Benamou, J. D., & Brenier, Y. (2000). “Fluid mechanics solution to OT”.