Skip to content

Commit 9245a68

Browse files
committed
[Feature] ConditionalPolicySwitch transform
ghstack-source-id: ff74aec4b65679fe6b4b7a84c2cadc8186ed2451 Pull Request resolved: #2711
1 parent 6179523 commit 9245a68

File tree

8 files changed

+509
-24
lines changed

8 files changed

+509
-24
lines changed

docs/source/reference/envs.rst

+1
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,7 @@ to be able to create this other composition:
816816
CenterCrop
817817
ClipTransform
818818
Compose
819+
ConditionalPolicySwitch
819820
Crop
820821
DTypeCastTransform
821822
DeviceCastTransform

examples/agents/ppo-chess.py

+54-18
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,24 @@
55
import tensordict.nn
66
import torch
77
import tqdm
8-
from tensordict.nn import TensorDictSequential as TDSeq, TensorDictModule as TDMod, \
9-
ProbabilisticTensorDictModule as TDProb, ProbabilisticTensorDictSequential as TDProbSeq
8+
from tensordict.nn import (
9+
ProbabilisticTensorDictModule as TDProb,
10+
ProbabilisticTensorDictSequential as TDProbSeq,
11+
TensorDictModule as TDMod,
12+
TensorDictSequential as TDSeq,
13+
)
1014
from torch import nn
1115
from torch.nn.utils import clip_grad_norm_
1216
from torch.optim import Adam
1317

1418
from torchrl.collectors import SyncDataCollector
19+
from torchrl.data import LazyTensorStorage, ReplayBuffer, SamplerWithoutReplacement
1520

1621
from torchrl.envs import ChessEnv, Tokenizer
1722
from torchrl.modules import MLP
1823
from torchrl.modules.distributions import MaskedCategorical
1924
from torchrl.objectives import ClipPPOLoss
2025
from torchrl.objectives.value import GAE
21-
from torchrl.data import ReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement
2226

2327
tensordict.nn.set_composite_lp_aggregate(False)
2428

@@ -39,7 +43,9 @@
3943
embedding_moves = nn.Embedding(num_embeddings=n + 1, embedding_dim=64)
4044

4145
# Embedding for the fen
42-
embedding_fen = nn.Embedding(num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64)
46+
embedding_fen = nn.Embedding(
47+
num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64
48+
)
4349

4450
backbone = MLP(out_features=512, num_cells=[512] * 8, activation_class=nn.ReLU)
4551

@@ -49,20 +55,30 @@
4955
critic_head = nn.Linear(512, 1)
5056
critic_head.bias.data.fill_(0)
5157

52-
prob = TDProb(in_keys=["logits", "mask"], out_keys=["action"], distribution_class=MaskedCategorical, return_log_prob=True)
58+
prob = TDProb(
59+
in_keys=["logits", "mask"],
60+
out_keys=["action"],
61+
distribution_class=MaskedCategorical,
62+
return_log_prob=True,
63+
)
64+
5365

5466
def make_mask(idx):
5567
mask = idx.new_zeros((*idx.shape[:-1], n + 1), dtype=torch.bool)
5668
return mask.scatter_(-1, idx, torch.ones_like(idx, dtype=torch.bool))[..., :-1]
5769

