Last active
December 6, 2022 14:16
-
-
Save oknoorap/5e88555877e9ee2fda6b87f9e9f77534 to your computer and use it in GitHub Desktop.
DQN Agent Time-Series in Javascript
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ml5 from 'ml5'; | |
import { TimeSeries, DataView } from 'pondjs'; | |
import { DQN } from 'rl-js-dqn'; | |
const data = await fetchTimeSeriesData(); | |
const timeseries = new TimeSeries({ | |
name: 'timeseries', | |
columns: ['time', 'value'], | |
points: data.map(d => [d.timestamp, d.value]) | |
}); | |
const view = new DataView(); | |
view.addColumn('value', timeseries.column('value')); | |
const processedSeries = view.resample({ | |
period: '1h', | |
aggregation: 'avg' | |
}).toJSON(); | |
const environment = { | |
// Define the possible actions the agent can take | |
actions: ['buy', 'hold', 'sell'], | |
// Function to compute the reward for a given action | |
computeReward: action => { | |
// Calculate the reward based on the current state of the environment | |
// and the selected action, using a specific formula or logic | |
const reward = calculateReward(currentState, action); | |
return reward; | |
}, | |
// Function to compute the next state of the environment | |
computeNextState: action => { | |
// Use the current state and selected action to determine the next state | |
const nextState = calculateNextState(currentState, action); | |
return nextState; | |
}, | |
// Other parameters and settings for the environment | |
maxSteps: 1000, | |
initialState: processedSeries[0] | |
}; | |
const agent = new DQN({ | |
environment: environment, | |
// Other parameters and settings for the DQN agent | |
hiddenLayers: [32, 32], | |
gamma: 0.9, | |
epsilon: 0.1 | |
}); | |
// Use the trained agent to make predictions on the test data | |
const predictions = []; | |
let totalReward = 0; | |
for (let i = 0; i < processedTestSeries.length; i++) { | |
const currentState = processedTestSeries[i]; | |
const action = agent.act(currentState); | |
const reward = environment.computeReward(action); | |
totalReward += reward; | |
predictions.push(action); | |
} | |
// Compute evaluation metrics | |
const accuracy = calculateAccuracy(predictions, testData); | |
const roi = calculateROI(totalReward); | |
console.log(`Test accuracy: ${accuracy}`); | |
console.log(`Total ROI: ${roi}`); | |
const model = { | |
agent: agent, | |
environment: environment | |
}; | |
await ml5.save(model, 'time-series-model.json'); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const tf = require("@tensorflow/tfjs-node"); | |
// Define the DQN agent | |
const agent = { | |
// Create the model | |
model: tf.sequential(), | |
// Set the learning rate | |
learningRate: 0.01, | |
// Define the discount factor | |
discountFactor: 0.95, | |
// Define the exploration factor | |
explorationFactor: 0.1, | |
// Initialize the model | |
init: function () { | |
this.model.add( | |
tf.layers.dense({ units: 32, inputShape: [8], activation: "relu" }) | |
); | |
this.model.add(tf.layers.dense({ units: 4, activation: "linear" })); | |
this.model.compile({ loss: "meanSquaredError", optimizer: "adam" }); | |
}, | |
// Select an action for a given state | |
selectAction: async function (state) { | |
// Choose a random action with probability equal to the exploration factor | |
if (Math.random() < this.explorationFactor) { | |
return Math.floor(Math.random() * 4); | |
} | |
// Otherwise, choose the action with the highest predicted Q-value | |
const qValues = this.model.predict(state); | |
return tf.argMax(qValues, 1).dataSync()[0]; | |
}, | |
// Update the model using a given batch of experiences | |
update: async function (batch) { | |
// Create the input and target tensors | |
const inputs = []; | |
const targets = []; | |
for (let experience of batch) { | |
const state = experience[0]; | |
const action = experience[1]; | |
const nextState = experience[2]; | |
const reward = experience[3]; | |
const done = experience[4]; | |
// Compute the Q-value for the current state | |
const qValues = this.model.predict(state); | |
// If the episode has ended, the Q-value for the next state is 0 | |
let nextQValue = 0; | |
if (!done) { | |
// Otherwise, use the predicted Q-values for the next state to compute the target Q-value | |
const nextQValues = this.model.predict(nextState); | |
nextQValue = tf.max(nextQValues).dataSync()[0]; | |
} | |
// Update the target Q-value for the given action | |
targets.push(qValues.dataSync()); | |
targets[targets.length - 1][action] = | |
reward + this.discountFactor * nextQValue; | |
// Use the current state as the input | |
inputs.push(state.dataSync()); | |
} | |
// Train the model on the inputs and targets | |
await this.model.fit(tf.tensor2d(inputs), tf.tensor2d(targets), { | |
epochs: 1, | |
}); | |
}, | |
}; | |
// Initialize the DQN agent | |
agent.init(); | |
// Generate some synthetic data for training | |
const xs = tf.randomNormal([100, 8]); | |
const ys = tf.randomNormal([100, 4]); | |
// Train the agent using the data | |
for (let i = 0; i < 100; i++) { | |
// Select an action for the current state | |
const action = await agent.selectAction(xs[i]); | |
// Compute the reward and next state based on the action | |
const reward = computeReward(action); | |
const nextState = computeNextState(xs[i], action); | |
// Update the replay buffer | |
replayBuffer.push([xs[i], action, nextState, reward, false]); | |
if (replayBuffer.length > replayBufferMaxSize) { | |
replayBuffer.shift(); | |
} | |
// Sample a random batch of experiences from the replay buffer | |
const batch = []; | |
for (let j = 0; j < 32; j++) { | |
const index = Math.floor(Math.random() * replayBuffer.length); | |
batch.push(replayBuffer[index]); | |
} | |
// Update the agent using the experience batch | |
agent.update(batch); | |
// Decrease the exploration factor over time | |
agent.explorationFactor *= 0.99; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Function to calculate the reward for a given action | |
function calculateReward(state, action) { | |
// Get the current value and the next value from the state | |
const currentValue = state.get('value'); | |
const nextValue = state.next().get('value'); | |
let reward = 0; | |
// Calculate the reward based on the action and the change in value | |
if (action === 'buy') { | |
if (nextValue > currentValue) { | |
reward = nextValue - currentValue; | |
} else { | |
reward = -1; | |
} | |
} else if (action === 'hold') { | |
if (nextValue > currentValue) { | |
reward = 0.1; | |
} else if (nextValue === currentValue) { | |
reward = 0.01; | |
} else { | |
reward = -0.1; | |
} | |
} else if (action === 'sell') { | |
if (nextValue < currentValue) { | |
reward = currentValue - nextValue; | |
} else { | |
reward = -1; | |
} | |
} | |
return reward; | |
} | |
// Function to calculate the next state based on the current state and action | |
function calculateNextState(state, action) { | |
// Get the next state from the current state | |
const nextState = state.next(); | |
return nextState; | |
} | |
// Function to calculate the accuracy of predictions | |
function calculateAccuracy(predictions, trueValues) { | |
let numCorrect = 0; | |
for (let i = 0; i < predictions.length; i++) { | |
if (predictions[i] === trueValues[i]) { | |
numCorrect++; | |
} | |
} | |
return numCorrect / predictions.length; | |
} | |
// Function to calculate the return on investment | |
function calculateROI(totalReward) { | |
return totalReward / initialInvestment; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment