Created
November 20, 2019 13:02
-
-
Save t-abe/e488e4c25c583045c1fb16a4bb49f37b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import typing | |
from collections import OrderedDict, Callable | |
import random | |
from copy import copy | |
import time | |
class Node(object): | |
def __init__(self, prev=None, action=(), player=None): | |
self.prev = prev | |
self.action: tuple = action | |
self.player = player | |
self.depth = 0 | |
if prev is not None: | |
self.depth = self.prev.depth + 1 | |
self.next = self.list_next() | |
self.utility = 0 | |
if self.is_terminal(): | |
self.utility = self.compute_utility() | |
def is_terminal(self): | |
return len(self.next) == 0 | |
def list_next(self): | |
betsize = 0.5 | |
if self.depth == 0: # depth=0 is root, next is dealing the first card | |
return [Node(self, ('dealt', c), 0) for c in (0, 1, 2)] | |
if self.depth == 1: # next is dealing the second card | |
return [Node(self, ('dealt', c), 1) for c in (0, 1, 2) if c != self.action[1]] | |
if self.depth == 2: # player 0's first action | |
return [Node(self, ('bet', 0), 0), Node(self, ('bet', betsize), 0)] | |
if self.depth == 3: # player 1's action | |
return [Node(self, ('bet', 0), 1), Node(self, ('bet', betsize), 1)] | |
if self.depth >= 4: | |
is_terminal = self.prev.action[1] >= self.action[1] | |
if is_terminal: | |
return [] | |
# check -> bet -> ? | |
next_p = int(not self.player) | |
return [Node(self, ('bet', 0), next_p), Node(self, ('bet', betsize), next_p)] | |
def is_root(self): | |
return self.prev is None | |
def is_fold(self): | |
assert self.player is not None | |
return self.prev.action[1] > self.action[1] | |
def is_call(self): | |
assert self.player is not None | |
return self.prev.action[1] == self.action[1] | |
def compute_utility(self): # of player 0 | |
u = -0.5 # ante | |
pot = 1 # ante | |
ptr = self | |
cards = [0, 0] | |
while ptr.prev is not None: | |
if ptr.action[0] == 'dealt': | |
cards[ptr.player] = ptr.action[1] | |
elif ptr.action[0] == 'bet': | |
pot += ptr.action[1] | |
if ptr.player == 0: | |
u -= ptr.action[1] | |
ptr = ptr.prev | |
win = False # flag of player 0 win | |
if self.is_fold() and self.player == 1: | |
win = True | |
if self.is_call() and cards[0] > cards[1]: | |
win = True | |
if win: | |
u += pot | |
return u | |
def trace_up_actions(self): | |
nodes = [self] | |
while nodes[-1].prev is not None: | |
nodes.append(nodes[-1].prev) | |
return [n.action for n in reversed(nodes) if n.action is not ()] | |
def get_infoset(self, player): | |
actions = self.trace_up_actions() | |
if player == 0: | |
actions[1] = ('dealt', -1) | |
elif player == 1: | |
actions[0] = ('dealt', -1) | |
return (player,) + tuple(actions) | |
class CFR(object): | |
def __init__(self): | |
self.regret_table = OrderedDict() # r[I][a] | |
self.cum_strategy_table = OrderedDict() # s[I][a] | |
self.profiles = [ | |
OrderedDict() | |
] | |
def __call__(self, node: Node, learning_player: int, t: int, reach_probs): | |
if node.is_terminal(): | |
u = node.utility | |
if learning_player == 1: | |
u = -u | |
return u | |
if node.next[0].action[0] == 'dealt': # chance node | |
# return self(random.choice(node.next), learning_player, t, reach_probs) | |
return sum([self(n, learning_player, t, reach_probs) for n in node.next]) / len(node.next) | |
player = node.next[0].player # player of NEXT action | |
infoset = node.get_infoset(player) | |
n_actions = len(node.next) | |
value = 0 | |
value_given_action = np.zeros((n_actions,), dtype=np.float32) | |
strategy_profile = self.profiles[t] | |
if infoset not in strategy_profile: | |
strategy_profile[infoset] = np.ones((n_actions,), dtype=np.float32) * 1. / n_actions | |
for j, n in enumerate(node.next): | |
action_prob = strategy_profile[infoset][j] | |
if player == 0: | |
value_given_action[j] = self(n, learning_player, t, | |
(action_prob * reach_probs[0], reach_probs[1])) | |
elif player == 1: | |
value_given_action[j] = self(n, learning_player, t, | |
(reach_probs[0], action_prob * reach_probs[1])) | |
value += action_prob * value_given_action[j] | |
if player == learning_player: | |
# TODO: Move to somewhere not to evaluate many times | |
if infoset not in self.regret_table: | |
self.regret_table[infoset] = np.zeros((n_actions,), dtype=np.float32) | |
if infoset not in self.cum_strategy_table: | |
self.cum_strategy_table[infoset] = np.zeros((n_actions,), dtype=np.float32) | |
self.regret_table[infoset] += reach_probs[1 - learning_player] * (value_given_action - value) | |
self.cum_strategy_table[infoset] += reach_probs[learning_player] * strategy_profile[infoset] | |
return value | |
def update_profile(self): | |
new_profile = copy(self.profiles[-1]) | |
for infoset in self.regret_table.keys(): | |
n_actions = self.regret_table[infoset].shape[0] | |
pos_reg = np.maximum(self.regret_table[infoset], 0) | |
pos_sum = np.sum(pos_reg) | |
if pos_sum <= 0: | |
new_profile[infoset] = np.ones(n_actions, dtype=np.float32) * 1. / n_actions | |
else: | |
new_profile[infoset] = pos_reg / pos_sum | |
self.profiles.append(new_profile) | |
def get_mes(self, player): | |
# Computes strategy minimizing current regret. | |
mes = copy(self.profiles[-1]) | |
for infoset in self.regret_table: | |
if infoset[0] == player: | |
mes_j = self.regret_table[infoset].argmax() | |
mes[infoset] = np.zeros_like(self.regret_table[infoset]) | |
mes[infoset][mes_j] = 1.0 | |
return mes | |
def main(): | |
tree = Node() | |
# traverse | |
next_nodes = [n for n in tree.next] | |
while len(next_nodes) > 0: | |
next_node = next_nodes.pop(0) | |
if next_node.is_terminal(): | |
print(next_node.trace_up_actions(), next_node.utility) | |
next_nodes += next_node.next | |
cfr = CFR() | |
timer = time.time() | |
for t in range(15000): | |
for learning_player in (0, 1): | |
cfr(tree, learning_player, t, (1., 1.)) | |
cfr.update_profile() | |
print(time.time() - timer) | |
cumst_list = sorted([(infoset, probs) | |
for infoset, probs in cfr.cum_strategy_table.items()]) | |
for infoset, probs in cumst_list: | |
print(infoset, probs / np.sum(probs)) | |
ev_calc = CFR() | |
for infoset, probs in cfr.cum_strategy_table.items(): | |
ev_calc.profiles[0][infoset] = probs / np.sum(probs) | |
evs = [ev_calc(tree, p, 0, [1, 1]) for p in (0, 1)] | |
print("EV=", evs) | |
print("MES:") | |
for p in (0, 1): | |
mes_ev_calc = CFR() | |
mes_ev_calc.profiles[0] = ev_calc.get_mes(p) | |
print(mes_ev_calc(tree, p, 0, (1, 1))) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment