Skip to content

[Feature,Example] Add MCTS algorithm and example #2796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: gh/kurtamohler/5/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions examples/trees/mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import time

import torch
import torchrl
import torchrl.envs
import torchrl.modules.mcts
from tensordict import TensorDict

pgn_or_fen = "fen"
mask_actions = True

env = torchrl.envs.ChessEnv(
include_pgn=False,
include_fen=True,
include_hash=True,
include_hash_inv=True,
include_san=True,
stateful=True,
mask_actions=mask_actions,
)


class TransformReward:
def __call__(self, td):
if "reward" not in td:
return td

reward = td["reward"]

if reward == 0.5:
reward = 0
elif reward == 1 and td["turn"]:
reward = -reward

td["reward"] = reward
return td


# ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player.
# Need to transform the reward to be:
# white win = 1
# draw = 0
# black win = -1
transform_reward = TransformReward()
env = env.append_transform(transform_reward)

forest = torchrl.data.MCTSForest()
forest.reward_keys = env.reward_keys
forest.done_keys = env.done_keys
forest.action_keys = env.action_keys

if mask_actions:
forest.observation_keys = [f"{pgn_or_fen}_hash", "turn", "action_mask"]
else:
forest.observation_keys = [f"{pgn_or_fen}_hash", "turn"]


def tree_format_fn(tree):
td = tree.rollout[-1]["next"]
return [
td["san"],
td[pgn_or_fen].split("\n")[-1],
tree.wins,
tree.visits,
]


def get_best_move(fen, mcts_steps, rollout_steps):
root = env.reset(TensorDict({"fen": fen}))
tree = torchrl.modules.mcts.MCTS(forest, root, env, mcts_steps, rollout_steps)
moves = []

for subtree in tree.subtree:
san = subtree.rollout[0]["next", "san"]
reward_sum = subtree.wins
visits = subtree.visits
value_avg = (reward_sum / visits).item()
if not root["turn"]:
value_avg = -value_avg
moves.append((value_avg, san))

moves = sorted(moves, key=lambda x: -x[0])

# print(tree.to_string(tree_format_fn))

print("------------------")
for value_avg, san in moves:
print(f" {value_avg:0.02f} {san}")
print("------------------")

return moves[0][1]


for idx in range(3):
print("==========")
print(idx)
print("==========")
torch.manual_seed(idx)

start_time = time.time()

# White has M1, best move Rd8#. Any other moves lose to M2 or M1.
fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1"
assert get_best_move(fen0, 40, 10) == "Rd8#"

# Black has M1, best move Qg6#. Other moves give rough equality or worse.
fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1"
assert get_best_move(fen1, 40, 10) == "Qg6#"

# White has M2, best move Rxg8+. Any other move loses.
fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1"
assert get_best_move(fen2, 600, 10) == "Rxg8+"

# Black has M2, best move Rxg1+. Any other move loses.
fen3 = "2r5/5R2/8/8/8/7k/5P1P/2r3QK b - - 0 1"
assert get_best_move(fen3, 600, 10) == "Rxg1+"

end_time = time.time()
total_time = end_time - start_time

print(f"Took {total_time} s")
5 changes: 5 additions & 0 deletions torchrl/data/map/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,6 +1363,11 @@ def valid_paths(cls, tree: Tree):
def __len__(self):
return len(self.data_map)

def __contains__(self, root: TensorDictBase):
if self.node_map is None:
return False
return root.select(*self.node_map.in_keys) in self.node_map

def to_string(self, td_root, node_format_fn=lambda tree: tree.node_data.to_dict()):
"""Generates a string representation of a tree in the forest.

Expand Down
9 changes: 7 additions & 2 deletions torchrl/envs/custom/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,15 @@ def lib(cls):
return chess

_san_moves = []
_san_move_to_index_map = {}

@_classproperty
def san_moves(cls):
if not cls._san_moves:
with open(pathlib.Path(__file__).parent / "san_moves.txt", "r+") as f:
cls._san_moves.extend(f.read().split("\n"))
for idx, san_move in enumerate(cls._san_moves):
cls._san_move_to_index_map[san_move] = idx
return cls._san_moves

def _legal_moves_to_index(
Expand Down Expand Up @@ -255,7 +258,7 @@ def _legal_moves_to_index(
board = self.board

indices = torch.tensor(
[self._san_moves.index(board.san(m)) for m in board.legal_moves],
[self._san_move_to_index_map[board.san(m)] for m in board.legal_moves],
dtype=torch.int64,
)
mask = None
Expand Down Expand Up @@ -409,7 +412,9 @@ def _reset(self, tensordict=None):
if move is None:
dest.set("san", "<start>")
else:
dest.set("san", self.board.san(move))
prev_board = self.board.copy()
prev_board.pop()
dest.set("san", prev_board.san(move))
if self.include_fen:
dest.set("fen", fen)
if self.include_pgn:
Expand Down
6 changes: 6 additions & 0 deletions torchrl/modules/mcts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .mcts import MCTS
118 changes: 118 additions & 0 deletions torchrl/modules/mcts/mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torchrl
from tensordict import TensorDict, TensorDictBase

from torchrl.data.map import MCTSForest, Tree
from torchrl.envs import EnvBase

C = 2.0**0.5


# TODO: Allow user to specify different priority functions with PR #2358
def _traversal_priority_UCB1(tree):
subtree = tree.subtree
visits = subtree.visits
reward_sum = subtree.wins

# If it's black's turn, flip the reward, since black wants to optimize for
# the lowest reward, not highest.
# TODO: Need a more generic way to do this, since not all use cases of MCTS
Copy link
Collaborator Author

@kurtamohler kurtamohler May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need to come up with a good way to specify how to process the exploitation value for each different player.

During an MCTS traversal, when we visit each node, we have to rank all the child nodes and decide which one to traverse to, using one of the standard exploration/exploitation formulas. I only know about the UCB1 formula at the moment, so I'll focus on that one. To calculate the expoitation value of each child node, one of the values that the UCB1 formula operates on is the "reward average" attached to each child node--the sum of all the rewards of the rollouts that have been performed under the child node divided by the number of times that child node has been visited during traversals.

Let's say that at the end of a rollout, the reward value that the chess environment gives for a white win is 1, black win is -1, and draw is 0. In order to do the exploitation part correctly, we should assume that each player wants to maximize their chances of winning. So if it's white's turn at a particular node, we want exploitation actions to maximize the reward average. But if it's black's turn, we want expoitation actions to minimize the reward average, so we have to flip the sign of the reward average when we calculate the UCB1 value on black's turn.

Or let's say we instead want to use a two-element reward. A white win is [1, -1], black win is [-1, 1], and draw is [0, 0]. Now the reward average at each node of the MCTS tree has two elements. When it's white's turn at a particular node, we want to look at the first element of the reward average of each child node. When it's black's turn, we want to look at the second element.

Ideally, our MCTS API should be able to handle both of the above reward schemes, and any other sensible kind of reward scheme that users would want to use, for environments with any number of agents.

Copy link
Collaborator Author

@kurtamohler kurtamohler May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking of just adding an argument to the MCTS function called select_player_reward_fn, or something along those lines, which is a callable with the signature (td, reward_average). In the two examples I gave above, the user could specify something like this:

select_player_reward_fn=lambda td, reward_avg: reward_avg if td["turn"] else -reward_avg for 1-element reward between 1 and -1

select_player_reward_fn=lambda td, reward_avg: reward_avg[td["turn"]] for n-element reward

In order to have a sensible default, we could assume that the reward is normally n-element, one value for each player, since that is the normal setup for multi-agent environments

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was curious about how other libraries handle this, so I looked into a couple of them.

The mcts Python library has a sign flip for two-player games. The total reward is multiplied by the player value (https://github.com/pbsinclair42/MCTS/blob/f4b6226ce06840e51b66bdca4bedbe2c4c143012/mcts.py#L112), which is either 1 or -1 in the example scripts of this repo. So this library seems to be specifically for two player games.

The MCTS implementation within the open_spiel library has an n-element reward, one for each player, and chooses the reward for whichever player whose turn it is, like the default behavior I proposed in the last comment: https://github.com/google-deepmind/open_spiel/blob/8296179b697644cf957c7c9313f594c062cbd17c/open_spiel/python/algorithms/mcts.py#L368-L369

Copy link
Collaborator Author

@kurtamohler kurtamohler May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybe our implementation of MCTS should require the reward to have one element per player, and also require that the environment output spec has a key indicating which player's turn it is. Or at least that could be the default behavior and we could allow the user to override it with a callable argument.

Let me know what you think @vmoens.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually the way we frame things in torchrl is that if you have 2 agents, we have

reward_spec = Composite(
    agent0=Composite(Unbounded(...)),
    agent1=Composite(Unbounded(...)),
)

(you can perfectly substitute agent0 with white)

If you have separate actions you can do the same. Now we should even be able to support easily turn-based environments. One way to do this would be to use the new ConditionalPolicySwitch transform, or simply to have action=torch.zeros((0,)) when the player is not acting at this turn.

I guess observations will be shared in the case of chess.

The mcts Python library has a sign flip for two-player games

That would only scale to two players though, no?

The MCTS implementation within the open_spiel library has an n-element reward, one for each player, and chooses the reward for whichever player whose turn it is

That's better than flipping the sign but IMO it's also not awesome as it doesn't show explicitly who owns what reward.

I think maybe our implementation of MCTS should require the reward to have one element per player, and also require that the environment output spec has a key indicating which player's turn it is. Or at least that could be the default behavior and we could allow the user to override it with a callable argument.

So let's imagine that we're at the root (observing a given dispositon of the board) and we ask ourselves what action should be taken.
We can have an observation spec

TensorDict(
  agent0=TensorDict(turn=torch.Tensor(True, ...)),
  agent1=TensorDict(turn=torch.Tensor(False, ...)),
  ...
)

that informes us on the current turn. Then we can use that to index the rewards.
Wdyt?

I can give a shot at making Chess look like this if you think that would help.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick hacky example:

from torchrl.envs import ChessEnv

from torchrl.envs import Transform
from torchrl.data import Unbounded, Composite
import torch
class TurnBasedChess(Transform):
    def transform_observation_spec(self, obsspec):
        obsspec["agent0", "turn"] = Unbounded(dtype=torch.bool, shape=())
        obsspec["agent1", "turn"] = Unbounded(dtype=torch.bool, shape=())
        return obsspec
    def transform_reward_spec(self, reward_spec):
        reward = reward_spec["reward"].clone()
        del reward_spec["reward"]
        return Composite(
            agent0=Composite(reward=reward),
            agent1=Composite(reward=reward),
        )
    def _reset(self, _td, td):
        td["agent0", "turn"] = True
        td["agent1", "turn"] = False
        print(f"{td=}")
        return td
    def _step(self, td, td_next):
        td_next["agent0", "turn"] = ~td["agent0", "turn"]
        td_next["agent1", "turn"] = ~td["agent1", "turn"]
        
        td_next["agent0", "reward"] = td_next["reward"]
        td_next["agent1", "reward"] = -td_next["reward"]
        del td_next["reward"]
        
        print(f"{td_next=}")
        return td_next
base_env = ChessEnv()
env = base_env.append_transform(TurnBasedChess())
env.rollout(3)

Copy link
Collaborator Author

@kurtamohler kurtamohler May 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make sure I understand, I think you're suggesting:

  • The MCTS module should expect the following two points and understand how to pull out the reward for whichever agent's turn it is.
  • Reward spec should have one reward per agent, keyed by the agent name. Example key: ("agent0", "reward")
  • Observation spec should contain a bool for each agent indicating whether it's their turn, also keyed by the agent name. It should only ever be one agent's turn at each step. Example key: ("agent0", "turn"). (Presumably we would give MCTS an argument like turn_key that tells it what key the turn value is under. I suppose in most cases it would just be turn_key="turn".)

That basically seems reasonable to me.

I still wonder though, should we provide a way for the user to override the above reward processing behavior and define their own method for pulling out and processing a reward value? I'm imagining what if the user's environment stacks all the agent data, so they don't have separate keys?

# will be two player turn based games.
if not subtree.rollout[0, 0]["turn"]:
reward_sum = -reward_sum

parent_visits = tree.visits
reward_sum = reward_sum.squeeze(-1)
priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits
priority[visits == 0] = float("inf")
return priority


def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps):
done = False
trees_visited = [tree]

while not done:
if tree.subtree is None:
td_tree = tree.rollout[-1]["next"].clone()

if (tree.visits > 0 or tree.parent is None) and not td_tree["done"]:
actions = env.all_actions(td_tree)
subtrees = []

for action in actions:
td = env.step(env.reset(td_tree).update(action))
new_node = torchrl.data.Tree(
rollout=td.unsqueeze(0),
node_data=td["next"].select(*forest.node_map.in_keys),
count=torch.tensor(0),
wins=torch.zeros_like(td["next", env.reward_key]),
)
subtrees.append(new_node)

# NOTE: This whole script runs about 2x faster with lazy stack
# versus eager stack.
tree.subtree = TensorDict.lazy_stack(subtrees)
chosen_idx = torch.randint(0, len(subtrees), ()).item()
rollout_state = subtrees[chosen_idx].rollout[-1]["next"]

else:
rollout_state = td_tree

if rollout_state["done"]:
rollout_reward = rollout_state[env.reward_key]
else:
rollout = env.rollout(
max_steps=max_rollout_steps,
tensordict=rollout_state,
)
rollout_reward = rollout[-1]["next", env.reward_key]
done = True

else:
priorities = _traversal_priority_UCB1(tree)
chosen_idx = torch.argmax(priorities).item()
tree = tree.subtree[chosen_idx]
trees_visited.append(tree)

for tree in trees_visited:
tree.visits += 1
tree.wins += rollout_reward


def MCTS(
forest: MCTSForest,
root: TensorDictBase,
env: EnvBase,
num_steps: int,
max_rollout_steps: int | None = None,
) -> Tree:
"""Performs Monte-Carlo tree search in an environment.

Args:
forest (MCTSForest): Forest of the tree to update. If the tree does not
exist yet, it is added.
root (TensorDict): The root step of the tree to update.
env (EnvBase): Environment to performs actions in.
num_steps (int): Number of iterations to traverse.
max_rollout_steps (int): Maximum number of steps for each rollout.
"""
for action in env.all_actions(root):
td = env.step(env.reset(root.clone()).update(action))
forest.extend(td.unsqueeze(0))

tree = forest.get_tree(root)

tree.wins = torch.zeros_like(td["next", env.reward_key])
for subtree in tree.subtree:
subtree.wins = torch.zeros_like(td["next", env.reward_key])

for _ in range(num_steps):
_traverse_MCTS_one_step(forest, tree, env, max_rollout_steps)

return tree
Loading