Skip to content

Commit 3c1241d

Browse files
committed
[Feature] ConditionalPolicySwitch transform
ghstack-source-id: f147e7c6b0f55da5746f79563af66ad057021d66 Pull Request resolved: #2711
1 parent bf707f5 commit 3c1241d

File tree

8 files changed

+510
-24
lines changed

8 files changed

+510
-24
lines changed

docs/source/reference/envs.rst

+1
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,7 @@ to be able to create this other composition:
816816
CenterCrop
817817
ClipTransform
818818
Compose
819+
ConditionalPolicySwitch
819820
Crop
820821
DTypeCastTransform
821822
DeviceCastTransform

examples/agents/ppo-chess.py

+54-18
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,24 @@
55
import tensordict.nn
66
import torch
77
import tqdm
8-
from tensordict.nn import TensorDictSequential as TDSeq, TensorDictModule as TDMod, \
9-
ProbabilisticTensorDictModule as TDProb, ProbabilisticTensorDictSequential as TDProbSeq
8+
from tensordict.nn import (
9+
ProbabilisticTensorDictModule as TDProb,
10+
ProbabilisticTensorDictSequential as TDProbSeq,
11+
TensorDictModule as TDMod,
12+
TensorDictSequential as TDSeq,
13+
)
1014
from torch import nn
1115
from torch.nn.utils import clip_grad_norm_
1216
from torch.optim import Adam
1317

1418
from torchrl.collectors import SyncDataCollector
19+
from torchrl.data import LazyTensorStorage, ReplayBuffer, SamplerWithoutReplacement
1520

1621
from torchrl.envs import ChessEnv, Tokenizer
1722
from torchrl.modules import MLP
1823
from torchrl.modules.distributions import MaskedCategorical
1924
from torchrl.objectives import ClipPPOLoss
2025
from torchrl.objectives.value import GAE
21-
from torchrl.data import ReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement
2226

2327
tensordict.nn.set_composite_lp_aggregate(False)
2428

@@ -39,7 +43,9 @@
3943
embedding_moves = nn.Embedding(num_embeddings=n + 1, embedding_dim=64)
4044

4145
# Embedding for the fen
42-
embedding_fen = nn.Embedding(num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64)
46+
embedding_fen = nn.Embedding(
47+
num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64
48+
)
4349

4450
backbone = MLP(out_features=512, num_cells=[512] * 8, activation_class=nn.ReLU)
4551

@@ -49,20 +55,30 @@
4955
critic_head = nn.Linear(512, 1)
5056
critic_head.bias.data.fill_(0)
5157

52-
prob = TDProb(in_keys=["logits", "mask"], out_keys=["action"], distribution_class=MaskedCategorical, return_log_prob=True)
58+
prob = TDProb(
59+
in_keys=["logits", "mask"],
60+
out_keys=["action"],
61+
distribution_class=MaskedCategorical,
62+
return_log_prob=True,
63+
)
64+
5365

5466
def make_mask(idx):
5567
mask = idx.new_zeros((*idx.shape[:-1], n + 1), dtype=torch.bool)
5668
return mask.scatter_(-1, idx, torch.ones_like(idx, dtype=torch.bool))[..., :-1]
5769

70+
5871
actor = TDProbSeq(
59-
TDMod(
60-
make_mask,
61-
in_keys=["legal_moves"], out_keys=["mask"]),
72+
TDMod(make_mask, in_keys=["legal_moves"], out_keys=["mask"]),
6273
TDMod(embedding_moves, in_keys=["legal_moves"], out_keys=["embedded_legal_moves"]),
6374
TDMod(embedding_fen, in_keys=["fen_tokenized"], out_keys=["embedded_fen"]),
64-
TDMod(lambda *args: torch.cat([arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1), in_keys=["embedded_legal_moves", "embedded_fen"],
65-
out_keys=["features"]),
75+
TDMod(
76+
lambda *args: torch.cat(
77+
[arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1
78+
),
79+
in_keys=["embedded_legal_moves", "embedded_fen"],
80+
out_keys=["features"],
81+
),
6682
TDMod(backbone, in_keys=["features"], out_keys=["hidden"]),
6783
TDMod(actor_head, in_keys=["hidden"], out_keys=["logits"]),
6884
prob,
@@ -78,7 +94,9 @@ def make_mask(idx):
7894

7995
optim = Adam(loss.parameters())
8096

81-
gae = GAE(value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True)
97+
gae = GAE(
98+
value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True
99+
)
82100

83101
# Create a data collector
84102
collector = SyncDataCollector(
@@ -88,12 +106,20 @@ def make_mask(idx):
88106
total_frames=1_000_000,
89107
)
90108

91-
replay_buffer0 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
92-
replay_buffer1 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
109+
replay_buffer0 = ReplayBuffer(
110+
storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2),
111+
batch_size=batch_size,
112+
sampler=SamplerWithoutReplacement(),
113+
)
114+
replay_buffer1 = ReplayBuffer(
115+
storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2),
116+
batch_size=batch_size,
117+
sampler=SamplerWithoutReplacement(),
118+
)
93119

94120
for data in tqdm.tqdm(collector):
95121
data = data.filter_non_tensor_data()
96-
print('data', data[0::2])
122+
print("data", data[0::2])
97123
for i in range(num_epochs):
98124
replay_buffer0.empty()
99125
replay_buffer1.empty()
@@ -103,14 +129,24 @@ def make_mask(idx):
103129
# player 1
104130
data1 = gae(data[1::2])
105131
if i == 0:
106-
print('win rate for 0', data0["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
107-
print('win rate for 1', data1["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
132+
print(
133+
"win rate for 0",
134+
data0["next", "reward"].sum()
135+
/ data["next", "done"].sum().clamp_min(1e-6),
136+
)
137+
print(
138+
"win rate for 1",
139+
data1["next", "reward"].sum()
140+
/ data["next", "done"].sum().clamp_min(1e-6),
141+
)
108142

109143
replay_buffer0.extend(data0)
110144
replay_buffer1.extend(data1)
111145

112-
n_iter = collector.frames_per_batch//(2 * batch_size)
113-
for (d0, d1) in tqdm.tqdm(zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter):
146+
n_iter = collector.frames_per_batch // (2 * batch_size)
147+
for (d0, d1) in tqdm.tqdm(
148+
zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter
149+
):
114150
loss_vals = (loss(d0) + loss(d1)) / 2
115151
loss_vals.sum(reduce=True).backward()
116152
gn = clip_grad_norm_(loss.parameters(), 100.0)

test/test_transforms.py

+203
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
import tensordict.tensordict
2222
import torch
23+
from tensordict.nn import WrapModule
24+
2325
from tensordict import (
2426
NonTensorData,
2527
NonTensorStack,
@@ -56,6 +58,7 @@
5658
CenterCrop,
5759
ClipTransform,
5860
Compose,
61+
ConditionalPolicySwitch,
5962
Crop,
6063
DeviceCastTransform,
6164
DiscreteActionProjection,
@@ -13341,6 +13344,206 @@ def test_composite_reward_spec(self) -> None:
1334113344
assert transform.transform_reward_spec(reward_spec) == expected_reward_spec
1334213345

1334313346

13347+
class TestConditionalPolicySwitch(TransformBase):
13348+
def test_single_trans_env_check(self):
13349+
base_env = CountingEnv(max_steps=15)
13350+
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
13351+
# Player 0
13352+
policy_odd = lambda td: td.set("action", env.action_spec.zero())
13353+
policy_even = lambda td: td.set("action", env.action_spec.one())
13354+
transforms = Compose(
13355+
StepCounter(),
13356+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13357+
)
13358+
env = base_env.append_transform(transforms)
13359+
env.check_env_specs()
13360+
13361+
def _create_policy_odd(self, base_env):
13362+
return WrapModule(
13363+
lambda td, base_env=base_env: td.set(
13364+
"action", base_env.action_spec_unbatched.zero(td.shape)
13365+
),
13366+
out_keys=["action"],
13367+
)
13368+
13369+
def _create_policy_even(self, base_env):
13370+
return WrapModule(
13371+
lambda td, base_env=base_env: td.set(
13372+
"action", base_env.action_spec_unbatched.one(td.shape)
13373+
),
13374+
out_keys=["action"],
13375+
)
13376+
13377+
def _create_transforms(self, condition, policy_even):
13378+
return Compose(
13379+
StepCounter(),
13380+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13381+
)
13382+
13383+
def _make_env(self, max_count, env_cls):
13384+
torch.manual_seed(0)
13385+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13386+
base_env = env_cls(max_steps=max_count)
13387+
policy_even = self._create_policy_even(base_env)
13388+
transforms = self._create_transforms(condition, policy_even)
13389+
return base_env.append_transform(transforms)
13390+
13391+
def _test_env(self, env, policy_odd):
13392+
env.check_env_specs()
13393+
env.set_seed(0)
13394+
r = env.rollout(100, policy_odd, break_when_any_done=False)
13395+
# Check results are independent: one reset / step in one env should not impact results in another
13396+
r0, r1, r2 = r.unbind(0)
13397+
r0_split = r0.split(6)
13398+
assert all(((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:]))
13399+
r1_split = r1.split(7)
13400+
assert all(((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:]))
13401+
r2_split = r2.split(8)
13402+
assert all(((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:]))
13403+
13404+
def test_trans_serial_env_check(self):
13405+
torch.manual_seed(0)
13406+
base_env = SerialEnv(
13407+
3,
13408+
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
13409+
batch_locked=False,
13410+
)
13411+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13412+
policy_odd = self._create_policy_odd(base_env)
13413+
policy_even = self._create_policy_even(base_env)
13414+
transforms = self._create_transforms(condition, policy_even)
13415+
env = base_env.append_transform(transforms)
13416+
self._test_env(env, policy_odd)
13417+
13418+
def test_trans_parallel_env_check(self):
13419+
torch.manual_seed(0)
13420+
base_env = ParallelEnv(
13421+
3,
13422+
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
13423+
batch_locked=False,
13424+
mp_start_method=mp_ctx,
13425+
)
13426+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13427+
policy_odd = self._create_policy_odd(base_env)
13428+
policy_even = self._create_policy_even(base_env)
13429+
transforms = self._create_transforms(condition, policy_even)
13430+
env = base_env.append_transform(transforms)
13431+
self._test_env(env, policy_odd)
13432+
13433+
def test_serial_trans_env_check(self):
13434+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13435+
policy_odd = self._create_policy_odd(CountingEnv())
13436+
13437+
def make_env(max_count):
13438+
return partial(self._make_env, max_count, CountingEnv)
13439+
13440+
env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)])
13441+
self._test_env(env, policy_odd)
13442+
13443+
def test_parallel_trans_env_check(self):
13444+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13445+
policy_odd = self._create_policy_odd(CountingEnv())
13446+
13447+
def make_env(max_count):
13448+
return partial(self._make_env, max_count, CountingEnv)
13449+
13450+
env = ParallelEnv(
13451+
3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx
13452+
)
13453+
self._test_env(env, policy_odd)
13454+
13455+
def test_transform_no_env(self):
13456+
policy_odd = lambda td: td
13457+
policy_even = lambda td: td
13458+
condition = lambda td: True
13459+
transforms = ConditionalPolicySwitch(condition=condition, policy=policy_even)
13460+
with pytest.raises(
13461+
RuntimeError,
13462+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
13463+
):
13464+
transforms(TensorDict())
13465+
13466+
def test_transform_compose(self):
13467+
policy_odd = lambda td: td
13468+
policy_even = lambda td: td
13469+
condition = lambda td: True
13470+
transforms = Compose(
13471+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13472+
)
13473+
with pytest.raises(
13474+
RuntimeError,
13475+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
13476+
):
13477+
transforms(TensorDict())
13478+
13479+
def test_transform_env(self):
13480+
base_env = CountingEnv(max_steps=15)
13481+
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
13482+
# Player 0
13483+
policy_odd = lambda td: td.set("action", env.action_spec.zero())
13484+
policy_even = lambda td: td.set("action", env.action_spec.one())
13485+
transforms = Compose(
13486+
StepCounter(),
13487+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13488+
)
13489+
env = base_env.append_transform(transforms)
13490+
env.check_env_specs()
13491+
r = env.rollout(1000, policy_odd, break_when_all_done=True)
13492+
assert r.shape[0] == 15
13493+
assert (r["action"] == 0).all()
13494+
assert (
13495+
r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1)
13496+
).all()
13497+
assert r["next", "done"].any()
13498+
13499+
# Player 1
13500+
condition = lambda td: ((td.get("step_count") % 2) == 1).all()
13501+
transforms = Compose(
13502+
StepCounter(),
13503+
ConditionalPolicySwitch(condition=condition, policy=policy_odd),
13504+
)
13505+
env = base_env.append_transform(transforms)
13506+
r = env.rollout(1000, policy_even, break_when_all_done=True)
13507+
assert r.shape[0] == 16
13508+
assert (r["action"] == 1).all()
13509+
assert (
13510+
r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1)
13511+
).all()
13512+
assert r["next", "done"].any()
13513+
13514+
def test_transform_model(self):
13515+
policy_odd = lambda td: td
13516+
policy_even = lambda td: td
13517+
condition = lambda td: True
13518+
transforms = nn.Sequential(
13519+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13520+
)
13521+
with pytest.raises(
13522+
RuntimeError,
13523+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
13524+
):
13525+
transforms(TensorDict())
13526+
13527+
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
13528+
def test_transform_rb(self, rbclass):
13529+
policy_odd = lambda td: td
13530+
policy_even = lambda td: td
13531+
condition = lambda td: True
13532+
rb = rbclass(storage=LazyTensorStorage(10))
13533+
rb.append_transform(
13534+
ConditionalPolicySwitch(condition=condition, policy=policy_even)
13535+
)
13536+
rb.extend(TensorDict(batch_size=[2]))
13537+
with pytest.raises(
13538+
RuntimeError,
13539+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
13540+
):
13541+
rb.sample(2)
13542+
13543+
def test_transform_inverse(self):
13544+
return
13545+
13546+
1334413547
if __name__ == "__main__":
1334513548
args, unknown = argparse.ArgumentParser().parse_known_args()
1334613549
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
CenterCrop,
5656
ClipTransform,
5757
Compose,
58+
ConditionalPolicySwitch,
5859
Crop,
5960
DeviceCastTransform,
6061
DiscreteActionProjection,

0 commit comments

Comments
 (0)