Skip to content

Instantly share code, notes, and snippets.

@luispedro
Last active May 10, 2020 11:40
Show Gist options
  • Save luispedro/1f423ca4304fa3c6b10b0e6ecde34bee to your computer and use it in GitHub Desktop.
Save luispedro/1f423ca4304fa3c6b10b0e6ecde34bee to your computer and use it in GitHub Desktop.
# %matplotlib qt
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import style
style.use('default')
TOTAL_POP = 100_000
MAX_ITERS = 100_000
rho = 0.5
gamma = 0.25
delta = 0.25
beta = 0.25 # Not in their paper, but this seems to fit
def simul(N, xs, sigma):
# We start with 50 infected individuals, 25 in each class
S = N - 50
E = np.array([25., 25.])
I = np.array([25., 25.])
R = np.array([0., 0.])
Evol = []
for _ in range(MAX_ITERS):
EIeff = np.sum(xs * (rho * E + I) * N/TOTAL_POP)
for xi,xf in enumerate(xs):
lam = beta / N[xi] * EIeff
dS = - lam * xf * S[xi]
dE = lam * xf * S[xi] - delta * E[xi]
dI = delta * E[xi] - gamma * I[xi]
dR = gamma * I[xi]
S[xi] += dS
E[xi] += dE
I[xi] += dI
R[xi] += dR
dS0 = sigma * (S[1] - S[0])
S += [dS0, -dS0]
Evol.append((S.copy(), E.copy(), I.copy(), R.copy()))
# Do at least 1 year, then stop once nobody is getting infected.
if len(Evol) > 365 and np.abs(Evol[-1][0] - Evol[-2][0]).sum() < 1:
break
return np.array(Evol)
for title,args in [
('Traditional', (np.array([TOTAL_POP/2, TOTAL_POP/2], float), np.array([1,1], float), 0)),
('Two classes (no big diff)', (np.array([TOTAL_POP/2, TOTAL_POP/2], float), np.array([0.5,1.5], float), 0)),
('Two classes (minority)', (np.array([TOTAL_POP/5* 4, TOTAL_POP/5], float), np.array([0.25,4.0], float), 0)),
('Two classes (minority; 1% shift)', (np.array([TOTAL_POP/5* 4, TOTAL_POP/5], float), np.array([0.25,4.0], float), .01)),
('Two classes (minority; 0.1% shift)', (np.array([TOTAL_POP/5* 4, TOTAL_POP/5], float), np.array([0.25,4.0], float), .001)),
]:
N,xs,s = args
mu_x = np.sum(N/np.sum(N)*xs)
std_x = np.sum( N/np.sum(N)*(xs-mu_x)**2)
Evol = simul(*args)
fig,ax = plt.subplots()
ax.plot(Evol[:,0,:].sum(1)/TOTAL_POP, label='S')
ax.plot(Evol[:,1,:].sum(1)/TOTAL_POP, label='E')
ax.plot(Evol[:,2,:].sum(1)/TOTAL_POP, label='I')
ax.plot(Evol[:,3,:].sum(1)/TOTAL_POP, label='R')
ax.legend(loc='best')
ax.set_xlim(0, 365)
ax.set_title(title + ' (CV={})'.format(std_x/mu_x))
sns.despine(fig, trim=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment