Stein's Paradox & Empirical Bayes

1. The sample mean in high dimensions

Suppose you see XN(θ,1)X \sim \mathcal{N}(\theta, 1) in one dimension and you want a guess for θ\theta. The natural guess is just XX itself and no other estimator beats it across the board.

In higher dimensions the picture shifts hard. Suppose you see XN(θ,I)X \sim \mathcal{N}(\theta, I) in dimension d=100d = 100 and you try the same guess θ^=X\hat{\theta} = X. It has a built-in flaw because it overshoots θ\|\theta\| every time.

By the law of large numbers, as dd \to \infty:

Xθ2d1almost surely\frac{\|X - \theta\|^2}{d} \to 1 \quad \text{almost surely}

Meanwhile, the squared norm of XX satisfies:

X2=θ+Z2=θ2+Z2+2θZ\|X\|^2 = \|\theta + Z\|^2 = \|\theta\|^2 + \|Z\|^2 + 2\theta \cdot Z EX2=θ2+d\mathbb{E}\|X\|^2 = \|\theta\|^2 + d

So XX sits on a shell of radius about θ2+d\sqrt{\|\theta\|^2 + d} and that is way farther from the origin than θ\theta is. The noise piece Z2\|Z\|^2 piles up around dd and the cross term θZ\theta \cdot Z is only O(d)O(\sqrt{d}) and the noise piece lands in directions orthogonal to θ\theta because of how concentration of measure works in high dimensions.

The obvious fix is to pull XX back toward the origin and Stein proved in 1956 that this fix is not just nice to have but actually required. d3    δ0(X)=X is inadmissible.d \ge 3 \implies \delta_0(X) = X \text{ is inadmissible.}

2. The James-Stein estimator

Set up risk under squared error loss.

R(θ,δ)=Eθ[δ(X)θ2]R(\theta, \delta) = \mathbb{E}_\theta [ \|\delta(X) - \theta\|^2 ]

For δ0(X)=X\delta_0(X) = X the risk is dd no matter what θ\theta is.

The James-Stein estimator uses a shrinkage factor that depends on the data.

δJS(X)=(1d2X2)X\delta_{JS}(X) = \left( 1 - \frac{d-2}{\|X\|^2} \right) X

Theorem (Stein’s Paradox). For all d3d \ge 3 and all θRd\theta \in \mathbb{R}^d,

R(θ,δJS)<dR(\theta, \delta_{JS}) < d

The improvement is largest near the origin, and at θ=0\theta = 0 we have

R(0,δJS)=2R(0, \delta_{JS}) = 2

In dimension 100 the MLE has risk 100 and James-Stein has risk 2 and that is a 98 percent cut.

To see why shrinkage is the natural move geometrically, think about the ideal or oracle shrinkage factor. Given X=θ+ϵX = \theta + \epsilon the scalar cc that makes θcX2\|\theta - cX\|^2 smallest is

cideal=θXX2θ2θ2+d=1dθ2+d1dX2c_{ideal} = \frac{\theta \cdot X}{\|X\|^2} \approx \frac{\|\theta\|^2}{\|\theta\|^2 + d} = 1 - \frac{d}{\|\theta\|^2 + d} \approx 1 - \frac{d}{\|X\|^2}

The oracle shrinkage factor matches the James-Stein form and dd gets swapped for d2d - 2 to handle the estimation error in the denominator.

3. Stein’s Lemma

To work out the risk of a nonlinear estimator δ(X)=X+g(X)\delta(X) = X + g(X) without knowing θ\theta we need the following identity.

Lemma (Stein, 1981). Let XNd(θ,I)X \sim \mathcal{N}_d(\theta, I) and let g:RdRdg: \mathbb{R}^d \to \mathbb{R}^d be weakly differentiable with Eg(X)<\mathbb{E} | \nabla \cdot g(X) | < \infty. Then

E[(Xθ)Tg(X)]=E[g(X)]\mathbb{E} [ (X - \theta)^T g(X) ] = \mathbb{E} [ \nabla \cdot g(X) ]

Proof. We go component by component. Because the coordinates are independent the joint density factors as p(x)=ϕ(xjθj)p(x) = \prod \phi(x_j - \theta_j) and each component can be handled on its own.

For the ii-th component,

E[(Xiθi)gi(X)]=Rd(xiθi)gi(x)(j=1dϕ(xjθj))dx\mathbb{E}[(X_i - \theta_i) g_i(X)] = \int_{\mathbb{R}^d} (x_i - \theta_i) g_i(x) \left( \prod_{j=1}^d \phi(x_j - \theta_j) \right) dx

Pull out the integral over xix_i,

(xiθi)ϕ(xiθi)gi(x)dxi\int_{-\infty}^\infty (x_i - \theta_i) \phi(x_i - \theta_i) g_i(x) dx_i

The Gaussian kernel obeys ϕ(z)=zϕ(z)\phi'(z) = -z \phi(z) and so (xiθi)ϕ(xiθi)=xiϕ(xiθi)(x_i - \theta_i)\phi(x_i - \theta_i) = -\frac{\partial}{\partial x_i} \phi(x_i - \theta_i). Plugging this in gives

=gi(x)(xiϕ(xiθi))dxi= \int_{-\infty}^\infty g_i(x) \left( - \frac{\partial}{\partial x_i} \phi(x_i - \theta_i) \right) dx_i

Integrating by parts with u=gi(x)u = g_i(x) and dv=ϕ(xiθi)dxidv = \phi'(x_i - \theta_i) dx_i,

=[gi(x)ϕ(xiθi)]+gixiϕ(xiθi)dxi= \left[ -g_i(x) \phi(x_i - \theta_i) \right]_{-\infty}^\infty + \int_{-\infty}^\infty \frac{\partial g_i}{\partial x_i} \phi(x_i - \theta_i) dx_i

The boundary term dies whenever gig_i grows slower than ex2/2e^{x^2/2} and that covers every polynomial-growth estimator and we are left with

=E[giXi]= \mathbb{E} \left[ \frac{\partial g_i}{\partial X_i} \right]

Adding over i=1,,di = 1, \ldots, d gives

iE[(Xiθi)gi(X)]=iE[giXi]=E[g(X)]\sum_i \mathbb{E}[(X_i - \theta_i) g_i(X)] = \sum_i \mathbb{E} \left[ \frac{\partial g_i}{\partial X_i} \right] = \mathbb{E} [ \nabla \cdot g(X) ]

\square

4. Deriving the James-Stein estimator

Write δ(X)=X+g(X)\delta(X) = X + g(X) and open up the risk,

R(θ,δ)=EX+g(X)θ2R(\theta, \delta) = \mathbb{E} \| X + g(X) - \theta \|^2 =EXθ2+Eg(X)2+2EXθ,g(X)= \mathbb{E} \|X-\theta\|^2 + \mathbb{E} \|g(X)\|^2 + 2 \mathbb{E} \langle X-\theta, g(X) \rangle

The first term equals dd and by Stein’s Lemma the cross term equals 2E[g(X)]2 \mathbb{E}[\nabla \cdot g(X)] and so

R(θ,δ)=d+E[g(X)2+2g(X)]R(\theta, \delta) = d + \mathbb{E} \left[ \|g(X)\|^2 + 2 \nabla \cdot g(X) \right]

The quantity g2+2g\|g\|^2 + 2 \nabla \cdot g is Stein’s Unbiased Risk Estimate and people call it SURE. To beat the MLE we need a nonzero gg that pushes this expression below zero.

Try g(X)=cX2Xg(X) = - \frac{c}{\|X\|^2} X. It blows up at X=0X = 0 but that event has probability zero under a continuous distribution.

To work out the divergence note that since gi(X)=cXi(X12++Xd2)1g_i(X) = - c X_i (X_1^2 + \dots + X_d^2)^{-1} the partial derivative is

giXi=c(1X22Xi2X4)\frac{\partial g_i}{\partial X_i} = -c \left( \frac{1}{\|X\|^2} - \frac{2 X_i^2}{\|X\|^4} \right)

Adding over all the components gives

g(X)=c(dX22X2)=c(d2)X2\nabla \cdot g(X) = -c \left( \frac{d}{\|X\|^2} - \frac{2}{\|X\|^2} \right) = - \frac{c(d-2)}{\|X\|^2}

The squared norm of gg is

g(X)2=c2X4X2=c2X2\|g(X)\|^2 = \frac{c^2}{\|X\|^4} \|X\|^2 = \frac{c^2}{\|X\|^2}

Plugging into SURE gives

ΔR=E[c22c(d2)X2]\Delta R = \mathbb{E} \left[ \frac{c^2 - 2c(d-2)}{\|X\|^2} \right]

The risk improvement is quadratic in cc and its minimum sits at c=d2c^* = d-2 and that gives

ΔRmin=(d2)2E[1X2]\Delta R_{min} = - (d-2)^2 \mathbb{E} \left[ \frac{1}{\|X\|^2} \right]

For d3d \ge 3 this is strictly negative and so

R(θ,δJS)=d(d2)2E[1X2]<dR(\theta, \delta_{JS}) = d - (d-2)^2 \mathbb{E} \left[ \frac{1}{\|X\|^2} \right] < d

When θ=0\theta = 0 the distribution X2χd2\|X\|^2 \sim \chi^2_d gives E[1/χd2]=1/(d2)\mathbb{E}[1/\chi^2_d] = 1/(d-2) and so

R(0,δJS)=d(d2)21d2=d(d2)=2R(0, \delta_{JS}) = d - (d-2)^2 \cdot \frac{1}{d-2} = d - (d-2) = 2

The paradox is that a single estimator beats the MLE at the same time for every single value of the parameter.

5. The Bayesian derivation

SURE shows you the risk reduction is real and the Bayesian view explains why the estimator has this exact shape.

Put a Gaussian prior on θ\theta.

Level 1: XθNd(θ,I)X | \theta \sim \mathcal{N}_d(\theta, I)

Level 2: θNd(0,AI)\theta \sim \mathcal{N}_d(0, A \cdot I)

By conjugacy the posterior is Gaussian and the precisions add,

Σpost1=I+1AI=1+AAI\Sigma_{post}^{-1} = I + \frac{1}{A}I = \frac{1 + A}{A} I

The posterior mean is

θ^Bayes=A1+AX=(111+A)X\hat{\theta}_{Bayes} = \frac{A}{1+A} X = \left( 1 - \frac{1}{1+A} \right) X

Writing B=11+AB = \frac{1}{1+A} the Bayes estimate is (1B)X(1-B)X and that is linear shrinkage toward the origin.

The snag is that AA is unknown but the marginal distribution XN(0,(1+A)I)X \sim \mathcal{N}(0, (1+A)I) gives us enough to pin down AA from the data,

X21+Aχd2\frac{\|X\|^2}{1+A} \sim \chi^2_d

Since E[1/χd2]=1/(d2)\mathbb{E}[1/\chi^2_d] = 1/(d-2) the quantity d2X2\frac{d-2}{\|X\|^2} is an unbiased estimator of B=11+AB = \frac{1}{1+A} and

E[d2X2]=11+A=B\mathbb{E} \left[ \frac{d-2}{\|X\|^2} \right] = \frac{1}{1+A} = B

Plugging this empirical estimate of the shrinkage factor into the Bayes rule gives you the James-Stein estimator exactly.

The empirical Bayes story makes what the estimator is doing clear because it learns the signal-to-noise ratio straight from the data. When X2\|X\|^2 is big compared to dd the data says there is strong signal and the estimator barely shrinks and when X2\|X\|^2 is small the data is mostly noise and the estimator shrinks hard.


6. Admissibility and Brownian motion

There is a deep link between whether the MLE is admissible and how Brownian motion recurs and Brown worked this out in 1971.

Think about estimators of the form δ(X)=X+ϕ(X)\delta(X) = X + \nabla \phi(X). Through Tweedie’s formula any such estimator can be read as a formal Bayes estimator against π(θ)=eϕ(θ)\pi(\theta) = e^{\phi(\theta)}.

The risk improvement of δ\delta over XX is controlled by the Laplacian of m\sqrt{m} where mm is the marginal density and

ΔRΔm(X)m(X)\Delta R \propto \frac{\Delta \sqrt{m(X)}}{\sqrt{m(X)}}

If m\sqrt{m} is superharmonic (meaning Δm0\Delta \sqrt{m} \le 0) then the risk improves. The flat prior π=1\pi = 1 gives m=1m = 1 and Δ1=0\Delta 1 = 0 and that is the boundary case that matches the MLE.

The harmonic prior π(θ)θ(d2)\pi(\theta) \propto \|\theta\|^{-(d-2)} gives a superharmonic marginal for d3d \ge 3.

The algebra reflects something topological. In d=1d = 1 and d=2d = 2 Brownian motion is recurrent and keeps coming back to every neighborhood over and over and in d3d \ge 3 Brownian motion is transient and runs off to infinity.

The inadmissibility of the MLE in d3d \ge 3 lines up with probability mass leaking off to infinity under the flat prior and the James-Stein estimator makes up for that leak. Shrinkage is needed exactly because the surrounding space is high-dimensional enough for random walks to escape.


7. The positive-part correction

The standard James-Stein estimator has a failure mode. When X2<d2\|X\|^2 < d - 2 the shrinkage factor 1d2X21 - \frac{d-2}{\|X\|^2} goes negative and the estimate flips direction and that is obviously bad because the estimator shoves θ^\hat\theta over to the opposite side of the origin from XX.

Baranchik’s positive-part tweak clips the shrinkage factor,

δJS+(X)=max(0,1d2X2)X\delta_{JS+} (X) = \max \left( 0, 1 - \frac{d-2}{\|X\|^2} \right) X

This estimator dominates δJS\delta_{JS} but it is not itself admissible because it has non-analytic kinks at X2=d2\|X\|^2 = d - 2 and Brown showed that admissible estimators for exponential families have to be analytic.

Nobody knows a simple analytic estimator that beats JS+ and so the positive-part estimator is what people actually use in practice.


8. Exact risk of the positive-part estimator

Working out the exact risk of δJS+\delta_{JS+} takes some care because the estimator acts differently in two regions.

δJS+(X)=max(0,1d2X2)X\delta_{JS+}(X) = \max \left( 0, 1 - \frac{d-2}{\|X\|^2} \right) X

The difference between δJS\delta_{JS} and δJS+\delta_{JS+} is only nonzero when the shrinkage overshoots and that happens when X2<d2\|X\|^2 < d-2. In that region δJS(X)=(1d2X2)X\delta_{JS}(X) = (1 - \frac{d-2}{\|X\|^2})X and δJS+(X)=0\delta_{JS+}(X) = 0.

Let A={x:x2<d2}A = \{ x : \|x\|^2 < d-2 \}. The risk gap between the two estimators is

ΔR+=A((1d2x2)xθ20θ2)p(x)dx\Delta R_+ = \int_A \left( \| (1 - \frac{d-2}{\|x\|^2}) x - \theta \|^2 - \| 0 - \theta \|^2 \right) p(x) dx

Opening up the risk gap on AA using Stein’s lemma on the full space and handling the truncation gives

ΔR+=E[((1d2X2)2X22(1d2X2)(Xθ)+θ2θ2)1A]\Delta R_+ = \mathbb{E}\left[\left(\left(1 - \frac{d-2}{\|X\|^2}\right)^2 \|X\|^2 - 2\left(1 - \frac{d-2}{\|X\|^2}\right)(X \cdot \theta) + \|\theta\|^2 - \|\theta\|^2 \right)\mathbf{1}_A\right]

The cross term 2(1d2X2)(Xθ)-2(1 - \frac{d-2}{\|X\|^2})(X \cdot \theta) does not die on AA because the truncation breaks the symmetry that Stein’s lemma leans on over the full space. Even so you can check that ΔR+<0\Delta R_+ < 0 for every θ\theta by writing

(1d2x2)xθ2<θ2on A\| (1 - \frac{d-2}{\|x\|^2}) x - \theta \|^2 < \|\theta\|^2 \quad \text{on } A

because on AA the factor (1d2x2)(1 - \frac{d-2}{\|x\|^2}) is negative and (1d2x2)x<x|(1 - \frac{d-2}{\|x\|^2})| \cdot \|x\| < \|x\| and so the JS estimate shoots past the origin and makes a bigger error than just guessing θ\theta is zero.

The exact risk of δJS+\delta_{JS+} needs you to integrate against the non-central chi-square distribution Y=X2χd2(θ2)Y = \|X\|^2 \sim \chi^2_d(\|\theta\|^2) and that has a Poisson mixture form

fY(y)=k=0eλ/2(λ/2)kk!fχd+2k2(y)f_Y(y) = \sum_{k=0}^\infty \frac{e^{-\lambda/2} (\lambda/2)^k}{k!} f_{\chi^2_{d+2k}}(y)

where λ=θ2\lambda = \|\theta\|^2.

The risk of JS is

R(θ,δJS)=d(d2)2E[1X2]=d(d2)2k=0eλ/2(λ/2)kk!1d+2k2R(\theta, \delta_{JS}) = d - (d-2)^2 \mathbb{E}\left[\frac{1}{\|X\|^2}\right] = d - (d-2)^2 \sum_{k=0}^\infty \frac{e^{-\lambda/2} (\lambda/2)^k}{k!} \frac{1}{d+2k-2}

The extra improvement from JS+ means you integrate the full risk gap including the θ\theta-dependent cross term over the region X2<d2\|X\|^2 < d-2 and you have to do this numerically for general θ\theta. The qualitative takeaway that RJS+<RJSR_{JS+} < R_{JS} for every θ\theta falls out of the geometric argument above.

9. Simulation

The simulation below compares the risk of the MLE and James-Stein and the positive-part estimators as a function of θ\|\theta\|. The MLE risk is a flat line at dd and the shrinkage estimators get their biggest win near the origin.

import numpy as np import matplotlib.pyplot as plt from scipy.stats import noncentral_chisquare def simulate_risk(d=10, n_trials=5000, theta_norms=None): if theta_norms is None: theta_norms = np.linspace(0, 10, 20) risk_mle = [] risk_js = [] risk_plus = [] for r in theta_norms: # Create theta vector with norm r theta = np.zeros(d) theta[0] = r # Generate Data # X ~ N(theta, I). Shape (n_trials, d) X = np.random.randn(n_trials, d) + theta # Norms squared X_norm_sq = np.sum(X**2, axis=1) # MLE Error # ||X - theta||^2 loss_mle = np.sum( (X - theta)**2, axis=1 ) risk_mle.append(np.mean(loss_mle)) # JS Estimator # (1 - (d-2)/||X||^2) * X shrinkage = 1 - (d-2)/X_norm_sq Theta_JS = X * shrinkage[:, np.newaxis] loss_js = np.sum( (Theta_JS - theta)**2, axis=1 ) risk_js.append(np.mean(loss_js)) # JS+ Estimator shrinkage_plus = np.maximum(0, shrinkage) Theta_Plus = X * shrinkage_plus[:, np.newaxis] loss_plus = np.sum( (Theta_Plus - theta)**2, axis=1 ) risk_plus.append(np.mean(loss_plus)) # Plot plt.figure(figsize=(10, 6)) plt.plot(theta_norms, risk_mle, 'k--', label='MLE Risk (d)') plt.plot(theta_norms, risk_js, 'b-o', label='James-Stein Risk') plt.plot(theta_norms, risk_plus, 'r-', label='Positive Part JS') plt.axhline(d, color='gray', linestyle=':') plt.axhline(2, color='green', linestyle=':', label='Min Risk (at 0)') plt.xlabel('||Theta||') plt.ylabel('Risk (MSE)') plt.title(f'Stein\'s Paradox in {d} Dimensions') plt.legend() # plt.show() return theta_norms, risk_js # Theoretical Minimum Risk: # At theta=0, Risk = d - (d-2)^2 E[1/ChiSq_d]. # E[1/ChiSq_d] = 1/(d-2). # Risk(0) = d - (d-2) = 2. # Massive dependency! From d to 2.

10. The Efron-Morris baseball example

Efron and Morris (1975) ran shrinkage estimation on batting averages. They took 18 players’ first 45 at-bats and used them to predict how those players would do over the rest of the season.

The MLE uses each player’s own 45-at-bat average and the James-Stein estimator pulls every player toward the grand mean.

James-Stein cut the total squared prediction error by a factor of 3. Players who started unusually hot at batting .450 got pulled down toward the league average and players who started cold got pulled up and in both cases the shrunk estimates landed closer to the true end-of-season averages.

Shrinkage toward a common value is the same trick that underlies ridge regression which is Bayesian estimation with a Gaussian prior centered at zero. James-Stein can be read as ridge regression where the regularization parameter tunes itself from the data as λ=d2y2\lambda = \frac{d-2}{\|y\|^2}.


11. Implications

Stein’s paradox shows something basic about estimation in high dimensions. Treating parameters one at a time is worse than it should be even when they really are unrelated and borrowing strength across the ensemble by shrinking toward a common value cuts total risk because it trades a small amount of bias for a big cut in variance.

In modern machine learning this idea shows up everywhere. Regularization and dropout and weight decay are all forms of shrinkage that bias models toward simpler solutions because high-dimensional parameter spaces blow up estimation error.


Background

YearEventSignificance
1956Charles SteinProves inadmissibility of MLE in d3d \ge 3.
1961James & SteinConstruct the explicit James-Stein estimator.
1971Lawrence BrownLinks admissibility to recurrence of diffusions.
1973Efron & MorrisEmpirical Bayes interpretation (baseball paper).
1995Donoho & JohnstoneWavelet shrinkage (soft thresholding).
2006Candes & TaoCompressed sensing (L1 shrinkage).

Terminology

An estimator is admissible if no other estimator gets lower risk R(θ,δ)R(\theta, \delta) for every θ\theta at the same time. When estimator A gets R(θ,A)R(θ,B)R(\theta, A) \le R(\theta, B) for every θ\theta and beats it strictly somewhere then A dominates B.

The James-Stein estimator is a shrinkage estimator that beats the MLE for the multivariate normal mean in dimensions d3d \ge 3. It works by pulling estimates toward a central value and trading bias for lower variance. The positive-part tweak clips the shrinkage factor so it stays non-negative and that stops overshrinkage.

Stein’s Lemma is the integration-by-parts identity E[(Xθ)Tg(X)]=E[g(X)]\mathbb{E}[(X - \theta)^T g(X)] = \mathbb{E}[\nabla \cdot g(X)] for Gaussian XX and it lets us work out risk without knowing θ\theta. The empirical Bayes framework estimates prior hyperparameters straight from data and the James-Stein estimator drops out as an empirical Bayes procedure. A minimax estimator minimizes worst-case risk across all θ\theta.


References

1. Stein, C. (1956). “Inadmissibility of the usual estimator for the mean of a multivariate normal distribution”. The original paper establishing the paradox.

2. James, W., & Stein, C. (1961). “Estimation with quadratic loss”. Constructs the explicit estimator and computes its risk.

3. Efron, B., & Morris, C. (1973). “Stein’s estimation rule and its competitors, an empirical Bayes approach”. Develops the empirical Bayes interpretation and demonstrates practical applications.

4. Lehmann, E. L., & Casella, G. (2006). “Theory of Point Estimation”. Standard graduate reference for decision theory and point estimation.