Skip to content

Commit 89dffab

Browse files
committed
[DRAFT, Example] Add MCTS example
ghstack-source-id: 15144dfb9e3ce724bc8c7a403b436f46ac8c5f8d Pull Request resolved: #2796
1 parent 8c9dc05 commit 89dffab

File tree

7 files changed

+353
-29
lines changed

7 files changed

+353
-29
lines changed

examples/trees/mcts.py

+200
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
import torchrl
8+
from tensordict import TensorDict
9+
10+
pgn_or_fen = "fen"
11+
mask_actions = False
12+
13+
env = torchrl.envs.ChessEnv(
14+
include_pgn=False,
15+
include_fen=True,
16+
include_hash=True,
17+
include_hash_inv=True,
18+
include_san=True,
19+
stateful=True,
20+
mask_actions=mask_actions,
21+
)
22+
23+
24+
def transform_reward(td):
25+
if "reward" not in td:
26+
return td
27+
reward = td["reward"]
28+
if reward == 0.5:
29+
td["reward"] = 0
30+
elif reward == 1 and td["turn"]:
31+
td["reward"] = -td["reward"]
32+
return td
33+
34+
35+
# ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player.
36+
# Need to transform the reward to be:
37+
# white win = 1
38+
# draw = 0
39+
# black win = -1
40+
env = env.append_transform(transform_reward)
41+
42+
forest = torchrl.data.MCTSForest()
43+
forest.reward_keys = env.reward_keys
44+
forest.done_keys = env.done_keys
45+
forest.action_keys = env.action_keys
46+
47+
if mask_actions:
48+
forest.observation_keys = [f"{pgn_or_fen}_hash", "turn", "action_mask"]
49+
else:
50+
forest.observation_keys = [f"{pgn_or_fen}_hash", "turn"]
51+
52+
C = 2.0**0.5
53+
54+
55+
def traversal_priority_UCB1(tree):
56+
subtree = tree.subtree
57+
visits = subtree.visits
58+
reward_sum = subtree.wins
59+
60+
# If it's black's turn, flip the reward, since black wants to
61+
# optimize for the lowest reward, not highest.
62+
if not subtree.rollout[0, 0]["turn"]:
63+
reward_sum = -reward_sum
64+
65+
parent_visits = tree.visits
66+
reward_sum = reward_sum.squeeze(-1)
67+
priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits
68+
priority[visits == 0] = float("inf")
69+
return priority
70+
71+
72+
def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps):
73+
done = False
74+
trees_visited = [tree]
75+
76+
while not done:
77+
if tree.subtree is None:
78+
td_tree = tree.rollout[-1]["next"].clone()
79+
80+
if (tree.visits > 0 or tree.parent is None) and not td_tree["done"]:
81+
actions = env.all_actions(td_tree)
82+
subtrees = []
83+
84+
for action in actions:
85+
td = env.step(env.reset(td_tree).update(action))
86+
new_node = torchrl.data.Tree(
87+
rollout=td.unsqueeze(0),
88+
node_data=td["next"].select(*forest.node_map.in_keys),
89+
count=torch.tensor(0),
90+
wins=torch.zeros_like(td["next"]["reward"]),
91+
)
92+
subtrees.append(new_node)
93+
94+
# NOTE: This whole script runs about 2x faster with lazy stack
95+
# versus eager stack.
96+
tree.subtree = TensorDict.lazy_stack(subtrees)
97+
chosen_idx = torch.randint(0, len(subtrees), ()).item()
98+
rollout_state = subtrees[chosen_idx].rollout[-1]["next"]
99+
100+
else:
101+
rollout_state = td_tree
102+
103+
if rollout_state["done"]:
104+
rollout_reward = rollout_state["reward"]
105+
else:
106+
rollout = env.rollout(
107+
max_steps=max_rollout_steps,
108+
tensordict=rollout_state,
109+
)
110+
rollout_reward = rollout[-1]["next", "reward"]
111+
done = True
112+
113+
else:
114+
priorities = traversal_priority_UCB1(tree)
115+
chosen_idx = torch.argmax(priorities).item()
116+
tree = tree.subtree[chosen_idx]
117+
trees_visited.append(tree)
118+
119+
for tree in trees_visited:
120+
tree.visits += 1
121+
tree.wins += rollout_reward
122+
123+
124+
def traverse_MCTS(forest, root, env, num_steps, max_rollout_steps):
125+
"""Performs Monte-Carlo tree search in an environment.
126+
127+
Args:
128+
forest (MCTSForest): Forest of the tree to update. If the tree does not
129+
exist yet, it is added.
130+
root (TensorDict): The root step of the tree to update.
131+
env (EnvBase): Environment to performs actions in.
132+
num_steps (int): Number of iterations to traverse.
133+
max_rollout_steps (int): Maximum number of steps for each rollout.
134+
"""
135+
if root not in forest:
136+
for action in env.all_actions(root):
137+
td = env.step(env.reset(root.clone()).update(action))
138+
forest.extend(td.unsqueeze(0))
139+
140+
tree = forest.get_tree(root)
141+
tree.wins = torch.zeros_like(td["next", "reward"])
142+
for subtree in tree.subtree:
143+
subtree.wins = torch.zeros_like(td["next", "reward"])
144+
145+
for _ in range(num_steps):
146+
_traverse_MCTS_one_step(forest, tree, env, max_rollout_steps)
147+
148+
return tree
149+
150+
151+
def tree_format_fn(tree):
152+
td = tree.rollout[-1]["next"]
153+
return [
154+
td["san"],
155+
td[pgn_or_fen].split("\n")[-1],
156+
tree.wins,
157+
tree.visits,
158+
]
159+
160+
161+
def get_best_move(fen, mcts_steps, rollout_steps):
162+
root = env.reset(TensorDict({"fen": fen}))
163+
tree = traverse_MCTS(forest, root, env, mcts_steps, rollout_steps)
164+
165+
# print('------------------------------')
166+
# print(tree.to_string(tree_format_fn))
167+
# print('------------------------------')
168+
169+
moves = []
170+
171+
for subtree in tree.subtree:
172+
san = subtree.rollout[0]["next", "san"]
173+
reward_sum = subtree.wins
174+
visits = subtree.visits
175+
value_avg = (reward_sum / visits).item()
176+
if not subtree.rollout[0]["turn"]:
177+
value_avg = -value_avg
178+
moves.append((value_avg, san))
179+
180+
moves = sorted(moves, key=lambda x: -x[0])
181+
182+
print("------------------")
183+
for value_avg, san in moves:
184+
print(f" {value_avg:0.02f} {san}")
185+
print("------------------")
186+
187+
return moves[0][1]
188+
189+
190+
# White has M1, best move Rd8#. Any other moves lose to M2 or M1.
191+
fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1"
192+
assert get_best_move(fen0, 100, 10) == "Rd8#"
193+
194+
# Black has M1, best move Qg6#. Other moves give rough equality or worse.
195+
fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1"
196+
assert get_best_move(fen1, 100, 10) == "Qg6#"
197+
198+
# White has M2, best move Rxg8+. Any other move loses.
199+
fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1"
200+
assert get_best_move(fen2, 1000, 10) == "Rxg8+"

test/test_env.py

+47-20
Original file line numberDiff line numberDiff line change
@@ -4157,43 +4157,68 @@ def test_env_reset_with_hash(self, stateful, include_san):
41574157
td_check = env.reset(td.select("fen_hash"))
41584158
assert (td_check == td).all()
41594159

4160-
@pytest.mark.parametrize("include_fen", [False, True])
4161-
@pytest.mark.parametrize("include_pgn", [False, True])
4160+
@pytest.mark.parametrize("include_fen,include_pgn", [[False, True], [True, False]])
41624161
@pytest.mark.parametrize("stateful", [False, True])
4163-
@pytest.mark.parametrize("mask_actions", [False, True])
4164-
def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
4165-
if not stateful and not include_fen and not include_pgn:
4166-
# pytest.skip("fen or pgn must be included if not stateful")
4167-
return
4168-
4162+
@pytest.mark.parametrize("include_hash", [False, True])
4163+
@pytest.mark.parametrize("include_san", [False, True])
4164+
@pytest.mark.parametrize("append_transform", [False, True])
4165+
# @pytest.mark.parametrize("mask_actions", [False, True])
4166+
@pytest.mark.parametrize("mask_actions", [False])
4167+
def test_all_actions(
4168+
self,
4169+
include_fen,
4170+
include_pgn,
4171+
stateful,
4172+
include_hash,
4173+
include_san,
4174+
append_transform,
4175+
mask_actions,
4176+
):
41694177
env = ChessEnv(
41704178
include_fen=include_fen,
41714179
include_pgn=include_pgn,
4180+
include_san=include_san,
4181+
include_hash=include_hash,
4182+
include_hash_inv=include_hash,
41724183
stateful=stateful,
41734184
mask_actions=mask_actions,
41744185
)
4175-
td = env.reset()
41764186

4177-
if not mask_actions:
4178-
with pytest.raises(RuntimeError, match="Cannot generate legal actions"):
4179-
env.all_actions()
4180-
return
4187+
def transform_reward(td):
4188+
if "reward" not in td:
4189+
return td
4190+
reward = td["reward"]
4191+
if reward == 0.5:
4192+
td["reward"] = 0
4193+
elif reward == 1 and td["turn"]:
4194+
td["reward"] = -td["reward"]
4195+
return td
4196+
4197+
# ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player.
4198+
# Need to transform the reward to be:
4199+
# white win = 1
4200+
# draw = 0
4201+
# black win = -1
4202+
if append_transform:
4203+
env = env.append_transform(transform_reward)
4204+
4205+
check_env_specs(env)
4206+
4207+
td = env.reset()
41814208

41824209
# Choose random actions from the output of `all_actions`
4183-
for _ in range(100):
4184-
if stateful:
4185-
all_actions = env.all_actions()
4186-
else:
4210+
for step_idx in range(100):
4211+
if step_idx % 5 == 0:
41874212
# Reset the the initial state first, just to make sure
41884213
# `all_actions` knows how to get the board state from the input.
41894214
env.reset()
4190-
all_actions = env.all_actions(td.clone())
4215+
all_actions = env.all_actions(td.clone())
41914216

41924217
# Choose some random actions and make sure they match exactly one of
41934218
# the actions from `all_actions`. This part is not tested when
41944219
# `mask_actions == False`, because `rand_action` can pick illegal
41954220
# actions in that case.
4196-
if mask_actions:
4221+
if mask_actions and step_idx % 4 == 0:
41974222
# TODO: Something is wrong in `ChessEnv.rand_action` which makes
41984223
# it fail to work properly for stateless mode. It doesn't know
41994224
# how to correctly reset the board state to what is given in the
@@ -4210,7 +4235,9 @@ def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
42104235

42114236
action_idx = torch.randint(0, all_actions.shape[0], ()).item()
42124237
chosen_action = all_actions[action_idx]
4213-
td = env.step(td.update(chosen_action))["next"]
4238+
td_new = env.step(td.update(chosen_action).clone())
4239+
assert (td == td_new.exclude("next")).all()
4240+
td = td_new["next"]
42144241

42154242
if td["done"]:
42164243
td = env.reset()

torchrl/data/map/tree.py

+5
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,11 @@ def valid_paths(cls, tree: Tree):
13641364
def __len__(self):
13651365
return len(self.data_map)
13661366

1367+
def __contains__(self, root: TensorDictBase):
1368+
if self.node_map is None:
1369+
return False
1370+
return root.select(*self.node_map.in_keys) in self.node_map
1371+
13671372
def to_string(self, td_root, node_format_fn=lambda tree: tree.node_data.to_dict()):
13681373
"""Generates a string representation of a tree in the forest.
13691374

0 commit comments

Comments
 (0)