Importance trick for hierarchical models
In [19]:
import numpy as np
import matplotlib.pyplot as pl
import emcee
import scipy.misc as mi
from ipdb import set_trace as stop
%matplotlib inline
Let's generate the data. They are linear trends with slopes obtained from a N(0,1) distribution and with noise added. The number of cases is N and the number of events in each case is a random number between 1 and 5
In [79]:
N = 15
NPoints = np.random.randint(1, high=4, size=N)
sigma = 0.01
In [80]:
aTrue = np.random.normal(loc=0.0, scale=1.0, size=N)
In [81]:
xAll = []
yAll = []
for i in range(N):
x = np.random.rand(NPoints[i])
y = aTrue[i] * x + sigma * np.random.randn(NPoints[i])
In [83]:
f, ax = pl.subplots(nrows=4, ncols=4, figsize=(12,8))
ax = ax.flatten()
for i in range(N):
ax[i].plot(xAll[i], yAll[i], 'o')
ax[i].plot(xAll[i], aTrue[i]*xAll[i], color='red')
Now we do the MCMC sampling for each case. For this, we define a class that does this sampling using emcee.
In [84]:
class singleCase(object):
def __init__(self, x, y, noise):
self.x = x
self.y = y
self.noise = noise
self.upper = 5.0
self.lower = -5.0
def logPosterior(self, theta):
if ((theta < self.upper) & (theta > self.lower)):
model = theta * self.x
return -np.sum((self.y-model)**2 / (2.0*self.noise**2))
return -np.inf
def sample(self):
ndim, nwalkers = 1, 500
self.theta0 = np.asarray([1.0])
p0 = [self.theta0 + 0.01*np.random.randn(ndim) for i in range(nwalkers)]
self.sampler = emcee.EnsembleSampler(nwalkers, ndim, self.logPosterior)
self.sampler.run_mcmc(p0, 100)
In [85]:
samples = []
for i in range(N):
res = singleCase(xAll[i], yAll[i], sigma)
In [76]:
f, ax = pl.subplots(nrows=2, ncols=3, figsize=(12,8))
ax = ax.flatten()
for i in range(N):
ax[i].axhline(aTrue[i], color='red')
Now the importance sampling trick
In [77]:
nGrid = 20
mu = np.linspace(-1.0,1.0,nGrid)
s = np.linspace(0.1,2.0,nGrid)
p = np.ones((nGrid,nGrid))
for i in range(nGrid):
for j in range(nGrid):
for k in range(N):
p[i,j] *= np.mean(1.0 / s[j] * np.exp(-0.5*(samples[k] - mu[i])**2 / s[j]**2))
In [78]:
f, ax = pl.subplots(figsize=(8,8))
ax.contour(mu, s, p)
In [ ]: