Skip to content

Instantly share code, notes, and snippets.

@sandeshbhusal
Created October 23, 2023 03:26
Show Gist options
  • Save sandeshbhusal/fe2981bff567dce58a11def2a4c8bd70 to your computer and use it in GitHub Desktop.
Save sandeshbhusal/fe2981bff567dce58a11def2a4c8bd70 to your computer and use it in GitHub Desktop.
A simple implementation of gridworld. There is a certain modification - stepping into state '9' gets a +3 reward, and every action from state '9' goes to state '3'. This can be changed in the transition map as required.
from typing import *
class State(object):
'''
Terminal states do not change their values
'''
def __init__(self, id: int, value: int, transitions: List[int], is_terminal: bool) -> None:
self.id = id
self.value = value
self.transitions = transitions
self.terminal = is_terminal
def get_value(self):
if self.terminal:
return 0
else:
return self.value
def set_value(self, newvalue):
if not self.terminal:
self.value = newvalue
def recalculate_value_and_transitions(self):
transitions: List[int] = self.transitions
return
# Transition order: left, right, top, down
transitions = {
0: [0, 1, 0, 4],
1: [0, 2, 1, 5],
2: [1, 3, 2, 6],
3: [2, 3, 3, 7],
4: [4, 5, 0, 8],
5: [4, 6, 1, 9],
6: [5, 7, 2, 10],
7: [6, 7, 3, 11],
8: [8, 9, 4, 12],
9: [3, 3, 3, 3],
10: [9, 11, 6, 14],
11: [10, 11, 7, 15],
12: [12, 13, 8, 12],
13: [12, 14, 9, 13],
14: [13, 15, 10, 14],
15: [14, 15, 11, 15]
}
class gridworld(object):
def __init__(self, num_states: int):
self.num_states = num_states
self.states: List[State] = []
def generate_states(self, transitions):
for k in transitions:
terminal = True if k == 0 or k == 15 else False
state = State(k, 0, transitions[k], terminal)
self.states.append(state)
def get_state(self, id: int) -> State:
if id is None:
return None
return self.states[id]
def step_reward(self, id: int) -> int:
if id is not None:
if id == 9:
return 3
else:
return -1
return -1
def recalculate_values(self):
# Recalculate values of each node in gridworld.
# First, get the possible transitions. If a transition is marked as None,
# It means the policy rejects the transition.
recalculated_values = {}
for state in self.states:
left, right, top, bottom = [self.get_state(i) for i in state.transitions]
total_div = 4 - state.transitions.count(None)
# print(total_div)
new_value = 0
# Bruteforce-ish way.
if left is not None:
new_value += (1 / total_div) * (self.step_reward(left.id) + left.get_value())
if right is not None:
new_value += (1 / total_div) * (self.step_reward(right.id) + right.get_value())
if top is not None:
new_value += (1 / total_div) * (self.step_reward(top.id) + top.get_value())
if bottom is not None:
new_value += (1 / total_div) * (self.step_reward(bottom.id) + bottom.get_value())
# Recalculate the value for the state at this point.
recalculated_values[state.id] = new_value
for state in recalculated_values:
# print(f"state {state}, old value {self.get_state(state).get_value()}, new value: {recalculated_values[state]}")
self.get_state(state).set_value(recalculated_values[state])
def update_policy(self):
for state in self.states:
to_states = [self.get_state(state_id) for state_id in state.transitions]
max_val = max([state.get_value() if state is not None else -10000 for state in to_states])
if (max_val == -10000):
print("Impossible - more more transitions left on state", state.id)
new_transitions = []
for to_state in to_states:
if to_state is None:
new_transitions.append(None)
else:
if to_state.get_value() == max_val:
new_transitions.append(to_state.id)
else:
new_transitions.append(None)
# print(f"{state.id} {new_transitions}")
state.transitions = new_transitions
def value_policy_iteration(self):
self.recalculate_values()
self.update_policy()
def __str__(self) -> str:
# for every state, ordered by ID, print it.
states = self.states.copy()
sorted(states, key = lambda k: k.id)
rval = ""
for i in range(0, 4):
rval += '\n'
for j in range(0, 4):
id = 4*i + j
state = states[id]
rval += f" {state.value:.2f} "
rval += '\n\n'
# Now print transition list.
for state in self.states:
rval += "\n"
if (state.id == 9):
rval += "9: 3"
continue;
else:
rval += str(state.id) + ": "
for (i, transition) in enumerate(state.transitions):
if transition is not None:
if i == 0:
rval += "left "
if i == 1:
rval += "right "
if i == 2:
rval += "up "
if i == 3:
rval += "down "
rval += "\n"
rval += "-" * 30
return rval
world = gridworld(len(transitions))
world.generate_states(transitions=transitions)
for i in range(0, 6):
if (i == 5):
print(world)
world.value_policy_iteration()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment