Created
November 13, 2020 17:46
-
-
Save redwrasse/dd5dd4924129d338b3a5ab6f6ac74d1b to your computer and use it in GitHub Desktop.
Toy reinforcement learning on an array
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# array_reinforcement_learning.py | |
""" | |
array_reinforcement_learning | |
~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
Reinforcement learning is performed on a 1-dimensional | |
finite state space ("array") of k elements: | |
S = {1,...,k} | |
There are two possible actions: move right (a = 1), or move left (a = -1), | |
(except at the boundaries s = 1 or s = k). This defines the action | |
space: | |
A = {-1, 1} | |
The reward for moving to the right is +1, and the reward for | |
moving to the left is -1. At the boundary the reward for attempting | |
to traverse outside the array is 0, and the state returned upon attempting | |
such action is the original state. | |
Define a function f: S * A -> S | |
representing deterministic state traversal for a given action. | |
Also define a function g: S * A -> R (real line) for the reward | |
for a given action in a given state. | |
The value iteration algorithm is used. | |
The resulting optimal policy is to move to the right in all states, | |
except in the last (rightmost) one, where the optimal policy requires | |
moving back to the next most rightmost state, then moving again to the | |
rightmost state. So the optimal policy will result in cycling | |
between the last two states; for any initial state i, the sequence | |
traversed for the optimal policy is: | |
i, i+1, i+2, ...., k-1, k, k-1, k, k-1, k (etc). | |
""" | |
from collections import defaultdict | |
class ArrayReinforcementLearning(object): | |
k = 5 # number of states | |
gamma = 0.1 # gamma factor in value iteration algorithm | |
tolerance = 0.01 # tolerance for value iteration convergence for each state | |
# f and g are state and reward maps, respectively | |
f = {(1, -1) : 1, (1, 1) : 2, (2, -1) : 1, (2, 1) : 3, | |
(3, -1) : 2, (3, 1) : 4, (4, -1) : 3, (4, 1) : 5, | |
(5, -1) : 4, (5, 1) : 5} | |
g = {(1, -1) : 0, (1, 1): 1, (2, -1) : -1, (2, 1) : 1, | |
(3, -1) : -1, (3, 1) : 1, (4, -1) : -1, (4, 1) : 1, | |
(5, -1) : -1, (5, 1) : 0} | |
states = [i + 1 for i in range(k)] | |
actions = [-1, 1] | |
def __init__(self): | |
""" Initialize with an array of state values | |
all zero. | |
v[s] = V(s): = value of state s | |
q[(s,a)] = Q(s,a) := value of action a in state s. | |
""" | |
self.v = [0.0] * ArrayReinforcementLearning.k | |
self.q = defaultdict(lambda: 0.0) | |
def learn(self): | |
""" Performs the value iteration algorithm. Returns the | |
optimal state values.""" | |
prev_v = [-10000.0] * ArrayReinforcementLearning.k | |
while not all( | |
[abs(prev_v[i] - self.v[i]) < ArrayReinforcementLearning.tolerance | |
for i in range(len(self.v))]): | |
prev_v = self.v | |
for s in ArrayReinforcementLearning.states: | |
max_value = 0.0 | |
max_value_action = -1 | |
for a in ArrayReinforcementLearning.actions: | |
self.q[(s,a)] = ArrayReinforcementLearning.g[(s,a)] + \ | |
ArrayReinforcementLearning.gamma * \ | |
self.v[ArrayReinforcementLearning.f[(s,a)]-1] | |
if self.q[(s,a)] > max_value: | |
max_value = self.q[(s,a)] | |
max_value_action = a | |
self.v[s-1] = max_value | |
return self.v | |
def optimal_policy(self): | |
"""Returns the optimal policy pi: S -> A | |
as a map <s, a>. """ | |
pi = {} | |
for s in ArrayReinforcementLearning.states: | |
max_value = 0.0 | |
max_value_action = -1 | |
for a in ArrayReinforcementLearning.actions: | |
res = ArrayReinforcementLearning.g[(s,a)] + \ | |
ArrayReinforcementLearning.gamma * \ | |
self.v[ArrayReinforcementLearning.f[(s,a)]-1] | |
if res > max_value: | |
max_value_action = a | |
pi[s] = max_value_action | |
return pi | |
def main(): | |
arl = ArrayReinforcementLearning() | |
v = arl.learn() | |
print "~~~~~~~~~~~~~~~~~~~~~~~~" | |
print "OPTIMAL STATE VALUES: " | |
print v | |
print "OPTIMAL POLICY: " | |
pi = arl.optimal_policy() | |
print pi | |
print "~~~~~~~~~~~~~~~~~~~~~~~~" | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
todo: redo in javascript with ui