Thomas originally posted this article here at http://twiecki.github.io
Hierarchical models are underappreciated. Hierarchies exist in many data sets and modeling them appropriately adds a boat load of statistical power (the common metric of statistical power). I provided an introduction to hierarchical models in a previous blog post: Best Of Both Worlds: Hierarchical Linear Regression in PyMC3″, written with Danne Elbers. See also my interview with FastForwardLabs where I touch on these points.
Here I want to focus on a common but subtle problem when trying to estimate these models and how to solve it with a simple trick. Although I was somewhat aware of this trick for quite some time it just recently clicked for me. We will use the same hierarchical linear regression model on the Radon data set from the previous blog post, so if you are not familiar, I recommend to start there.
I will then use the intuitions we’ve built up to highlight a subtle point about expectations vs modes (i.e. the MAP). Several talks by Michael Betancourt have really expanded my thinking here.
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pymc3 as pm
import pandas as pd
import theano
import seaborn as sns
sns.set_style('whitegrid')
np.random.seed(123)
data = pd.read_csv('../data/radon.csv')
data['log_radon'] = data['log_radon'].astype(theano.config.floatX)
county_names = data.county.unique()
county_idx = data.county_code.values
n_counties = len(data.county.unique())
The intuitive specification
Usually, hierachical models are specified in a centered way. In a regression model, individual slopes would be centered around a group mean with a certain group variance, which controls the shrinkage:
with pm.Model() as hierarchical_model_centered:
# Hyperpriors for group nodes
mu_a = pm.Normal('mu_a', mu=0., sd=100**2)
sigma_a = pm.HalfCauchy('sigma_a', 5)
mu_b = pm.Normal('mu_b', mu=0., sd=100**2)
sigma_b = pm.HalfCauchy('sigma_b', 5)
# Intercept for each county, distributed around group mean mu_a
# Above we just set mu and sd to a fixed value while here we
# plug in a common group distribution for all a and b (which are
# vectors of length n_counties).
a = pm.Normal('a', mu=mu_a, sd=sigma_a, shape=n_counties)
# Intercept for each county, distributed around group mean mu_a
b = pm.Normal('b', mu=mu_b, sd=sigma_b, shape=n_counties)
# Model error
eps = pm.HalfCauchy('eps', 5)
# Linear regression
radon_est = a[county_idx] + b[county_idx] * data.floor.values
# Data likelihood
radon_like = pm.Normal('radon_like', mu=radon_est, sd=eps, observed=data.log_radon)
# Inference button (TM)!
with hierarchical_model_centered:
hierarchical_centered_trace = pm.sample(draws=5000, tune=1000, njobs=4)[1000:]
I have seen plenty of traces with terrible convergences but this one might look fine to the unassuming eye. Perhaps sigma_b
has some problems, so let’s look at the Rhat:
print('Rhat(sigma_b) = {}'.format(pm.diagnostics.gelman_rubin(hierarchical_centered_trace)['sigma_b']))
Not too bad — well below 1.01. I used to think this wasn’t a big deal but Michael Betancourt in his StanCon 2017 talk makes a strong point that it is actually very problematic. To understand what’s going on, let’s take a closer look at the slopes b
and their group variance (i.e. how far they are allowed to move from the mean) sigma_b
. I’m just plotting a single chain now.
fig, axs = plt.subplots(nrows=2)
axs[0].plot(hierarchical_centered_trace.get_values('sigma_b', chains=1), alpha=.5);
axs[0].set(ylabel='sigma_b');
axs[1].plot(hierarchical_centered_trace.get_values('b', chains=1), alpha=.5);
axs[1].set(ylabel='b');
sigma_b
seems to drift into this area of very small values and get stuck there for a while. This is a common pattern and the sampler is trying to tell you that there is a region in space that it can’t quite explore efficiently. While stuck down there, the slopes b_i
become all squished together. We’ve entered The Funnel of Hell (it’s just called the funnel, I added the last part for dramatic effect).
The Funnel of Hell (and how to escape it)
Let’s look at the joint posterior of a single slope b
(I randomly chose the 75th one) and the slope group variance sigma_b
.
x = pd.Series(hierarchical_centered_trace['b'][:, 75], name='slope b_75')
y = pd.Series(hierarchical_centered_trace['sigma_b'], name='slope group variance sigma_b')
sns.jointplot(x, y, ylim=(0, .7));
This makes sense, as the slope group variance goes to zero (or, said differently, we apply maximum shrinkage), individual slopes are not allowed to deviate from the slope group mean, so they all collapose to the group mean.
While this property of the posterior in itself is not problematic, it makes the job extremely difficult for our sampler. Imagine a Metropolis-Hastings exploring this space with a medium step-size (we’re using NUTS here but the intuition works the same): in the wider top region we can comfortably make larger jumps to explore the space efficiently. However, once we move to the narrow bottom region we can change b_75
and sigma_b
only by tiny amounts. This causes the sampler to become trapped in that region of space. Most of the proposals will be rejected because our step-size is too large for this narrow part of the space and exploration will be very inefficient.
You might wonder if we could somehow choose the step-size based on the denseness (or curvature) of the space. Indeed that’s possible and it’s called Riemannian HMC. It works very well but is quite costly to run. Here, we will explore a different, simpler method.
Finally, note that this problem does not exist for the intercept parameters a
. Because we can determine individual intercepts a_i
with enough confidence, sigma_a
is not small enough to be problematic. Thus, the funnel of hell can be a problem in hierarchical models, but it does not have to be. (Thanks to John Hall for pointing this out).
Reparameterization
If we can’t easily make the sampler step-size adjust to the region of space, maybe we can adjust the region of space to make it simpler for the sampler? This is indeed possible and quite simple with a small reparameterization trick, we will call this the non-centered version.
with pm.Model() as hierarchical_model_non_centered:
# Hyperpriors for group nodes
mu_a = pm.Normal('mu_a', mu=0., sd=100**2)
sigma_a = pm.HalfCauchy('sigma_a', 5)
mu_b = pm.Normal('mu_b', mu=0., sd=100**2)
sigma_b = pm.HalfCauchy('sigma_b', 5)
# Before:
# a = pm.Normal('a', mu=mu_a, sd=sigma_a, shape=n_counties)
# Transformed:
a_offset = pm.Normal('a_offset', mu=0, sd=1, shape=n_counties)
a = pm.Deterministic("a", mu_a + a_offset * sigma_a)
# Before:
# b = pm.Normal('b', mu=mu_b, sd=sigma_b, shape=n_counties)
# Now:
b_offset = pm.Normal('b_offset', mu=0, sd=1, shape=n_counties)
b = pm.Deterministic("b", mu_b + b_offset * sigma_b)
# Model error
eps = pm.HalfCauchy('eps', 5)
radon_est = a[county_idx] + b[county_idx] * data.floor.values
# Data likelihood
radon_like = pm.Normal('radon_like', mu=radon_est, sd=eps, observed=data.log_radon)
Pay attention to the definitions of a_offset
, a
, b_offset
, and b
and compare them to before (commented out). What’s going on here? It’s pretty neat actually. Instead of saying that our individual slopes b
are normally distributed around a group mean (i.e. modeling their absolute values directly), we can say that they are offset from a group mean by a certain value (b_offset
; i.e. modeling their values relative to that mean). Now we still have to consider how far from that mean we actually allow things to deviate (i.e. how much shrinkage we apply). This is where sigma_b
makes a comeback. We can simply multiply the offset by this scaling factor to get the same effect as before, just under a different parameterization. For a more formal introduction, see e.g. Betancourt & Girolami (2013).
Critically, b_offset
and sigma_b
are now mostly independent. This will become more clear soon. Let’s first look at if this transform helped our sampling:
# Inference button (TM)!
with hierarchical_model_non_centered:
hierarchical_non_centered_trace = pm.sample(draws=5000, tune=1000, njobs=4)[1000:]
pm.traceplot(hierarchical_non_centered_trace, varnames=['sigma_b']);