Last active
January 21, 2021 19:45
-
-
Save yukoba/76cacfee5fb373d944f29a5693859e83 to your computer and use it in GitHub Desktop.
新型コロナウイルスの感染者数の予測
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Copyright (C) 2021 by Yu Kobayashi | |
Permission to use, copy, modify, and/or distribute this software for any purpose | |
with or without fee is hereby granted. | |
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH | |
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND | |
FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, | |
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS | |
OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER | |
TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF | |
THIS SOFTWARE. | |
""" | |
from datetime import date, timedelta | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from dateutil.relativedelta import relativedelta | |
from matplotlib.ticker import ScalarFormatter | |
from scipy import stats | |
# noinspection PyProtectedMember | |
from scipy.stats.mstats_basic import LinregressResult | |
# 2020年10月19日~2021年1月21日の東京都の新規感染者数 https://stopcovid19.metro.tokyo.lg.jp/ | |
initial_date = date(2020, 10, 19) | |
holidays = [ | |
date(2020, 11, 3), | |
date(2020, 11, 23), | |
date(2021, 1, 11), | |
] | |
data = np.array([ | |
78, | |
139, | |
145, | |
185, | |
186, | |
201, | |
124, | |
102, | |
158, | |
171, | |
220, | |
203, | |
215, | |
116, | |
87, | |
209, | |
122, | |
269, | |
242, | |
294, | |
189, | |
157, | |
293, | |
316, | |
392, | |
374, | |
352, | |
255, | |
180, | |
298, | |
485, | |
533, | |
522, | |
539, | |
391, | |
314, | |
186, | |
401, | |
481, | |
570, | |
561, | |
418, | |
311, | |
372, | |
500, | |
533, | |
449, | |
584, | |
327, | |
299, | |
352, | |
572, | |
602, | |
595, | |
621, | |
480, | |
305, | |
460, | |
660, | |
821, | |
664, | |
736, | |
556, | |
392, | |
563, | |
748, | |
888, | |
884, | |
949, | |
708, | |
481, | |
856, | |
944, | |
1337, | |
783, | |
814, | |
816, | |
884, | |
1278, | |
1591, | |
2447, | |
2392, | |
2268, | |
1494, | |
1219, | |
970, | |
1433, | |
1502, | |
2001, | |
1809, | |
1592, | |
1204, | |
1240, | |
1274, | |
1471, | |
]) | |
x = np.arange(len(data)) | |
y = np.log10(data) | |
exclude_indices = [int((holiday - initial_date) / timedelta(days=1)) + 1 for holiday in holidays] | |
def search_param() -> np.ndarray: | |
# 探索開始の初期値として回帰直線を使用します | |
result: LinregressResult = stats.linregress(x, y) | |
param_init = np.array([result.slope, result.intercept] + [0] * 6) | |
def evaluate(param: np.ndarray) -> float: | |
slope = param[0] | |
intercept = param[1] | |
weekly_difference = param[2:8] | |
weekly_difference = np.append(weekly_difference, [-weekly_difference.sum()]) # 総和は0 | |
diff = y - (slope * x + intercept + weekly_difference[x % 7]) | |
# 東京都は祝日の翌日に大きく減少します。祝日の翌日は評価値から除外します。 | |
diff[exclude_indices] = 0 | |
return np.abs(diff).sum() # 絶対誤差の和を最小化します | |
# 進化戦略 https://qiita.com/yukoba/items/ed40e0c4f4a27b73c6b8 | |
iter_count = 30000 | |
pop_size = 300 | |
children_size = 30 | |
param_len = len(param_init) | |
tau = 1.0 / np.sqrt(2.0 * param_len) | |
individuals = [(param_init, np.full([param_len], 0.1), evaluate(param_init))] | |
best = individuals[0] | |
best_ev = best[2] | |
for i in range(iter_count): | |
for _ in range(children_size): | |
# 交叉 | |
if np.random.rand() < 0.8: | |
ind0 = individuals[np.random.randint(len(individuals))] | |
ind1 = individuals[np.random.randint(len(individuals))] | |
r = np.random.randint(0, 2, [param_len]) | |
parent = (ind0[0] * r + ind1[0] * (1 - r), ind0[1] * r + ind1[1] * (1 - r)) | |
else: | |
parent = individuals[np.random.randint(len(individuals))] | |
# 突然変異 | |
strategies2 = parent[1] * np.exp(tau * np.random.randn(param_len)) | |
params2 = parent[0] + strategies2 * np.random.randn(param_len) | |
individuals.append((params2, strategies2, evaluate(params2))) | |
# 上位を選択 | |
individuals = sorted(individuals, key=lambda ind: ind[2])[:pop_size] | |
best = individuals[0] | |
# 最善が更新されたら出力します | |
if best[2] < best_ev: | |
best_ev = best[2] | |
print(i, best_ev, np.array_repr(best[0], max_line_width=1000)) | |
return best[0] | |
def plot_result(param: np.ndarray): | |
slope = param[0] | |
intercept = param[1] | |
weekly_difference = param[2:8] | |
weekly_difference = np.append(weekly_difference, [-weekly_difference.sum()]) # 総和は0 | |
# データから30日延長してグラフに表示 | |
x2 = np.arange(len(data) + 30) | |
x_date = [initial_date + relativedelta(days=int(i)) for i in x] | |
x2_date = [initial_date + relativedelta(days=int(i)) for i in x2] | |
# 週単位の差分込みのグラフ | |
plt.plot(x_date, 10 ** y, label='real') | |
plt.plot(x2_date, 10 ** (slope * x2 + intercept + weekly_difference[x2 % 7]), label='model') | |
plt.yscale("log") | |
plt.gca().yaxis.set_major_formatter(ScalarFormatter()) | |
plt.legend() | |
plt.grid() | |
plt.show() | |
# 週単位の差分を除去したグラフ | |
plt.plot(x_date, 10 ** (y - weekly_difference[x % 7]), label="real - weekly difference") | |
plt.plot(x2_date, 10 ** (slope * x2 + intercept), label="model - weekly difference") | |
plt.yscale("log") | |
plt.gca().yaxis.set_major_formatter(ScalarFormatter()) | |
plt.legend() | |
plt.grid() | |
plt.show() | |
param_best = search_param() | |
# param_best = np.array([0.01137184, 2.13545116, -0.20030776, -0.01955204, 0.04198086, 0.09855929, 0.0490876, 0.10989335]) | |
plot_result(param_best) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment