Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Add MCTS example #2796

Open
wants to merge 7 commits into
base: gh/kurtamohler/5/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions examples/trees/mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torchrl
from tensordict import TensorDict

pgn_or_fen = "fen"
mask_actions = False

env = torchrl.envs.ChessEnv(
include_pgn=False,
include_fen=True,
include_hash=True,
include_hash_inv=True,
include_san=True,
stateful=True,
mask_actions=mask_actions,
)


def transform_reward(td):
if "reward" not in td:
return td
reward = td["reward"]
if reward == 0.5:
td["reward"] = 0
elif reward == 1 and td["turn"]:
td["reward"] = -td["reward"]
return td


# ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player.
# Need to transform the reward to be:
# white win = 1
# draw = 0
# black win = -1
env = env.append_transform(transform_reward)

forest = torchrl.data.MCTSForest()
forest.reward_keys = env.reward_keys
forest.done_keys = env.done_keys
forest.action_keys = env.action_keys

if mask_actions:
forest.observation_keys = [f"{pgn_or_fen}_hash", "turn", "action_mask"]
else:
forest.observation_keys = [f"{pgn_or_fen}_hash", "turn"]

C = 2.0**0.5


def traversal_priority_UCB1(tree):
subtree = tree.subtree
visits = subtree.visits
reward_sum = subtree.wins

# If it's black's turn, flip the reward, since black wants to
# optimize for the lowest reward, not highest.
if not subtree.rollout[0, 0]["turn"]:
reward_sum = -reward_sum

parent_visits = tree.visits
reward_sum = reward_sum.squeeze(-1)
priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits
priority[visits == 0] = float("inf")
return priority


def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps):
done = False
trees_visited = [tree]

while not done:
if tree.subtree is None:
td_tree = tree.rollout[-1]["next"].clone()

if (tree.visits > 0 or tree.parent is None) and not td_tree["done"]:
actions = env.all_actions(td_tree)
subtrees = []

for action in actions:
td = env.step(env.reset(td_tree).update(action))
new_node = torchrl.data.Tree(
rollout=td.unsqueeze(0),
node_data=td["next"].select(*forest.node_map.in_keys),
count=torch.tensor(0),
wins=torch.zeros_like(td["next"]["reward"]),
)
subtrees.append(new_node)

# NOTE: This whole script runs about 2x faster with lazy stack
# versus eager stack.
tree.subtree = TensorDict.lazy_stack(subtrees)
chosen_idx = torch.randint(0, len(subtrees), ()).item()
rollout_state = subtrees[chosen_idx].rollout[-1]["next"]

else:
rollout_state = td_tree

if rollout_state["done"]:
rollout_reward = rollout_state["reward"]
else:
rollout = env.rollout(
max_steps=max_rollout_steps,
tensordict=rollout_state,
)
rollout_reward = rollout[-1]["next", "reward"]
done = True

else:
priorities = traversal_priority_UCB1(tree)
chosen_idx = torch.argmax(priorities).item()
tree = tree.subtree[chosen_idx]
trees_visited.append(tree)

for tree in trees_visited:
tree.visits += 1
tree.wins += rollout_reward


def traverse_MCTS(forest, root, env, num_steps, max_rollout_steps):
"""Performs Monte-Carlo tree search in an environment.

Args:
forest (MCTSForest): Forest of the tree to update. If the tree does not
exist yet, it is added.
root (TensorDict): The root step of the tree to update.
env (EnvBase): Environment to performs actions in.
num_steps (int): Number of iterations to traverse.
max_rollout_steps (int): Maximum number of steps for each rollout.
"""
if root not in forest:
for action in env.all_actions(root):
td = env.step(env.reset(root.clone()).update(action))
forest.extend(td.unsqueeze(0))

tree = forest.get_tree(root)
tree.wins = torch.zeros_like(td["next", "reward"])
for subtree in tree.subtree:
subtree.wins = torch.zeros_like(td["next", "reward"])

for _ in range(num_steps):
_traverse_MCTS_one_step(forest, tree, env, max_rollout_steps)

return tree


def tree_format_fn(tree):
td = tree.rollout[-1]["next"]
return [
td["san"],
td[pgn_or_fen].split("\n")[-1],
tree.wins,
tree.visits,
]


def get_best_move(fen, mcts_steps, rollout_steps):
root = env.reset(TensorDict({"fen": fen}))
tree = traverse_MCTS(forest, root, env, mcts_steps, rollout_steps)

# print('------------------------------')
# print(tree.to_string(tree_format_fn))
# print('------------------------------')

moves = []