70+
5871
actor = TDProbSeq(
59-
TDMod(
60-
make_mask,
61-
in_keys=["legal_moves"], out_keys=["mask"]),
72+
TDMod(make_mask, in_keys=["legal_moves"], out_keys=["mask"]),
6273
TDMod(embedding_moves, in_keys=["legal_moves"], out_keys=["embedded_legal_moves"]),
6374
TDMod(embedding_fen, in_keys=["fen_tokenized"], out_keys=["embedded_fen"]),
64-
TDMod(lambda *args: torch.cat([arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1), in_keys=["embedded_legal_moves", "embedded_fen"],
65-
out_keys=["features"]),
75+
TDMod(
76+
lambda *args: torch.cat(
77+
[arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1
78+
),
79+
in_keys=["embedded_legal_moves", "embedded_fen"],
80+
out_keys=["features"],
81+
),
6682
TDMod(backbone, in_keys=["features"], out_keys=["hidden"]),
6783
TDMod(actor_head, in_keys=["hidden"], out_keys=["logits"]),
6884
prob,
@@ -78,7 +94,9 @@ def make_mask(idx):
7894

7995
optim = Adam(loss.parameters())
8096

81-
gae = GAE(value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True)
97+
gae = GAE(
98+
value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True
99+
)
82100

83101
# Create a data collector
84102
collector = SyncDataCollector(
@@ -88,12 +106,20 @@ def make_mask(idx):
88106
total_frames=1_000_000,
89107
)
90108

91-
replay_buffer0 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
92-
replay_buffer1 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
109+
replay_buffer0 = ReplayBuffer(
110+
storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2),
111+
batch_size=batch_size,
112+
sampler=SamplerWithoutReplacement(),
113+
)
114+
replay_buffer1 = ReplayBuffer(
115+
storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2),
116+
batch_size=batch_size,
117+
sampler=SamplerWithoutReplacement(),
118+
)
93119

94120
for data in tqdm.tqdm(collector):
95121
data = data.filter_non_tensor_data()
96-
print('data', data[0::2])
122+
print("data", data[0::2])
97123
for i in range(num_epochs):
98124
replay_buffer0.empty()
99125
replay_buffer1.empty()
@@ -103,14 +129,24 @@ def make_mask(idx):
103129
# player 1
104130
data1 = gae(data[1::2])
105131
if i == 0:
106-
print('win rate for 0', data0["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
107-
print('win rate for 1', data1["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
132+
print(
133+
"win rate for 0",
134+
data0["next", "reward"].sum()
135+
/ data["next", "done"].sum().clamp_min(1e-6),
136+
)
137+
print(
138+
"win rate for 1",
139+
data1["next", "reward"].sum()
140+
/ data["next", "done"].sum().clamp_min(1e-6),
141+
)
108142

109143
replay_buffer0.extend(data0)
110144
replay_buffer1.extend(data1)
111145

112-
n_iter = collector.frames_per_batch//(2 * batch_size)
113-
for (d0, d1) in tqdm.tqdm(zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter):
146+
n_iter = collector.frames_per_batch // (2 * batch_size)
147+
for (d0, d1) in tqdm.tqdm(
148+
zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter
149+
):
114150
loss_vals = (loss(d0) + loss(d1)) / 2
115151
loss_vals.sum(reduce=True).backward()
116152
gn = clip_grad_norm_(loss.parameters(), 100.0)

test/test_transforms.py

+202
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import tensordict.tensordict
2222
import torch
23+
from tensordict.nn import WrapModule
2324

2425
from torchrl.collectors import MultiSyncDataCollector
2526

@@ -106,6 +107,7 @@
106107
CenterCrop,
107108
ClipTransform,
108109
Compose,
110+
ConditionalPolicySwitch,
109111
Crop,
110112
DeviceCastTransform,
111113
DiscreteActionProjection,
@@ -13192,6 +13194,206 @@ def test_composite_reward_spec(self) -> None:
1319213194
assert transform.transform_reward_spec(reward_spec) == expected_reward_spec
1319313195

1319413196

13197+
class TestConditionalPolicySwitch(TransformBase):
13198+
def test_single_trans_env_check(self):
13199+
base_env = CountingEnv(max_steps=15)
13200+
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
13201+
# Player 0
13202+
policy_odd = lambda td: td.set("action", env.action_spec.zero())
13203+
policy_even = lambda td: td.set("action", env.action_spec.one())
13204+
transforms = Compose(
13205+
StepCounter(),
13206+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13207+
)
13208+
env = base_env.append_transform(transforms)
13209+
env.check_env_specs()
13210+
13211+
def _create_policy_odd(self, base_env):
13212+
return WrapModule(
13213+
lambda td, base_env=base_env: td.set(
13214+
"action", base_env.action_spec_unbatched.zero(td.shape)
13215+
),
13216+
out_keys=["action"],
13217+
)
13218+
13219+
def _create_policy_even(self, base_env):
13220+
return WrapModule(
13221+
lambda td, base_env=base_env: td.set(
13222+
"action", base_env.action_spec_unbatched.one(td.shape)
13223+
),
13224+
out_keys=["action"],
13225+
)
13226+
13227+
def _create_transforms(self, condition, policy_even):
13228+
return Compose(
13229+
StepCounter(),
13230+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13231+
)
13232+
13233+
def _make_env(self, max_count, env_cls):
13234+
torch.manual_seed(0)
13235+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13236+
base_env = env_cls(max_steps=max_count)
13237+
policy_even = self._create_policy_even(base_env)
13238+
transforms = self._create_transforms(condition, policy_even)
13239+
return base_env.append_transform(transforms)
13240+
13241+
def _test_env(self, env, policy_odd):
13242+
env.check_env_specs()
13243+
env.set_seed(0)
13244+
r = env.rollout(100, policy_odd, break_when_any_done=False)
13245+
# Check results are independent: one reset / step in one env should not impact results in another
13246+
r0, r1, r2 = r.unbind(0)
13247+
r0_split = r0.split(6)
13248+
assert all(((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:]))
13249+
r1_split = r1.split(7)
13250+
assert all(((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:]))
13251+
r2_split = r2.split(8)
13252+
assert all(((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:]))
13253+
13254+
def test_trans_serial_env_check(self):
13255+
torch.manual_seed(0)
13256+
base_env = SerialEnv(
13257+
3,
13258+
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
13259+
batch_locked=False,
13260+
)
13261+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13262+
policy_odd = self._create_policy_odd(base_env)
13263+
policy_even = self._create_policy_even(base_env)
13264+
transforms = self._create_transforms(condition, policy_even)
13265+
env = base_env.append_transform(transforms)
13266+
self._test_env(env, policy_odd)
13267+
13268+
def test_trans_parallel_env_check(self):
13269+
torch.manual_seed(0)
13270+
base_env = ParallelEnv(
13271+
3,
13272+
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
13273+
batch_locked=False,
13274+
mp_start_method=mp_ctx,
13275+
)
13276+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13277+
policy_odd = self._create_policy_odd(base_env)
13278+
policy_even = self._create_policy_even(base_env)
13279+
transforms = self._create_transforms(condition, policy_even)
13280+
env = base_env.append_transform(transforms)
13281+
self._test_env(env, policy_odd)
13282+
13283+
def test_serial_trans_env_check(self):
13284+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13285+
policy_odd = self._create_policy_odd(CountingEnv())
13286+
13287+
def make_env(max_count):
13288+
return partial(self._make_env, max_count, CountingEnv)
13289+
13290+
env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)])
13291+
self._test_env(env, policy_odd)
13292+
13293+
def test_parallel_trans_env_check(self):
13294+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13295+
policy_odd = self._create_policy_odd(CountingEnv())
13296+
13297+
def make_env(max_count):
13298+
return partial(self._make_env, max_count, CountingEnv)
13299+
13300+
env = ParallelEnv(
13301+
3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx
13302+
)
13303+
self._test_env(env, policy_odd)
13304+
13305+
def test_transform_no_env(self):
13306+
policy_odd = lambda td: td
13307+
policy_even = lambda td: td
13308+
condition = lambda td: True
13309+
transforms = ConditionalPolicySwitch(condition=condition, policy=policy_even)
13310+
with pytest.raises(
13311+
RuntimeError,
13312+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
13313+
):
13314+
transforms(TensorDict())
13315+
13316+
def test_transform_compose(self):
13317+
policy_odd = lambda td: td
13318+
policy_even = lambda td: td
13319+
condition = lambda td: True
13320+
transforms = Compose(
13321+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13322+
)
13323+
with pytest.raises(
13324+
RuntimeError,
13325+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
13326+
):
13327+
transforms(TensorDict())
13328+
13329+
def test_transform_env(self):
13330+
base_env = CountingEnv(max_steps=15)
13331+
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
13332+
# Player 0
13333+
policy_odd = lambda td: td.set("action", env.action_spec.zero())
13334+
policy_even = lambda td: td.set("action", env.action_spec.one())
13335+
transforms = Compose(
13336+
StepCounter(),
13337+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13338+
)
13339+
env = base_env.append_transform(transforms)
13340+
env.check_env_specs()
13341+
r = env.rollout(1000, policy_odd, break_when_all_done=True)
13342+
assert r.shape[0] == 15
13343+
assert (r["action"] == 0).all()
13344+
assert (
13345+
r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1)
13346+
).all()
13347+
assert r["next", "done"].any()
13348+
13349+
# Player 1
13350+
condition = lambda td: ((td.get("step_count") % 2) == 1).all()
13351+
transforms = Compose(
13352+
StepCounter(),
13353+
ConditionalPolicySwitch(condition=condition, policy=policy_odd),
13354+
)
13355+
env = base_env.append_transform(transforms)
13356+
r = env.rollout(1000, policy_even, break_when_all_done=True)
13357+
assert r.shape[0] == 16
13358+
assert (r["action"] == 1).all()
13359+
assert (
13360+
r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1)
13361+
).all()
13362+
assert r["next", "done"].any()
13363+
13364+
def test_transform_model(self):
13365+
policy_odd = lambda td: td
13366+
policy_even = lambda td: td
13367+
condition = lambda td: True
13368+
transforms = nn.Sequential(
13369+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13370+
)
13371+
with pytest.raises(
13372+
RuntimeError,
13373+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
13374+
):
13375+
transforms(TensorDict())
13376+
13377+
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
13378+
def test_transform_rb(self, rbclass):
13379+
policy_odd = lambda td: td
13380+
policy_even = lambda td: td
13381+
condition = lambda td: True
13382+
rb = rbclass(storage=LazyTensorStorage(10))
13383+
rb.append_transform(
13384+
ConditionalPolicySwitch(condition=condition, policy=policy_even)
13385+
)
13386+
rb.extend(TensorDict(batch_size=[2]))
13387+
with pytest.raises(
13388+
RuntimeError,
13389+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
13390+
):
13391+
rb.sample(2)
13392+
13393+
def test_transform_inverse(self):
13394+
return
13395+
13396+
1319513397
if __name__ == "__main__":
1319613398
args, unknown = argparse.ArgumentParser().parse_known_args()
1319713399
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
CenterCrop,
5656
ClipTransform,
5757
Compose,
58+
ConditionalPolicySwitch,
5859
Crop,
5960
DeviceCastTransform,
6061
DiscreteActionProjection,

0 commit comments

Comments
 (0)