|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import torch |
| 7 | +import torchrl |
| 8 | +from tensordict import TensorDict |
| 9 | + |
| 10 | +pgn_or_fen = "fen" |
| 11 | + |
| 12 | +env = torchrl.envs.ChessEnv( |
| 13 | + include_pgn=False, |
| 14 | + include_fen=True, |
| 15 | + include_hash=True, |
| 16 | + include_hash_inv=True, |
| 17 | + include_san=True, |
| 18 | + stateful=True, |
| 19 | + mask_actions=True, |
| 20 | +) |
| 21 | + |
| 22 | + |
| 23 | +def transform_reward(td): |
| 24 | + if "reward" not in td: |
| 25 | + return td |
| 26 | + reward = td["reward"] |
| 27 | + if reward == 0.5: |
| 28 | + td["reward"] = 0 |
| 29 | + elif reward == 1 and td["turn"]: |
| 30 | + td["reward"] = -td["reward"] |
| 31 | + return td |
| 32 | + |
| 33 | + |
| 34 | +# ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player. |
| 35 | +# Need to transform the reward to be: |
| 36 | +# white win = 1 |
| 37 | +# draw = 0 |
| 38 | +# black win = -1 |
| 39 | +env.append_transform(transform_reward) |
| 40 | + |
| 41 | +forest = torchrl.data.MCTSForest() |
| 42 | +forest.reward_keys = env.reward_keys + ["_visits", "_reward_sum"] |
| 43 | +forest.done_keys = env.done_keys |
| 44 | +forest.action_keys = env.action_keys |
| 45 | +forest.observation_keys = [f"{pgn_or_fen}_hash", "turn", "action_mask"] |
| 46 | + |
| 47 | +C = 2.0**0.5 |
| 48 | + |
| 49 | + |
| 50 | +def traversal_priority_UCB1(tree, root_visits): |
| 51 | + subtree = tree.subtree |
| 52 | + td_subtree = subtree.rollout[:, -1]["next"] |
| 53 | + visits = td_subtree["_visits"] |
| 54 | + reward_sum = td_subtree["_reward_sum"].clone() |
| 55 | + |
| 56 | + # If it's black's turn, flip the reward, since black wants to |
| 57 | + # optimize for the lowest reward, not highest. |
| 58 | + if not subtree.rollout[0, 0]["turn"]: |
| 59 | + reward_sum = -reward_sum |
| 60 | + |
| 61 | + if tree.rollout is None: |
| 62 | + parent_visits = root_visits |
| 63 | + else: |
| 64 | + parent_visits = tree.rollout[-1]["next", "_visits"] |
| 65 | + reward_sum = reward_sum.squeeze(-1) |
| 66 | + priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits |
| 67 | + priority[visits == 0] = float("inf") |
| 68 | + return priority |
| 69 | + |
| 70 | + |
| 71 | +def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps, root_visits): |
| 72 | + done = False |
| 73 | + td_trees_visited = [] |
| 74 | + |
| 75 | + while not done: |
| 76 | + if tree.subtree is None: |
| 77 | + td_tree = tree.rollout[-1]["next"].clone() |
| 78 | + |
| 79 | + if (td_tree["_visits"] > 0 or tree.parent is None) and not td_tree["done"]: |
| 80 | + actions = env.all_actions(td_tree) |
| 81 | + subtrees = [] |
| 82 | + |
| 83 | + for action in actions: |
| 84 | + td = env.step(env.reset(td_tree).update(action)).update( |
| 85 | + TensorDict( |
| 86 | + { |
| 87 | + ("next", "_visits"): 0, |
| 88 | + ("next", "_reward_sum"): env.reward_spec.zeros(), |
| 89 | + } |
| 90 | + ) |
| 91 | + ) |
| 92 | + |
| 93 | + new_node = torchrl.data.Tree( |
| 94 | + rollout=td.unsqueeze(0), |
| 95 | + node_data=td["next"].select(*forest.node_map.in_keys), |
| 96 | + ) |
| 97 | + subtrees.append(new_node) |
| 98 | + |
| 99 | + # NOTE: This whole script runs about 2x faster with lazy stack |
| 100 | + # versus eager stack. |
| 101 | + tree.subtree = TensorDict.lazy_stack(subtrees) |
| 102 | + chosen_idx = torch.randint(0, len(subtrees), ()).item() |
| 103 | + rollout_state = subtrees[chosen_idx].rollout[-1]["next"] |
| 104 | + |
| 105 | + else: |
| 106 | + rollout_state = td_tree |
| 107 | + |
| 108 | + if rollout_state["done"]: |
| 109 | + rollout_reward = rollout_state["reward"] |
| 110 | + else: |
| 111 | + rollout = env.rollout( |
| 112 | + max_steps=max_rollout_steps, |
| 113 | + tensordict=rollout_state, |
| 114 | + ) |
| 115 | + rollout_reward = rollout[-1]["next", "reward"] |
| 116 | + done = True |
| 117 | + |
| 118 | + else: |
| 119 | + priorities = traversal_priority_UCB1(tree, root_visits) |
| 120 | + chosen_idx = torch.argmax(priorities).item() |
| 121 | + tree = tree.subtree[chosen_idx] |
| 122 | + td_trees_visited.append(tree.rollout[-1]["next"]) |
| 123 | + |
| 124 | + for td in td_trees_visited: |
| 125 | + td["_visits"] += 1 |
| 126 | + td["_reward_sum"] += rollout_reward |
| 127 | + |
| 128 | + |
| 129 | +def traverse_MCTS(forest, root, env, num_steps, max_rollout_steps): |
| 130 | + """Performs Monte-Carlo tree search in an environment. |
| 131 | +
|
| 132 | + Args: |
| 133 | + forest (MCTSForest): Forest of the tree to update. If the tree does not |
| 134 | + exist yet, it is added. |
| 135 | + root (TensorDict): The root step of the tree to update. |
| 136 | + env (EnvBase): Environment to performs actions in. |
| 137 | + num_steps (int): Number of iterations to traverse. |
| 138 | + max_rollout_steps (int): Maximum number of steps for each rollout. |
| 139 | + """ |
| 140 | + if root not in forest: |
| 141 | + for action in env.all_actions(root): |
| 142 | + td = env.step(env.reset(root.clone()).update(action)).update( |
| 143 | + TensorDict( |
| 144 | + { |
| 145 | + ("next", "_visits"): 0, |
| 146 | + ("next", "_reward_sum"): env.reward_spec.zeros(), |
| 147 | + } |
| 148 | + ) |
| 149 | + ) |
| 150 | + forest.extend(td.unsqueeze(0)) |
| 151 | + |
| 152 | + tree = forest.get_tree(root) |
| 153 | + |
| 154 | + # TODO: Add this to the root node |
| 155 | + root_visits = torch.tensor([0]) |
| 156 | + |
| 157 | + for _ in range(num_steps): |
| 158 | + _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps, root_visits) |
| 159 | + root_visits += 1 |
| 160 | + |
| 161 | + return tree |
| 162 | + |
| 163 | + |
| 164 | +def tree_format_fn(tree): |
| 165 | + td = tree.rollout[-1]["next"] |
| 166 | + return [ |
| 167 | + td["san"], |
| 168 | + td[pgn_or_fen].split("\n")[-1], |
| 169 | + td["_reward_sum"].item(), |
| 170 | + td["_visits"].item(), |
| 171 | + ] |
| 172 | + |
| 173 | + |
| 174 | +def get_best_move(fen, mcts_steps, rollout_steps): |
| 175 | + root = env.reset(TensorDict({"fen": fen})) |
| 176 | + tree = traverse_MCTS(forest, root, env, mcts_steps, rollout_steps) |
| 177 | + |
| 178 | + # print('------------------------------') |
| 179 | + # print(tree.to_string(tree_format_fn)) |
| 180 | + # print('------------------------------') |
| 181 | + |
| 182 | + moves = [] |
| 183 | + |
| 184 | + for subtree in tree.subtree: |
| 185 | + san = subtree.rollout[0]["next", "san"] |
| 186 | + reward_sum = subtree.rollout[-1]["next", "_reward_sum"] |
| 187 | + visits = subtree.rollout[-1]["next", "_visits"] |
| 188 | + value_avg = (reward_sum / visits).item() |
| 189 | + if not subtree.rollout[0]["turn"]: |
| 190 | + value_avg = -value_avg |
| 191 | + moves.append((value_avg, san)) |
| 192 | + |
| 193 | + moves = sorted(moves, key=lambda x: -x[0]) |
| 194 | + |
| 195 | + print("------------------") |
| 196 | + for value_avg, san in moves: |
| 197 | + print(f" {value_avg:0.02f} {san}") |
| 198 | + print("------------------") |
| 199 | + |
| 200 | + return moves[0][1] |
| 201 | + |
| 202 | + |
| 203 | +# White has M1, best move Rd8#. Any other moves lose to M2 or M1. |
| 204 | +fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1" |
| 205 | +assert get_best_move(fen0, 100, 10) == "Rd8#" |
| 206 | + |
| 207 | +# Black has M1, best move Qg6#. Other moves give rough equality or worse. |
| 208 | +fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1" |
| 209 | +assert get_best_move(fen1, 100, 10) == "Qg6#" |
| 210 | + |
| 211 | +# White has M2, best move Rxg8+. Any other move loses. |
| 212 | +fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1" |
| 213 | +assert get_best_move(fen2, 1000, 10) == "Rxg8+" |
0 commit comments