Last active
November 1, 2017 09:59
-
-
Save yukoba/876d51bbeb3b0f7dcc53a40d17a43b13 to your computer and use it in GitHub Desktop.
混合正規分布をEMアルゴリズムを使わずに直接勾配法でパラメータを求めるが、パラメータに事前分布をつけて正しく収束させる
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 混合正規分布をEMアルゴリズムを使わずに直接勾配法でパラメータを求める。 | |
# | |
# 【パラメータの事前分布】 | |
# u1, u2: 正規分布 平均=データの平均 標準偏差1000 | |
# s1, s2: 指数分布 θ = 10000 | |
# pi: 正規分布 平均0.5 標準偏差100。ただし、定義域(0~1)から出たら非常に小さな事前確率にする。 | |
# 事前分布があることにより、おかしなパラメータに行かない。 | |
import autograd | |
import autograd.numpy as np | |
epsilon = 1e-8 | |
data = np.array([-0.39, 0.12, 0.94, 1.67, 1.76, 2.44, 3.72, 4.28, 4.92, 5.53, | |
0.06, 0.48, 1.01, 1.68, 1.80, 3.25, 4.12, 4.60, 5.28, 6.22]) | |
data_mean = data.mean() | |
data_var = data.var() | |
def loss(p): | |
def norm_dist(u, s, x): | |
return np.exp(-((x - u) ** 2 / (2 * s))) / np.sqrt(2 * np.pi * s) | |
u1, u2, s1, s2, pi = p | |
p_loss = (pi - 0.5) ** 2 / (2 * 100 ** 2) if 0 <= pi <= 1 else (pi - 0.5) ** 2 # piの事前分布 | |
p_loss += (u1 - data_mean) ** 2 / (2 * 1000 ** 2) | |
p_loss += (u2 - data_mean) ** 2 / (2 * 1000 ** 2) | |
p_loss += s1 / 10000 | |
p_loss += s2 / 10000 | |
return p_loss - np.mean(np.log((1 - pi) * norm_dist(u1, s1, data) + pi * norm_dist(u2, s2, data))) | |
loss_grad = autograd.grad(loss) | |
np.random.shuffle(data) | |
p, r = np.array([data[0], data[1], data_var, data_var, 0.5]), np.array([epsilon] * 5) | |
for _ in range(10000): | |
print(loss(p), p) | |
d = loss_grad(p) | |
r += d ** 2 # AdaGrad | |
p_new = p - 0.3 * d / np.sqrt(r) | |
if np.allclose(p, p_new, epsilon): | |
break | |
p = p_new |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment