|
9 | 9 | from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 |
10 | 10 | from stable_baselines3.common.envs import FakeImageEnv |
11 | 11 | from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first |
12 | | -from stable_baselines3.common.utils import zip_strict |
13 | 12 | from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize, VecTransposeImage, is_vecenv_wrapped |
14 | 13 |
|
15 | 14 |
|
@@ -102,26 +101,26 @@ def patch_dqn_names_(model): |
102 | 101 |
|
103 | 102 |
|
104 | 103 | def params_should_match(params, other_params): |
105 | | - for param, other_param in zip_strict(params, other_params): |
| 104 | + for param, other_param in zip(params, other_params, strict=True): |
106 | 105 | assert th.allclose(param, other_param) |
107 | 106 |
|
108 | 107 |
|
109 | 108 | def params_should_differ(params, other_params): |
110 | | - for param, other_param in zip_strict(params, other_params): |
| 109 | + for param, other_param in zip(params, other_params, strict=True): |
111 | 110 | assert not th.allclose(param, other_param) |
112 | 111 |
|
113 | 112 |
|
114 | 113 | def check_td3_feature_extractor_match(model): |
115 | 114 | for (key, actor_param), critic_param in zip( |
116 | | - model.actor_target.named_parameters(), model.critic_target.parameters(), strict=True |
| 115 | + model.actor_target.named_parameters(), model.critic_target.parameters(), strict=False |
117 | 116 | ): |
118 | 117 | if "features_extractor" in key: |
119 | 118 | assert th.allclose(actor_param, critic_param), key |
120 | 119 |
|
121 | 120 |
|
122 | 121 | def check_td3_feature_extractor_differ(model): |
123 | 122 | for (key, actor_param), critic_param in zip( |
124 | | - model.actor_target.named_parameters(), model.critic_target.parameters(), strict=True |
| 123 | + model.actor_target.named_parameters(), model.critic_target.parameters(), strict=False |
125 | 124 | ): |
126 | 125 | if "features_extractor" in key: |
127 | 126 | assert not th.allclose(actor_param, critic_param), key |
|
0 commit comments