for subtree in tree.subtree:
san = subtree.rollout[0]["next", "san"]
reward_sum = subtree.wins
visits = subtree.visits
value_avg = (reward_sum / visits).item()
if not subtree.rollout[0]["turn"]:
value_avg = -value_avg
moves.append((value_avg, san))

moves = sorted(moves, key=lambda x: -x[0])

print("------------------")
for value_avg, san in moves:
print(f" {value_avg:0.02f} {san}")
print("------------------")

return moves[0][1]


# White has M1, best move Rd8#. Any other moves lose to M2 or M1.
fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1"
assert get_best_move(fen0, 100, 10) == "Rd8#"

# Black has M1, best move Qg6#. Other moves give rough equality or worse.
fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1"
assert get_best_move(fen1, 100, 10) == "Qg6#"

# White has M2, best move Rxg8+. Any other move loses.
fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1"
assert get_best_move(fen2, 1000, 10) == "Rxg8+"
67 changes: 47 additions & 20 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4157,43 +4157,68 @@ def test_env_reset_with_hash(self, stateful, include_san):
td_check = env.reset(td.select("fen_hash"))
assert (td_check == td).all()

@pytest.mark.parametrize("include_fen", [False, True])
@pytest.mark.parametrize("include_pgn", [False, True])
@pytest.mark.parametrize("include_fen,include_pgn", [[False, True], [True, False]])
@pytest.mark.parametrize("stateful", [False, True])
@pytest.mark.parametrize("mask_actions", [False, True])
def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
if not stateful and not include_fen and not include_pgn:
# pytest.skip("fen or pgn must be included if not stateful")
return

@pytest.mark.parametrize("include_hash", [False, True])
@pytest.mark.parametrize("include_san", [False, True])
@pytest.mark.parametrize("append_transform", [False, True])
# @pytest.mark.parametrize("mask_actions", [False, True])
@pytest.mark.parametrize("mask_actions", [False])
def test_all_actions(
self,
include_fen,
include_pgn,
stateful,
include_hash,
include_san,
append_transform,
mask_actions,
):
env = ChessEnv(
include_fen=include_fen,
include_pgn=include_pgn,
include_san=include_san,
include_hash=include_hash,
include_hash_inv=include_hash,
stateful=stateful,
mask_actions=mask_actions,
)
td = env.reset()

if not mask_actions:
with pytest.raises(RuntimeError, match="Cannot generate legal actions"):
env.all_actions()
return
def transform_reward(td):
if "reward" not in td:
return td
reward = td["reward"]
if reward == 0.5:
td["reward"] = 0
elif reward == 1 and td["turn"]:
td["reward"] = -td["reward"]
return td

# ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player.
# Need to transform the reward to be:
# white win = 1
# draw = 0
# black win = -1
if append_transform:
env = env.append_transform(transform_reward)

check_env_specs(env)

td = env.reset()

# Choose random actions from the output of `all_actions`
for _ in range(100):
if stateful:
all_actions = env.all_actions()
else:
for step_idx in range(100):
if step_idx % 5 == 0:
# Reset the the initial state first, just to make sure
# `all_actions` knows how to get the board state from the input.
env.reset()
all_actions = env.all_actions(td.clone())
all_actions = env.all_actions(td.clone())

# Choose some random actions and make sure they match exactly one of
# the actions from `all_actions`. This part is not tested when
# `mask_actions == False`, because `rand_action` can pick illegal
# actions in that case.
if mask_actions:
if mask_actions and step_idx % 4 == 0:
# TODO: Something is wrong in `ChessEnv.rand_action` which makes
# it fail to work properly for stateless mode. It doesn't know
# how to correctly reset the board state to what is given in the
Expand All @@ -4210,7 +4235,9 @@ def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):

action_idx = torch.randint(0, all_actions.shape[0], ()).item()
chosen_action = all_actions[action_idx]
td = env.step(td.update(chosen_action))["next"]
td_new = env.step(td.update(chosen_action).clone())
assert (td == td_new.exclude("next")).all()
td = td_new["next"]

if td["done"]:
td = env.reset()
Expand Down
5 changes: 5 additions & 0 deletions torchrl/data/map/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,11 @@ def valid_paths(cls, tree: Tree):
def __len__(self):
return len(self.data_map)

def __contains__(self, root: TensorDictBase):
if self.node_map is None:
return False
return root.select(*self.node_map.in_keys) in self.node_map

def to_string(self, td_root, node_format_fn=lambda tree: tree.node_data.to_dict()):
"""Generates a string representation of a tree in the forest.

Expand Down
Loading
Loading