|
20 | 20 |
|
21 | 21 | import tensordict.tensordict
|
22 | 22 | import torch
|
| 23 | +from tensordict.nn import WrapModule |
23 | 24 |
|
24 | 25 | from torchrl.collectors import MultiSyncDataCollector
|
25 | 26 |
|
|
106 | 107 | CenterCrop,
|
107 | 108 | ClipTransform,
|
108 | 109 | Compose,
|
| 110 | + ConditionalPolicySwitch, |
109 | 111 | Crop,
|
110 | 112 | DeviceCastTransform,
|
111 | 113 | DiscreteActionProjection,
|
@@ -13192,6 +13194,206 @@ def test_composite_reward_spec(self) -> None:
|
13192 | 13194 | assert transform.transform_reward_spec(reward_spec) == expected_reward_spec
|
13193 | 13195 |
|
13194 | 13196 |
|
| 13197 | +class TestConditionalPolicySwitch(TransformBase): |
| 13198 | + def test_single_trans_env_check(self): |
| 13199 | + base_env = CountingEnv(max_steps=15) |
| 13200 | + condition = lambda td: ((td.get("step_count") % 2) == 0).all() |
| 13201 | + # Player 0 |
| 13202 | + policy_odd = lambda td: td.set("action", env.action_spec.zero()) |
| 13203 | + policy_even = lambda td: td.set("action", env.action_spec.one()) |
| 13204 | + transforms = Compose( |
| 13205 | + StepCounter(), |
| 13206 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 13207 | + ) |
| 13208 | + env = base_env.append_transform(transforms) |
| 13209 | + env.check_env_specs() |
| 13210 | + |
| 13211 | + def _create_policy_odd(self, base_env): |
| 13212 | + return WrapModule( |
| 13213 | + lambda td, base_env=base_env: td.set( |
| 13214 | + "action", base_env.action_spec_unbatched.zero(td.shape) |
| 13215 | + ), |
| 13216 | + out_keys=["action"], |
| 13217 | + ) |
| 13218 | + |
| 13219 | + def _create_policy_even(self, base_env): |
| 13220 | + return WrapModule( |
| 13221 | + lambda td, base_env=base_env: td.set( |
| 13222 | + "action", base_env.action_spec_unbatched.one(td.shape) |
| 13223 | + ), |
| 13224 | + out_keys=["action"], |
| 13225 | + ) |
| 13226 | + |
| 13227 | + def _create_transforms(self, condition, policy_even): |
| 13228 | + return Compose( |
| 13229 | + StepCounter(), |
| 13230 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 13231 | + ) |
| 13232 | + |
| 13233 | + def _make_env(self, max_count, env_cls): |
| 13234 | + torch.manual_seed(0) |
| 13235 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 13236 | + base_env = env_cls(max_steps=max_count) |
| 13237 | + policy_even = self._create_policy_even(base_env) |
| 13238 | + transforms = self._create_transforms(condition, policy_even) |
| 13239 | + return base_env.append_transform(transforms) |
| 13240 | + |
| 13241 | + def _test_env(self, env, policy_odd): |
| 13242 | + env.check_env_specs() |
| 13243 | + env.set_seed(0) |
| 13244 | + r = env.rollout(100, policy_odd, break_when_any_done=False) |
| 13245 | + # Check results are independent: one reset / step in one env should not impact results in another |
| 13246 | + r0, r1, r2 = r.unbind(0) |
| 13247 | + r0_split = r0.split(6) |
| 13248 | + assert all(((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:])) |
| 13249 | + r1_split = r1.split(7) |
| 13250 | + assert all(((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:])) |
| 13251 | + r2_split = r2.split(8) |
| 13252 | + assert all(((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:])) |
| 13253 | + |
| 13254 | + def test_trans_serial_env_check(self): |
| 13255 | + torch.manual_seed(0) |
| 13256 | + base_env = SerialEnv( |
| 13257 | + 3, |
| 13258 | + [partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)], |
| 13259 | + batch_locked=False, |
| 13260 | + ) |
| 13261 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 13262 | + policy_odd = self._create_policy_odd(base_env) |
| 13263 | + policy_even = self._create_policy_even(base_env) |
| 13264 | + transforms = self._create_transforms(condition, policy_even) |
| 13265 | + env = base_env.append_transform(transforms) |
| 13266 | + self._test_env(env, policy_odd) |
| 13267 | + |
| 13268 | + def test_trans_parallel_env_check(self): |
| 13269 | + torch.manual_seed(0) |
| 13270 | + base_env = ParallelEnv( |
| 13271 | + 3, |
| 13272 | + [partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)], |
| 13273 | + batch_locked=False, |
| 13274 | + mp_start_method=mp_ctx, |
| 13275 | + ) |
| 13276 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 13277 | + policy_odd = self._create_policy_odd(base_env) |
| 13278 | + policy_even = self._create_policy_even(base_env) |
| 13279 | + transforms = self._create_transforms(condition, policy_even) |
| 13280 | + env = base_env.append_transform(transforms) |
| 13281 | + self._test_env(env, policy_odd) |
| 13282 | + |
| 13283 | + def test_serial_trans_env_check(self): |
| 13284 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 13285 | + policy_odd = self._create_policy_odd(CountingEnv()) |
| 13286 | + |
| 13287 | + def make_env(max_count): |
| 13288 | + return partial(self._make_env, max_count, CountingEnv) |
| 13289 | + |
| 13290 | + env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)]) |
| 13291 | + self._test_env(env, policy_odd) |
| 13292 | + |
| 13293 | + def test_parallel_trans_env_check(self): |
| 13294 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 13295 | + policy_odd = self._create_policy_odd(CountingEnv()) |
| 13296 | + |
| 13297 | + def make_env(max_count): |
| 13298 | + return partial(self._make_env, max_count, CountingEnv) |
| 13299 | + |
| 13300 | + env = ParallelEnv( |
| 13301 | + 3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx |
| 13302 | + ) |
| 13303 | + self._test_env(env, policy_odd) |
| 13304 | + |
| 13305 | + def test_transform_no_env(self): |
| 13306 | + policy_odd = lambda td: td |
| 13307 | + policy_even = lambda td: td |
| 13308 | + condition = lambda td: True |
| 13309 | + transforms = ConditionalPolicySwitch(condition=condition, policy=policy_even) |
| 13310 | + with pytest.raises( |
| 13311 | + RuntimeError, |
| 13312 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 13313 | + ): |
| 13314 | + transforms(TensorDict()) |
| 13315 | + |
| 13316 | + def test_transform_compose(self): |
| 13317 | + policy_odd = lambda td: td |
| 13318 | + policy_even = lambda td: td |
| 13319 | + condition = lambda td: True |
| 13320 | + transforms = Compose( |
| 13321 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 13322 | + ) |
| 13323 | + with pytest.raises( |
| 13324 | + RuntimeError, |
| 13325 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 13326 | + ): |
| 13327 | + transforms(TensorDict()) |
| 13328 | + |
| 13329 | + def test_transform_env(self): |
| 13330 | + base_env = CountingEnv(max_steps=15) |
| 13331 | + condition = lambda td: ((td.get("step_count") % 2) == 0).all() |
| 13332 | + # Player 0 |
| 13333 | + policy_odd = lambda td: td.set("action", env.action_spec.zero()) |
| 13334 | + policy_even = lambda td: td.set("action", env.action_spec.one()) |
| 13335 | + transforms = Compose( |
| 13336 | + StepCounter(), |
| 13337 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 13338 | + ) |
| 13339 | + env = base_env.append_transform(transforms) |
| 13340 | + env.check_env_specs() |
| 13341 | + r = env.rollout(1000, policy_odd, break_when_all_done=True) |
| 13342 | + assert r.shape[0] == 15 |
| 13343 | + assert (r["action"] == 0).all() |
| 13344 | + assert ( |
| 13345 | + r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1) |
| 13346 | + ).all() |
| 13347 | + assert r["next", "done"].any() |
| 13348 | + |
| 13349 | + # Player 1 |
| 13350 | + condition = lambda td: ((td.get("step_count") % 2) == 1).all() |
| 13351 | + transforms = Compose( |
| 13352 | + StepCounter(), |
| 13353 | + ConditionalPolicySwitch(condition=condition, policy=policy_odd), |
| 13354 | + ) |
| 13355 | + env = base_env.append_transform(transforms) |
| 13356 | + r = env.rollout(1000, policy_even, break_when_all_done=True) |
| 13357 | + assert r.shape[0] == 16 |
| 13358 | + assert (r["action"] == 1).all() |
| 13359 | + assert ( |
| 13360 | + r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1) |
| 13361 | + ).all() |
| 13362 | + assert r["next", "done"].any() |
| 13363 | + |
| 13364 | + def test_transform_model(self): |
| 13365 | + policy_odd = lambda td: td |
| 13366 | + policy_even = lambda td: td |
| 13367 | + condition = lambda td: True |
| 13368 | + transforms = nn.Sequential( |
| 13369 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 13370 | + ) |
| 13371 | + with pytest.raises( |
| 13372 | + RuntimeError, |
| 13373 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 13374 | + ): |
| 13375 | + transforms(TensorDict()) |
| 13376 | + |
| 13377 | + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) |
| 13378 | + def test_transform_rb(self, rbclass): |
| 13379 | + policy_odd = lambda td: td |
| 13380 | + policy_even = lambda td: td |
| 13381 | + condition = lambda td: True |
| 13382 | + rb = rbclass(storage=LazyTensorStorage(10)) |
| 13383 | + rb.append_transform( |
| 13384 | + ConditionalPolicySwitch(condition=condition, policy=policy_even) |
| 13385 | + ) |
| 13386 | + rb.extend(TensorDict(batch_size=[2])) |
| 13387 | + with pytest.raises( |
| 13388 | + RuntimeError, |
| 13389 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 13390 | + ): |
| 13391 | + rb.sample(2) |
| 13392 | + |
| 13393 | + def test_transform_inverse(self): |
| 13394 | + return |
| 13395 | + |
| 13396 | + |
13195 | 13397 | if __name__ == "__main__":
|
13196 | 13398 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
13197 | 13399 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
0 commit comments