Skip to content

Commit

Permalink
Update CFR abstraction training
Browse files Browse the repository at this point in the history
  • Loading branch information
Gongsta committed Jun 20, 2024
1 parent e8f73c4 commit d547590
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 560 deletions.
172 changes: 145 additions & 27 deletions src/aiplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,52 @@ def trash_talk_fold(self):
self.engine.say(random.choice(self.get_trash_talk("opponent_fold")))
self.engine.runAndWait()

def place_bet(self, observed_env) -> int: # AI will call every time
def process_action(self, action, observed_env):
if action == "k": # check
if observed_env.game_stage == 2:
self.current_bet = 2
else:
self.current_bet = 0

self.engine.say("I Check")
elif action == "c":
if observed_env.get_highest_current_bet() == self.player_balance:
self.engine.say("I call your all-in. You think I'm afraid?")
else:
self.engine.say(random.choice(self.get_trash_talk("c")))
# If you call on the preflop
self.current_bet = observed_env.get_highest_current_bet()
elif action == "f":
self.engine.say(random.choice(self.get_trash_talk("f")))
else:
self.current_bet = int(action[1:])
if self.current_bet == self.player_balance:
self.engine.say(random.choice(self.get_trash_talk("all_in")))
else:
self.engine.say(random.choice(self.get_trash_talk("b", self.current_bet)))

self.engine.runAndWait()

def place_bet(self, observed_env):
raise NotImplementedError


# Strategy with Heuristic
class EquityAIPlayer(AIPlayer):
def __init__(self, balance) -> None:
super().__init__(balance)

def place_bet(self, observed_env) -> int: # AI will call every time
"""
A Strategy implemented with human heuristics
"""
if "k" in observed_env.valid_actions():
action = "k"
else:
action = "c"

card_str = [str(card) for card in self.hand]
community_cards = [str(card) for card in observed_env.community_cards]
# if observed_env.game_stage == 2:

equity = calculate_equity(card_str, community_cards)

# fold, check / call, raise
Expand All @@ -96,7 +131,9 @@ def place_bet(self, observed_env) -> int: # AI will call every time
): # If you are the dealer, raise more of the time
strategy = {
"k": np_strategy[0],
f"b{min(max(observed_env.BIG_BLIND, int(observed_env.total_pot_balance / 3)), self.player_balance)}": np_strategy[2],
f"b{min(max(observed_env.BIG_BLIND, int(observed_env.total_pot_balance / 3)), self.player_balance)}": np_strategy[
2
],
f"b{min(observed_env.total_pot_balance, self.player_balance)}": np_strategy[1],
}
else:
Expand Down Expand Up @@ -138,37 +175,118 @@ def place_bet(self, observed_env) -> int: # AI will call every time
print("equity", equity)
print("AI strategy ", strategy)
action = getAction(strategy)
self.process_action(action, observed_env)
return action

# history = HoldEmHistory(observed_env.history)
# strategy = observed_env.get_average_strategy()

# print("AI strategy", strategy)
# print("AI action", action)
import joblib
from abstraction import calculate_equity, predict_cluster_fast
from postflop_holdem import HoldemInfoSet, HoldEmHistory

if action == "k": # check
if observed_env.game_stage == 2:
self.current_bet = 2
else:
self.current_bet = 0
import copy

self.engine.say("I Check")
elif action == "c":
if observed_env.get_highest_current_bet() == self.player_balance:
self.engine.say("I call your all-in. You think I'm afraid?")

class CFRAIPlayer(AIPlayer):
def __init__(self, balance) -> None:
super().__init__(balance)

self.infosets = joblib.load("../src/infoSets_batch_7.joblib")

def perform_postflop_abstraction(self, observed_env):
history = copy.deepcopy(observed_env.history)

pot_total = observed_env.BIG_BLIND * 2
# Compute preflop pot size
flop_start = history.index("/")
for i, action in enumerate(history[:flop_start]):
if action[0] == "b":
bet_size = int(action[1:])
pot_total = 2 * bet_size

# Remove preflop actions
abstracted_history = history[:2]

# Bet Abstraction (card abstraction is done later)
stage_start = flop_start
stage = self.get_stage(history[stage_start + 1 :])
latest_bet = 0
while True:
abstracted_history += ["/"]

if (
len(stage) >= 4 and stage[3] != "c"
): # length 4 that isn't a call, we need to condense down
abstracted_history += [stage[0]]

if stage[-1] == "c":
if len(stage) % 2 == 1: # ended on dealer
abstracted_history += ["bMAX", "c"]
else:
if stage[0] == "k":
abstracted_history += ["k", "bMAX", "c"]
else:
abstracted_history += ["bMIN", "bMAX", "c"]
else:
self.engine.say(random.choice(self.get_trash_talk("c")))
# If you call on the preflop
self.current_bet = observed_env.get_highest_current_bet()
elif action == "f":
self.engine.say(random.choice(self.get_trash_talk("f")))
for i, action in enumerate(stage):
if action[0] == "b":
bet_size = int(action[1:])
latest_bet = bet_size
pot_total += bet_size

# this is a raise on a small bet
if abstracted_history[-1] == "bMIN":
abstracted_history += ["bMAX"]
# this is a raise on a big bet
elif abstracted_history[-1] == "bMAX":
abstracted_history[-1] = "k" # turn into a check
else: # first bet
if bet_size >= pot_total:
abstracted_history += ["bMAX"]
else:
abstracted_history += ["bMIN"]

elif action == "c":
pot_total += latest_bet
abstracted_history += ["c"]
else:
abstracted_history += [action]

# Proceed to next stage or exit if final stage
if "/" not in history[stage_start + 1 :]:
break
stage_start = history[stage_start + 1 :].index("/") + (stage_start + 1)
stage = self.get_stage(history[stage_start + 1 :])

return abstracted_history

def get_stage(self, history):
if "/" in history:
return history[: history.index("/")]
else:
self.current_bet = int(action[1:])
if self.current_bet == self.player_balance:
self.engine.say(random.choice(self.get_trash_talk("all_in")))
return history

def place_bet(self, observed_env):
if observed_env.game_stage == 2: # preflop
if "k" in observed_env.valid_actions():
action = "k"
else:
self.engine.say(random.choice(self.get_trash_talk("b", self.current_bet)))
action = "c"
else:
abstracted_history = self.perform_postflop_abstraction(observed_env)
print("abstracted history", abstracted_history)
infoset_key = HoldEmHistory(abstracted_history).get_infoSet_key_online()
strategy = self.infosets[infoset_key].get_average_strategy()
print(infoset_key)
print("AI strategy ", strategy)
action = getAction(strategy)
if action == "bMIN":
action = "b" + str(
max(observed_env.BIG_BLIND, int(1 / 3 * observed_env.total_pot_balance))
)
elif action == "bMAX":
action = "b" + str(min(observed_env.total_pot_balance, self.player_balance))

self.engine.runAndWait()
self.process_action(action, observed_env)
return action


Expand Down
49 changes: 30 additions & 19 deletions src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,10 @@ def __init__(
create_history,
n_players: int = 2,
iterations: int = 1000000,
tracker_interval=1000,
):
self.n_players = n_players
self.iterations = iterations
self.tracker_interval = tracker_interval
self.tracker_interval = int(iterations / 10)
self.infoSets: Dict[str, InfoSet] = {}
self.create_infoSet = create_infoSet
self.create_history = create_history
Expand Down Expand Up @@ -193,7 +192,7 @@ def vanilla_cfr(
if history.is_terminal():
if debug:
print(f"history: {history.history} utility: {history.terminal_utility(i)}")
time.sleep(1)
time.sleep(0.1)
return history.terminal_utility(i)
elif history.is_chance():
a = (
Expand All @@ -206,9 +205,6 @@ def vanilla_cfr(
infoSet = self.get_infoSet(history)
assert infoSet.player() == history.player()

if debug:
print("infoset", infoSet.to_dict())

v = 0
va = {}

Expand All @@ -233,19 +229,30 @@ def vanilla_cfr(
# Update regret matching values
infoSet.get_strategy()

if debug:
print("infoset", infoSet.to_dict())
print("strategy", infoSet.strategy)

return v

def vanilla_cfr_speedup(self, history: History, t: int, pi_0: float, pi_1: float, debug=False):
"""
We double the speed by updating both player values simultaneously, since this is a zero-sum game.
NOTE: Doesn't work super well, I don't understand why. The trick here to speedup is by assuming by whatever the opponent gains is
the opposite of what we gain. Zero-sum game. However, need to make sure we always return the correct utility.
"""
# Return payoff for terminal states
# ['3d7c', '4cQd', '/', '7sKd9c', 'bMIN', 'f']
if history.is_terminal():
if debug:
print(history.history, history.terminal_utility(0))
time.sleep(1)
return history.terminal_utility(0)
print(
f"utility returned: {history.terminal_utility((len(history.get_last_game_stage())) % 2)}, history: {history.history}"
)
return history.terminal_utility(
(len(history.get_last_game_stage()) + 1) % 2
) # overfit solution for holdem
elif history.is_chance():
a = (
history.sample_chance_outcome()
Expand All @@ -257,9 +264,6 @@ def vanilla_cfr_speedup(self, history: History, t: int, pi_0: float, pi_1: float
infoSet = self.get_infoSet(history)
assert infoSet.player() == history.player()

if debug:
print("infoset", infoSet.to_dict())

v = 0
va = {}

Expand All @@ -285,6 +289,12 @@ def vanilla_cfr_speedup(self, history: History, t: int, pi_0: float, pi_1: float
# Update regret matching values
infoSet.get_strategy()

if debug:
print("infoset", infoSet.to_dict())
print("va", va)
print("strategy", infoSet.strategy)
time.sleep(0.1)

return v

def vanilla_cfr_manim(
Expand Down Expand Up @@ -356,11 +366,11 @@ def solve(self, method="vanilla_speedup", debug=False):
for player in range(self.n_players):
if player == 0:
util_0 += self.vanilla_cfr_manim(
self.create_history(), player, t, 1, 1, histories
self.create_history(t), player, t, 1, 1, histories
)
else:
util_1 += self.vanilla_cfr_manim(
self.create_history(), player, t, 1, 1, histories
self.create_history(t), player, t, 1, 1, histories
)

print(histories)
Expand All @@ -371,11 +381,11 @@ def solve(self, method="vanilla_speedup", debug=False):
): # This is the slower way, we can speed by updating both players
if player == 0:
util_0 += self.vanilla_cfr(
self.create_history(), player, t, 1, 1, debug=debug
self.create_history(t), player, t, 1, 1, debug=debug
)
else:
util_1 += self.vanilla_cfr(
self.create_history(), player, t, 1, 1, debug=debug
self.create_history(t), player, t, 1, 1, debug=debug
)

if (t + 1) % self.tracker_interval == 0:
Expand All @@ -384,13 +394,14 @@ def solve(self, method="vanilla_speedup", debug=False):
self.tracker(self.infoSets)
self.tracker.pprint()

if t % 2500 == 0:
if t % 500000 == 0:
self.export_infoSets(f"infoSets_{t}.joblib")

self.export_infoSets("infoSets_solved.joblib")
if method == "manim":
return histories

def export_infoSets(self, filename = "infoSets.joblib"):
def export_infoSets(self, filename="infoSets.joblib"):
joblib.dump(self.infoSets, filename)

def get_expected_value(
Expand Down Expand Up @@ -525,4 +536,4 @@ def __call__(self, infoSets: Dict[str, InfoSet]):
def pprint(self):
infoSets = self.tracker_hist[-1]
for infoSet in infoSets.values():
print(infoSet.infoSet, infoSet.get_average_strategy())
print(infoSet.infoSet, "Regret: ", infoSet.regret, "Average Strategy: ", infoSet.get_average_strategy())
10 changes: 5 additions & 5 deletions src/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from evaluator import *
from typing import List
from player import Player
from aiplayer import AIPlayer
from postflop_holdem import PostflopHoldemHistory, PostflopHoldemInfoSet
from aiplayer import CFRAIPlayer


class PokerEnvironment:
Expand Down Expand Up @@ -42,7 +43,7 @@ def __init__(self) -> None:
self.SMALL_BLIND = 1
self.BIG_BLIND = 2

self.INPUT_CARDS = True
self.INPUT_CARDS = False

self.history = []
self.players_balance_history = [] # List of "n" list for "n" players
Expand All @@ -54,7 +55,7 @@ def get_player(self, idx) -> Player:
return self.players[idx]

def add_AI_player(self): # Add a dumb AI
self.players.append(AIPlayer(self.new_player_balance))
self.players.append(CFRAIPlayer(self.new_player_balance))
self.AI_player_idx = len(self.players) - 1

def get_winning_players(self) -> List:
Expand Down Expand Up @@ -358,14 +359,13 @@ def end_round(self):
if player.playing_current_round:
player.trash_talk_win()
else:
player.get_trash_lose()
player.trash_talk_lose()

else:
for player in self.players:
if player.is_AI:
if player.playing_current_round:
player.trash_talk_fold()


self.game_stage = 6 # mark end of round
self.distribute_pot_to_winning_players()
Loading

0 comments on commit d547590

Please sign in to comment.