Skip to content

Commit 5539563

Browse files
committed
Fix tests when zip strict is not needed
1 parent a4520b8 commit 5539563

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

tests/test_cnn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
1010
from stable_baselines3.common.envs import FakeImageEnv
1111
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
12-
from stable_baselines3.common.utils import zip_strict
1312
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize, VecTransposeImage, is_vecenv_wrapped
1413

1514

@@ -102,26 +101,26 @@ def patch_dqn_names_(model):
102101

103102

104103
def params_should_match(params, other_params):
105-
for param, other_param in zip_strict(params, other_params):
104+
for param, other_param in zip(params, other_params, strict=True):
106105
assert th.allclose(param, other_param)
107106

108107

109108
def params_should_differ(params, other_params):
110-
for param, other_param in zip_strict(params, other_params):
109+
for param, other_param in zip(params, other_params, strict=True):
111110
assert not th.allclose(param, other_param)
112111

113112

114113
def check_td3_feature_extractor_match(model):
115114
for (key, actor_param), critic_param in zip(
116-
model.actor_target.named_parameters(), model.critic_target.parameters(), strict=True
115+
model.actor_target.named_parameters(), model.critic_target.parameters(), strict=False
117116
):
118117
if "features_extractor" in key:
119118
assert th.allclose(actor_param, critic_param), key
120119

121120

122121
def check_td3_feature_extractor_differ(model):
123122
for (key, actor_param), critic_param in zip(
124-
model.actor_target.named_parameters(), model.critic_target.parameters(), strict=True
123+
model.actor_target.named_parameters(), model.critic_target.parameters(), strict=False
125124
):
126125
if "features_extractor" in key:
127126
assert not th.allclose(actor_param, critic_param), key

0 commit comments

Comments
 (0)