Created
April 18, 2020 21:06
-
-
Save gilesvangruisen/4a8a4a8ff11af45dec3bc9587abcbc47 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _highest_density_interval(self, pmf, p=.9, title=''): | |
# If we pass a DataFrame, just call this recursively on the columns | |
if(isinstance(pmf, pd.DataFrame)): | |
return pd.DataFrame([self._highest_density_interval(pmf[col], title=str(col)) for col in pmf], | |
index=pmf.columns) | |
# Broadcast the probability distribution to an artificial set of samples by | |
# repeating each index value N times where N = probability sample_precision | |
sample_precision = 1000000 | |
samples_repeats = np.array(pmf.values * sample_precision)#.astype(int) | |
samples = np.repeat(pmf.index, np.array(pmf.values * sample_precision).astype(int)) | |
# Get HDI | |
hdi = pymc3.stats.hpd(samples, credible_interval=p) | |
low = hdi[0] | |
high = hdi[1] | |
return pd.Series([low, high], index=['Rt_low', 'Rt_high']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment