Last active January 29, 2022 16:23
import numpy as np, pandas as pd, matplotlib.pyplot as plt
from collections import namedtuple
from jax import random, lax
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, util, init_to_sample
from numpyro.infer.mcmc import MCMCKernel
ABCState = namedtuple("ABCState", ["z", "rng_key"])
class ABC(MCMCKernel):
def __init__(self, model, data, threshold, summary_statistic, max_attempts_per_sample
self._model = model
self._data = data
self._predictive = util.Predictive(self._model, num_samples=1)
self._threshold = jnp.array(threshold)
self._summary_statistic = summary_statistic
self._max_attempts_per_sample = max_attempts_per_sample
def sample_field(self):
return "z"
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
assert rng_key.ndim == 1, "only non-vectorized, for now"
proposal = self._predictive(rng_key, *model_args, **model_kwargs)
return ABCState(proposal, rng_key)
def sample(self, state, model_args, model_kwargs):
def while_condition_func(val):
distance, rng_key, proposal, n = val
return jnp.logical_and(distance > self._threshold,
n < self._max_attempts_per_sample)
def while_body_func(val):
distance, rng_key, proposal, n = val
rng_key, sample_key = random.split(rng_key)
proposal = self._predictive(sample_key, *model_args, **model_kwargs)
# FIXME: need to resample the values of the observed vars here
distance = self._summary_statistic(self._data, proposal)
return (distance, rng_key, proposal, n+1)
distance, rng_key, proposal, n = \
(jnp.inf, # distance
state.rng_key, # rng_key
state.z, # proposal
0 # iteration
proposal['theta'] = jnp.where(distance <= self._threshold, proposal['theta'], state.z['theta'])
return ABCState(proposal, rng_key)
def my_model():
with numpyro.plate('I', 4):
theta = numpyro.sample('theta', dist.Uniform(-10, 10))
def sum_exceeds_threshold(threshold, proposal):
return jnp.where(proposal['theta'].sum() > threshold, 0, jnp.inf)
def my_run(model):
rng_key = random.PRNGKey(12345)
sum_lower_bound = jnp.array(-1)
kernel = ABC(model,
data=sum_lower_bound, threshold=1,
mcmc = MCMC(kernel, num_warmup=0, num_samples=100, thinning=1)
posterior_samples = mcmc.get_samples()
plt.plot(posterior_samples['theta'][:,0,:].sum(axis=1), label='trace')
plt.axhline(sum_lower_bound, linestyle='dashed', color='k', label='lower bound')
