Skip to content

Instantly share code, notes, and snippets.

@koozdra
Last active June 18, 2020 05:34
Show Gist options
  • Save koozdra/ed3b5456d79576f173ffebcc9afe3f28 to your computer and use it in GitHub Desktop.
Save koozdra/ed3b5456d79576f173ffebcc9afe3f28 to your computer and use it in GitHub Desktop.
post sign flow q-learning
shuffle_array = arr =>
arr
.map(a => [a, Math.random()])
.sort(([, lr], [, rr]) => lr - rr)
.map(([a]) => a);
user_action_for_state = (state, randy) => {
if (state === 'p' || state === 'ps' || state === 'psa' || state === 'psam') {
if (randy < 0.7) {
return primary_action_names[state[state.length - 1]];
} else {
return 'skip';
}
}
if (state.startsWith('a') || state.startsWith('m')) {
return 'skip';
}
if (randy < 0.5) {
return primary_action_names[state[state.length - 1]];
} else {
return 'skip';
}
};
const actions = ['p', 's', 'a', 'm'];
const rewards = {
share: 2,
skip: 1,
promote: 3,
member: 3,
action: 2
};
const primary_action_names = {
p: 'promote',
a: 'action',
s: 'share',
m: 'member'
};
possible_actions = (state, all_actions) =>
all_actions.filter(action => !state.includes(action));
get_q_value = (q_table, state, action) => q_table[state + action] || 0;
update_q_table_trajectory = (
q_table,
learning_rate,
discount_factor,
trajectory
) => {
trajectory.forEach(([state, action, reward]) => {
update_q_table(
q_table,
learning_rate,
discount_factor,
state,
action,
reward
);
});
};
update_q_table = (
q_table,
learning_rate,
discount_factor,
state,
action,
reward
) => {
const state_action = state + action;
const next_q_values = possible_actions(
state_action,
actions
).map(possible_action => get_q_value(q_table, state_action, possible_action));
max_next_q = next_q_values.length > 0 ? Math.max(...next_q_values) : 0;
current_q_value = get_q_value(q_table, state, action);
q_table[state_action] =
current_q_value +
learning_rate * (reward + discount_factor * max_next_q - current_q_value);
};
select_random_action = possibles =>
possibles[Math.floor(Math.random() * possibles.length)];
select_greedy_q_table = (q_table, possibles, state) => {
const q_values = possibles.map(action => get_q_value(q_table, state, action));
return possibles[q_values.indexOf(Math.max(...q_values))];
};
generate_trajectory = (q_table, state, acc) => {
if (state.length > 3) {
return acc;
}
const is_explore = Math.random() < exploration_rate;
const possibles = shuffle_array(possible_actions(state, actions));
agent_action = is_explore
? select_random_action(possibles)
: select_greedy_q_table(q_table, possibles, state);
next_state = state + agent_action;
user_page_action_name = user_action_for_state(next_state, Math.random());
reward = rewards[user_page_action_name];
return generate_trajectory(q_table, next_state, [
...acc,
[state, agent_action, reward]
]);
};
const iterations = 1000;
const trials = 100;
const learning_rate = 0.9;
const discount_factor = 0.9;
const exploration_rate = 0.01;
const q_table = {};
trial_rewards = [];
trial_paths = [];
Array.from({ length: trials }, () => {
trajectory_rewards = [];
trajectory_paths = [];
Array.from({ length: iterations }, () => {
const trajectory = generate_trajectory(q_table, '', []);
const total_reward = trajectory.reduce(
(accumulator, [, , reward]) => accumulator + reward,
0
);
const pretty_trajectory = trajectory
.map(([state, action, reward]) => `${state}(${action})[${reward}]`)
.join(', ');
trajectory_rewards.push(total_reward);
const [final_state, final_action] = trajectory[trajectory.length - 1];
trajectory_paths.push(final_state + final_action);
update_q_table_trajectory(
q_table,
learning_rate,
discount_factor,
trajectory
);
});
trial_rewards.push(trajectory_rewards);
trial_paths.push(trajectory_paths);
});
// const output_rewards = trial_rewards.map(a => a.join(',')).join(',');
// const output_paths = trial_paths.map(a => a.join(',')).join(',');
const window_size = 100;
const all_keys = {};
const stat_frequencies = [];
Array.from({ length: iterations - window_size }, (x, i) => {
let window_accumulator = [];
trial_paths.forEach(trajectory_paths => {
const trajectory_window = trajectory_paths.slice(i, i + window_size);
// window_accumulator.push(trajectory_paths.slice(i, i + window_size));
window_accumulator = [...window_accumulator, ...trajectory_window];
});
// console.log(window_accumulator);
const frequency = {};
window_accumulator.forEach(stat => {
all_keys[stat] = true;
curr_frequency_value = frequency[stat];
frequency[stat] = curr_frequency_value ? curr_frequency_value + 1 : 1;
});
stat_frequencies.push(frequency);
});
// console.log(stat_frequencies);
frequency_stat_keys = Object.keys(all_keys);
header = frequency_stat_keys.join(',');
console.log(header);
stat_frequencies.forEach(frequency_stat => {
const output_line = [];
frequency_stat_keys.forEach(frequency_stat_key => {
const frequency_stat_value = frequency_stat[frequency_stat_key];
output_line.push(frequency_stat_value || 0);
});
console.log(output_line.join(','));
});
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment