Skip to content

Instantly share code, notes, and snippets.

@koozdra
Created November 26, 2019 17:58
Show Gist options
  • Save koozdra/9715592fbdce35ecf124a1f57f867e2e to your computer and use it in GitHub Desktop.
Save koozdra/9715592fbdce35ecf124a1f57f867e2e to your computer and use it in GitHub Desktop.
A basic implementation of bandits for a demo
const times = require('lodash/fp/times');
const map = require('lodash/fp/map');
const sample = require('lodash/fp/sample');
const maxBy = require('lodash/fp/maxBy');
const find = require('lodash/fp/find');
const join = require('lodash/fp/join');
const flow = require('lodash/fp/flow');
const get = require('lodash/fp/get');
const sum = require('lodash/fp/sum');
const constant = require('lodash/fp/constant');
const variantExpectedValues = [0.1, 0.8, 0.5];
const startDollars = 100;
const valueOnWin = 2;
const vaueOnLoss = 0;
const results = [];
const iterations = 1000;
const tests = 1;
const randBoolean = () => Math.random() >= 0.5;
const testData = map(expectedValue => ({
pulls: 0,
rewards: 0,
expectedValue
}))(variantExpectedValues);
const rankVariant = variant =>
variant.pulls == 0 ? 0 : variant.rewards / variant.pulls;
const isVariantUnvisited = variant => variant.pulls === 0;
const pickRandomVariant = sample;
const pickBestVariant = maxBy(rankVariant);
const findUnvisitedVariant = find(isVariantUnvisited);
const visitVariant = variant => (variant.pulls += 1);
const selectVariantEpsilonGreedy = variants => {
const isExploit = Math.random() > 0.1;
return (isExploit ? pickBestVariant : pickRandomVariant)(variants);
};
const selectVariantUCB = variants => {
const totalVisits = flow(map(get('pulls')), sum)(variants);
return maxBy(variant => {
const estimatedExpectedValue = rankVariant(variant);
return (
estimatedExpectedValue +
Math.sqrt(2 * (Math.log(totalVisits) / variant.pulls))
);
})(variants);
};
times(iteration => {
results[iteration] = map(constant(0))(variantExpectedValues);
})(iterations);
times(test => {
times(iteration => {
const previousVisits = [...results[Math.max(iteration - 1, 0)]];
const unvisitedVariant = findUnvisitedVariant(testData);
const selectedVariant = unvisitedVariant
? unvisitedVariant
: selectVariantEpsilonGreedy(testData);
// selectVariantUCB(testData);
const selectedVariantIndex = testData.indexOf(selectedVariant);
// console.log(selectedVariantIndex);
visitVariant(selectedVariant);
const isReward = Math.random() < selectedVariant.expectedValue;
if (isReward) {
selectedVariant.rewards += 1;
}
previousVisits[selectedVariantIndex] += 1;
results[iteration] = previousVisits;
})(iterations);
console.log(testData);
// console.log(results);
const output = flow(map(join(', ')), join('\n'))(results);
console.log(output);
const fs = require('fs');
fs.writeFile('/tmp/test.csv', output, function(err) {
if (err) {
return console.log(err);
}
console.log('The file was saved!');
});
})(tests);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment