|
20 | 20 |
|
21 | 21 | import tensordict.tensordict
|
22 | 22 | import torch
|
| 23 | +from tensordict.nn import WrapModule |
| 24 | + |
23 | 25 | from tensordict import (
|
24 | 26 | NonTensorData,
|
25 | 27 | NonTensorStack,
|
|
56 | 58 | CenterCrop,
|
57 | 59 | ClipTransform,
|
58 | 60 | Compose,
|
| 61 | + ConditionalPolicySwitch, |
59 | 62 | Crop,
|
60 | 63 | DeviceCastTransform,
|
61 | 64 | DiscreteActionProjection,
|
@@ -13341,6 +13344,206 @@ def test_composite_reward_spec(self) -> None:
|
13341 | 13344 | assert transform.transform_reward_spec(reward_spec) == expected_reward_spec
|
13342 | 13345 |
|
13343 | 13346 |
|
| 13347 | +class TestConditionalPolicySwitch(TransformBase): |
| 13348 | + def test_single_trans_env_check(self): |
| 13349 | + base_env = CountingEnv(max_steps=15) |
| 13350 | + condition = lambda td: ((td.get("step_count") % 2) == 0).all() |
| 13351 | + # Player 0 |
| 13352 | + policy_odd = lambda td: td.set("action", env.action_spec.zero()) |
| 13353 | + policy_even = lambda td: td.set("action", env.action_spec.one()) |
| 13354 | + transforms = Compose( |
| 13355 | + StepCounter(), |
| 13356 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 13357 | + ) |
| 13358 | + env = base_env.append_transform(transforms) |
| 13359 | + env.check_env_specs() |
| 13360 | + |
| 13361 | + def _create_policy_odd(self, base_env): |
| 13362 | + return WrapModule( |
| 13363 | + lambda td, base_env=base_env: td.set( |
| 13364 | + "action", base_env.action_spec_unbatched.zero(td.shape) |
| 13365 | + ), |
| 13366 | + out_keys=["action"], |
| 13367 | + ) |
| 13368 | + |
| 13369 | + def _create_policy_even(self, base_env): |
| 13370 | + return WrapModule( |
| 13371 | + lambda td, base_env=base_env: td.set( |
| 13372 | + "action", base_env.action_spec_unbatched.one(td.shape) |
| 13373 | + ), |
| 13374 | + out_keys=["action"], |
| 13375 | + ) |
| 13376 | + |
| 13377 | + def _create_transforms(self, condition, policy_even): |
| 13378 | + return Compose( |
| 13379 | + StepCounter(), |
| 13380 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 13381 | + ) |
| 13382 | + |
| 13383 | + def _make_env(self, max_count, env_cls): |
| 13384 | + torch.manual_seed(0) |
| 13385 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 13386 | + base_env = env_cls(max_steps=max_count) |
| 13387 | + policy_even = self._create_policy_even(base_env) |
| 13388 | + transforms = self._create_transforms(condition, policy_even) |
| 13389 | + return base_env.append_transform(transforms) |
| 13390 | + |
| 13391 | + def _test_env(self, env, policy_odd): |
| 13392 | + env.check_env_specs() |
| 13393 | + env.set_seed(0) |
| 13394 | + r = env.rollout(100, policy_odd, break_when_any_done=False) |
| 13395 | + # Check results are independent: one reset / step in one env should not impact results in another |
| 13396 | + r0, r1, r2 = r.unbind(0) |
| 13397 | + r0_split = r0.split(6) |
| 13398 | + assert all(((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:])) |
| 13399 | + r1_split = r1.split(7) |
| 13400 | + assert all(((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:])) |
| 13401 | + r2_split = r2.split(8) |
| 13402 | + assert all(((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:])) |
| 13403 | + |
| 13404 | + def test_trans_serial_env_check(self): |
| 13405 | + torch.manual_seed(0) |
| 13406 | + base_env = SerialEnv( |
| 13407 | + 3, |
| 13408 | + [partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)], |
| 13409 | + batch_locked=False, |
| 13410 | + ) |
| 13411 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 13412 | + policy_odd = self._create_policy_odd(base_env) |
| 13413 | + policy_even = self._create_policy_even(base_env) |
| 13414 | + transforms = self._create_transforms(condition, policy_even) |
| 13415 | + env = base_env.append_transform(transforms) |
| 13416 | + self._test_env(env, policy_odd) |
| 13417 | + |
| 13418 | + def test_trans_parallel_env_check(self): |
| 13419 | + torch.manual_seed(0) |
| 13420 | + base_env = ParallelEnv( |
| 13421 | + 3, |
| 13422 | + [partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)], |
| 13423 | + batch_locked=False, |
| 13424 | + mp_start_method=mp_ctx, |
| 13425 | + ) |
| 13426 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 13427 | + policy_odd = self._create_policy_odd(base_env) |
| 13428 | + policy_even = self._create_policy_even(base_env) |
| 13429 | + transforms = self._create_transforms(condition, policy_even) |
| 13430 | + env = base_env.append_transform(transforms) |
| 13431 | + self._test_env(env, policy_odd) |
| 13432 | + |
| 13433 | + def test_serial_trans_env_check(self): |
| 13434 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 13435 | + policy_odd = self._create_policy_odd(CountingEnv()) |
| 13436 | + |
| 13437 | + def make_env(max_count): |
| 13438 | + return partial(self._make_env, max_count, CountingEnv) |
| 13439 | + |
| 13440 | + env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)]) |
| 13441 | + self._test_env(env, policy_odd) |
| 13442 | + |
| 13443 | + def test_parallel_trans_env_check(self): |
| 13444 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 13445 | + policy_odd = self._create_policy_odd(CountingEnv()) |
| 13446 | + |
| 13447 | + def make_env(max_count): |
| 13448 | + return partial(self._make_env, max_count, CountingEnv) |
| 13449 | + |
| 13450 | + env = ParallelEnv( |
| 13451 | + 3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx |
| 13452 | + ) |
| 13453 | + self._test_env(env, policy_odd) |
| 13454 | + |
| 13455 | + def test_transform_no_env(self): |
| 13456 | + policy_odd = lambda td: td |
| 13457 | + policy_even = lambda td: td |
| 13458 | + condition = lambda td: True |
| 13459 | + transforms = ConditionalPolicySwitch(condition=condition, policy=policy_even) |
| 13460 | + with pytest.raises( |
| 13461 | + RuntimeError, |
| 13462 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 13463 | + ): |
| 13464 | + transforms(TensorDict()) |
| 13465 | + |
| 13466 | + def test_transform_compose(self): |
| 13467 | + policy_odd = lambda td: td |
| 13468 | + policy_even = lambda td: td |
| 13469 | + condition = lambda td: True |
| 13470 | + transforms = Compose( |
| 13471 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 13472 | + ) |
| 13473 | + with pytest.raises( |
| 13474 | + RuntimeError, |
| 13475 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 13476 | + ): |
| 13477 | + transforms(TensorDict()) |
| 13478 | + |
| 13479 | + def test_transform_env(self): |
| 13480 | + base_env = CountingEnv(max_steps=15) |
| 13481 | + condition = lambda td: ((td.get("step_count") % 2) == 0).all() |
| 13482 | + # Player 0 |
| 13483 | + policy_odd = lambda td: td.set("action", env.action_spec.zero()) |
| 13484 | + policy_even = lambda td: td.set("action", env.action_spec.one()) |
| 13485 | + transforms = Compose( |
| 13486 | + StepCounter(), |
| 13487 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 13488 | + ) |
| 13489 | + env = base_env.append_transform(transforms) |
| 13490 | + env.check_env_specs() |
| 13491 | + r = env.rollout(1000, policy_odd, break_when_all_done=True) |
| 13492 | + assert r.shape[0] == 15 |
| 13493 | + assert (r["action"] == 0).all() |
| 13494 | + assert ( |
| 13495 | + r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1) |
| 13496 | + ).all() |
| 13497 | + assert r["next", "done"].any() |
| 13498 | + |
| 13499 | + # Player 1 |
| 13500 | + condition = lambda td: ((td.get("step_count") % 2) == 1).all() |
| 13501 | + transforms = Compose( |
| 13502 | + StepCounter(), |
| 13503 | + ConditionalPolicySwitch(condition=condition, policy=policy_odd), |
| 13504 | + ) |
| 13505 | + env = base_env.append_transform(transforms) |
| 13506 | + r = env.rollout(1000, policy_even, break_when_all_done=True) |
| 13507 | + assert r.shape[0] == 16 |
| 13508 | + assert (r["action"] == 1).all() |
| 13509 | + assert ( |
| 13510 | + r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1) |
| 13511 | + ).all() |
| 13512 | + assert r["next", "done"].any() |
| 13513 | + |
| 13514 | + def test_transform_model(self): |
| 13515 | + policy_odd = lambda td: td |
| 13516 | + policy_even = lambda td: td |
| 13517 | + condition = lambda td: True |
| 13518 | + transforms = nn.Sequential( |
| 13519 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 13520 | + ) |
| 13521 | + with pytest.raises( |
| 13522 | + RuntimeError, |
| 13523 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 13524 | + ): |
| 13525 | + transforms(TensorDict()) |
| 13526 | + |
| 13527 | + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) |
| 13528 | + def test_transform_rb(self, rbclass): |
| 13529 | + policy_odd = lambda td: td |
| 13530 | + policy_even = lambda td: td |
| 13531 | + condition = lambda td: True |
| 13532 | + rb = rbclass(storage=LazyTensorStorage(10)) |
| 13533 | + rb.append_transform( |
| 13534 | + ConditionalPolicySwitch(condition=condition, policy=policy_even) |
| 13535 | + ) |
| 13536 | + rb.extend(TensorDict(batch_size=[2])) |
| 13537 | + with pytest.raises( |
| 13538 | + RuntimeError, |
| 13539 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 13540 | + ): |
| 13541 | + rb.sample(2) |
| 13542 | + |
| 13543 | + def test_transform_inverse(self): |
| 13544 | + return |
| 13545 | + |
| 13546 | + |
13344 | 13547 | if __name__ == "__main__":
|
13345 | 13548 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
13346 | 13549 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
0 commit comments