From 8a96e03b3b86b8cbe9ef0f1b4e8820154c1d57a8 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Mon, 18 Nov 2024 07:44:28 -0300 Subject: [PATCH] fix: Added reward reset to custom env test --- .../environments/stablebaselines3/test_custom_env.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/units/environments/stablebaselines3/test_custom_env.py b/tests/units/environments/stablebaselines3/test_custom_env.py index a004cec..28be11f 100644 --- a/tests/units/environments/stablebaselines3/test_custom_env.py +++ b/tests/units/environments/stablebaselines3/test_custom_env.py @@ -37,18 +37,22 @@ def test_custom_env_step(self,env_mock, state_mock, action_space_mock, reward_mo reward_mock.get.assert_called_once() self.assertEqual(step_return, (None, None, None, None, {})) + @patch('urnai.rewards.reward_base.RewardBase') @patch('urnai.environments.environment_base.EnvironmentBase') @patch('urnai.states.state_base.StateBase') - def test_custom_env_reset(self, state_mock, env_mock): + def test_custom_env_reset(self, state_mock, env_mock, reward_mock): # GIVEN - env = CustomEnv(env_mock, state_mock, None, None, None, None) + env = CustomEnv(env_mock, state_mock, None, reward_mock, None, None) env_mock.reset.return_value = None state_mock.update.return_value = None + reward_mock.reset.return_value = None + reward_mock.get.return_value = None # WHEN reset_return = env.reset() # THEN env_mock.reset.assert_called_once() state_mock.update.assert_called_once() + reward_mock.reset.assert_called_once() self.assertEqual(reset_return, (None, {})) def test_custom_env_render(self):