Created
March 8, 2020 04:17
-
-
Save AcrylicShrimp/74a7493b736c2e30f4a9306ffe05f535 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import random | |
import torch | |
from torch.distributions.dirichlet import Dirichlet | |
from game import Game | |
class MCTSNode: | |
def __init__(self, state, parent=None, action=None): | |
self.state = state # board | |
self.n = 0 # num. of visits | |
self.w = 0 # total reward(accumulative) | |
self.q = 0 # average reward(w / n) | |
self.childs = {} # action -> child node | |
self.parent = parent # parent node | |
self.action = action # action (integer) | |
@property | |
def is_leaf(self): | |
""" | |
Tests whether self is leaf or not. | |
""" | |
return len(self.childs) == 0 | |
def calc_ucb(self, c_puct): | |
""" | |
Calculates the UCB value of self. | |
UCB = q + U | |
U = c_puct * sqrt(N) / (1 + n) | |
""" | |
return self.q + c_puct * (math.sqrt(self.parent.n) if self.parent is not None else 0) / (1 + self.n) | |
def select_most(self): | |
""" | |
Returns a child node that has highest q value. | |
""" | |
if self.is_leaf: | |
return None | |
# Gets childs. | |
childs = [(child.q, child) | |
for action, child in self.childs.items()] | |
# Shuffle them - to randomly select a child when there're many childs that have same visit count. | |
random.shuffle(childs) | |
# Sort them and returns first child. | |
return sorted(childs, key=lambda x: x[0], reverse=True)[0][1] | |
def select_leaf(self): | |
""" | |
Returns a child node that is leaf node. | |
""" | |
if self.is_leaf: | |
return self | |
# Gets childs. | |
childs = [(child.calc_ucb(1.), child) | |
for action, child in self.childs.items()] | |
# Shuffle them - to randomly select a child when there're many childs that have same UCB value. | |
random.shuffle(childs) | |
# Sort them and returns "recursive" result. | |
# IMPORTANT: Recursive call here!!! | |
return sorted(childs, key=lambda x: x[0], reverse=True)[0][1].select_leaf() | |
def expand(self): | |
""" | |
Generates all possible childs if self is leaf node. | |
""" | |
if not self.is_leaf: | |
return | |
if self.state.is_terminated: | |
return | |
actions = Game.possible_actions(self.state) | |
# Gets all possible actions and creates child nodes for each action. | |
for action in actions: | |
self.childs[action] = \ | |
MCTSNode( | |
Game.next_state(self.state, action), # next state | |
self, # parent node | |
action) # action | |
def backup(self): | |
""" | |
Simulates game once and propagate back the result. | |
""" | |
state = self.state | |
# Random rollout. | |
while not state.is_terminated: | |
state = Game.next_state( | |
state, | |
random.choice(Game.possible_actions(state))) | |
node = self | |
while node is not None: | |
node.n += 1 | |
# WTF? Why the reward should be negative here? | |
node.w += -state.reward * (1 if state.turn == | |
node.state.turn else -1) | |
node.q = node.w / node.n | |
node = node.parent | |
class MCTS: | |
def __init__(self): | |
self.root = MCTSNode(Game.init()) | |
def step(self): | |
""" | |
Makes MCTS one step more. | |
1. Selects a leaf. | |
2. Expands it(if needed). | |
3. Selects a leaf once more(because step #2 may expanded that's children). | |
4. Simulates and back-ups the results. | |
""" | |
leaf = self.root.select_leaf() | |
leaf.expand() | |
leaf = leaf.select_leaf() | |
leaf.backup() | |
def select_most(self): | |
""" | |
Selects an available action that has the highest Q value. | |
""" | |
return self.root.select_most().action | |
def place(self, action): | |
""" | |
Moves root node and detaches it. | |
""" | |
if action not in self.root.childs: | |
self.root.expand() | |
if action not in self.root.childs: | |
raise RuntimeError('given action is not legal') | |
self.root = self.root.childs[action] | |
self.root.parent = None | |
self.root.action = None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment