Created
November 24, 2023 09:21
-
-
Save MatsuuraKentaro/c7457d8a207c4e631e0377b8d9efa8e6 to your computer and use it in GitHub Desktop.
『Pythonではじめる数理最適化』の7章「商品推薦のための興味のスコアリング」をStanで解く
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
data { | |
int I; // number of data | |
int R; // number of Rcen | |
int F; // number of Freq | |
array[I] int Rcen; // value of Rcen | |
array[I] int Freq; // value of Freq | |
array[I] int N; | |
array[I] int PV; | |
} | |
parameters { | |
matrix<lower=0>[R, F] dx; | |
} | |
transformed parameters { | |
matrix[R, F] x; | |
matrix[R, F] q; | |
x[1,1] = 5 - dx[1,1]; | |
for (r in 2:R) { | |
x[r,1] = x[r-1,1] - dx[r,1]; | |
} | |
for (f in 2:F) { | |
x[1,f] = x[1,f-1] - dx[1,f]; | |
} | |
for (r in 2:R) { | |
for (f in 2:F) { | |
x[r,f] = min([x[r-1,f], x[r,f-1]]) - dx[r,f]; | |
} | |
} | |
q[1:R,1:F] = inv_logit(x[1:R,1:F]); | |
} | |
model { | |
for (i in 1:I) { | |
PV[i] ~ binomial(N[i], q[Rcen[i], Freq[i]]); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas | |
import cmdstanpy | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from plotnine import ggplot, aes, theme, element_text, geom_ribbon, geom_line, geom_point, labs | |
rf_df = pandas.read_csv('input/rf_df.csv') | |
fit = cmdstanpy.from_csv('output/result') | |
q_ms = fit.stan_variable(var='q') | |
Freq = rf_df.freq.unique().tolist() | |
Rcen = rf_df.rcen.unique().tolist() | |
Z = [np.median(q_ms[:, rcen-1, 7-freq]) for freq in Freq for rcen in Rcen] | |
Z = np.array(Z).reshape((len(Freq), len(Rcen))) | |
X, Y = np.meshgrid(Rcen, Freq) | |
fig = plt.figure(dpi=250) | |
ax = fig.add_subplot(111, projection='3d', xlabel='rcen', ylabel='freq', zlabel='pred_prob') | |
ax.plot_wireframe(X, Y, Z) | |
plt.savefig('output/q_median.png') | |
plot_f = 5 | |
qua = np.quantile(q_ms[:, :, 7-plot_f], [0.025, 0.25, 0.50, 0.75, 0.975], axis=0) | |
d_est = pandas.DataFrame(np.column_stack([np.arange(1,8), qua.T]), \ | |
columns=['rcen', '2.5%', '25%', '50%', '75%', '97.5%']) | |
rf_df_at_plot_f = rf_df[rf_df.freq == plot_f] | |
p = (ggplot() | |
+ theme(text=element_text(size=18)) | |
+ geom_ribbon(d_est, aes(x='rcen', ymin='2.5%', ymax='97.5%'), fill='blue', alpha=1/6) | |
+ geom_ribbon(d_est, aes(x='rcen', ymin='25%', ymax='75%'), fill='blue', alpha=2/6) | |
+ geom_line(d_est, aes(x='rcen', y='50%'), color='blue', size=1) | |
+ geom_point(rf_df_at_plot_f, aes(x='rcen', y='prob'), size=1) | |
+ labs(y='prob') | |
) | |
p.save(filename=f'output/q_at_f_{plot_f}.png', dpi=300, width=5, height=4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
rcen | freq | N | pv | prob | |
---|---|---|---|---|---|
1 | 1 | 19602 | 245 | 0.012498724619936742 | |
1 | 2 | 3323 | 132 | 0.039723141739392114 | |
1 | 3 | 1120 | 81 | 0.07232142857142858 | |
1 | 4 | 539 | 36 | 0.06679035250463822 | |
1 | 5 | 285 | 36 | 0.12631578947368421 | |
1 | 6 | 177 | 20 | 0.11299435028248588 | |
1 | 7 | 120 | 21 | 0.175 | |
2 | 1 | 19126 | 112 | 0.0058559029593223885 | |
2 | 2 | 3162 | 67 | 0.02118912080961417 | |
2 | 3 | 1001 | 27 | 0.026973026973026972 | |
2 | 4 | 459 | 26 | 0.05664488017429194 | |
2 | 5 | 302 | 20 | 0.06622516556291391 | |
2 | 6 | 162 | 16 | 0.09876543209876543 | |
2 | 7 | 94 | 6 | 0.06382978723404255 | |
3 | 1 | 22596 | 138 | 0.006107275624004248 | |
3 | 2 | 3616 | 84 | 0.023230088495575223 | |
3 | 3 | 1161 | 46 | 0.03962101636520241 | |
3 | 4 | 582 | 31 | 0.05326460481099656 | |
3 | 5 | 279 | 11 | 0.03942652329749104 | |
3 | 6 | 185 | 10 | 0.05405405405405406 | |
3 | 7 | 119 | 6 | 0.05042016806722689 | |
4 | 1 | 24385 | 133 | 0.005454172647119131 | |
4 | 2 | 4035 | 62 | 0.015365551425030979 | |
4 | 3 | 1305 | 32 | 0.024521072796934867 | |
4 | 4 | 597 | 28 | 0.04690117252931323 | |
4 | 5 | 300 | 11 | 0.03666666666666667 | |
4 | 6 | 185 | 7 | 0.03783783783783784 | |
4 | 7 | 109 | 2 | 0.01834862385321101 | |
5 | 1 | 25363 | 111 | 0.0043764538895241095 | |
5 | 2 | 3999 | 62 | 0.015503875968992248 | |
5 | 3 | 1225 | 29 | 0.0236734693877551 | |
5 | 4 | 603 | 9 | 0.014925373134328358 | |
5 | 5 | 274 | 6 | 0.021897810218978103 | |
5 | 6 | 173 | 5 | 0.028901734104046242 | |
5 | 7 | 98 | 3 | 0.030612244897959183 | |
6 | 1 | 26034 | 116 | 0.004455711761542598 | |
6 | 2 | 3757 | 37 | 0.009848283204684588 | |
6 | 3 | 1183 | 29 | 0.024513947590870666 | |
6 | 4 | 511 | 10 | 0.019569471624266144 | |
6 | 5 | 235 | 2 | 0.00851063829787234 | |
6 | 6 | 121 | 3 | 0.024793388429752067 | |
6 | 7 | 79 | 2 | 0.02531645569620253 | |
7 | 1 | 25611 | 109 | 0.004255983756979423 | |
7 | 2 | 3522 | 32 | 0.009085746734809767 | |
7 | 3 | 996 | 14 | 0.014056224899598393 | |
7 | 4 | 385 | 9 | 0.023376623376623377 | |
7 | 5 | 220 | 2 | 0.00909090909090909 | |
7 | 6 | 98 | 2 | 0.02040816326530612 | |
7 | 7 | 43 | 0 | 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas | |
import cmdstanpy | |
d = pandas.read_csv('input/rf_df.csv') | |
data = {'I':len(d), 'R':7, 'F':7, | |
'Rcen':d.rcen, 'Freq':8 - d.freq, 'N':d.N, 'PV':d.pv} | |
model = cmdstanpy.CmdStanModel(stan_file='model/model.stan') | |
fit = model.sample(data=data, seed=123) | |
fit.save_csvfiles('output/result') | |
fit.summary().to_csv('output/fit-summary.csv') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment