diff --git a/tests/units/environments/stablebaselines3/test_custom_env.py b/tests/units/environments/stablebaselines3/test_custom_env.py index a004cec..28be11f 100644 --- a/tests/units/environments/stablebaselines3/test_custom_env.py +++ b/tests/units/environments/stablebaselines3/test_custom_env.py @@ -37,18 +37,22 @@ def test_custom_env_step(self,env_mock, state_mock, action_space_mock, reward_mo reward_mock.get.assert_called_once() self.assertEqual(step_return, (None, None, None, None, {})) + @patch('urnai.rewards.reward_base.RewardBase') @patch('urnai.environments.environment_base.EnvironmentBase') @patch('urnai.states.state_base.StateBase') - def test_custom_env_reset(self, state_mock, env_mock): + def test_custom_env_reset(self, state_mock, env_mock, reward_mock): # GIVEN - env = CustomEnv(env_mock, state_mock, None, None, None, None) + env = CustomEnv(env_mock, state_mock, None, reward_mock, None, None) env_mock.reset.return_value = None state_mock.update.return_value = None + reward_mock.reset.return_value = None + reward_mock.get.return_value = None # WHEN reset_return = env.reset() # THEN env_mock.reset.assert_called_once() state_mock.update.assert_called_once() + reward_mock.reset.assert_called_once() self.assertEqual(reset_return, (None, {})) def test_custom_env_render(self):