From 1b86382522d1861e9a067e2819a91d6bed7692c2 Mon Sep 17 00:00:00 2001 From: Yi Wan Date: Fri, 13 Dec 2024 10:44:34 -0800 Subject: [PATCH] modify epsilon greedy exploration to add epsilon scheduling Summary: Our current epsilon-greedy exploration module does not allow annealing epsilon. Add this. Reviewed By: rodrigodesalvobraz Differential Revision: D65923616 fbshipit-source-id: b3e280010e2adae22a6b428abc9e6e1af42cc0dc --- .../common/epsilon_greedy_exploration.py | 36 +++++++++++++++++-- .../tabular_q_learning.py | 2 +- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/pearl/policy_learners/exploration_modules/common/epsilon_greedy_exploration.py b/pearl/policy_learners/exploration_modules/common/epsilon_greedy_exploration.py index 831578c..596e36c 100644 --- a/pearl/policy_learners/exploration_modules/common/epsilon_greedy_exploration.py +++ b/pearl/policy_learners/exploration_modules/common/epsilon_greedy_exploration.py @@ -26,9 +26,28 @@ class EGreedyExploration(UniformExplorationBase): epsilon Greedy exploration module. """ - def __init__(self, epsilon: float) -> None: + def __init__( + self, + epsilon: float, + start_epsilon: Optional[float] = None, + end_epsilon: Optional[float] = None, + warmup_steps: Optional[int] = None, + ) -> None: super().__init__() - self.epsilon = epsilon + self.start_epsilon = start_epsilon + self.end_epsilon = end_epsilon + self.warmup_steps = warmup_steps + self.time_step = 0 + self._epsilon_scheduling: bool = ( + self.start_epsilon is not None + and self.end_epsilon is not None + and self.warmup_steps is not None + ) + if self._epsilon_scheduling: + assert self.start_epsilon is not None + self.curr_epsilon: float = self.start_epsilon + else: + self.curr_epsilon = epsilon def act( self, @@ -39,13 +58,24 @@ def act( action_availability_mask: torch.Tensor | None = None, representation: torch.nn.Module | None = None, ) -> Action: + if self._epsilon_scheduling: + assert self.warmup_steps is not None + if self.time_step < self.warmup_steps: + assert self.warmup_steps is not None + frac = self.time_step / self.warmup_steps + assert self.start_epsilon is not None + assert self.end_epsilon is not None + self.curr_epsilon = ( + self.start_epsilon + (self.end_epsilon - self.start_epsilon) * frac + ) + self.time_step += 1 if exploit_action is None: raise ValueError( "exploit_action cannot be None for epsilon-greedy exploration" ) if not isinstance(action_space, DiscreteActionSpace): raise TypeError("action space must be discrete") - if random.random() < self.epsilon: + if random.random() < self.curr_epsilon: return action_space.sample(action_availability_mask).to( exploit_action.device ) diff --git a/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py b/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py index 8867680..6975021 100644 --- a/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py @@ -209,7 +209,7 @@ def __str__(self) -> str: assert isinstance(exploration_module, EGreedyExploration) items = [ "α=" + str(self.learning_rate), - "ε=" + str(exploration_module.epsilon), + "ε=" + str(exploration_module.curr_epsilon), "λ=" + str(self.discount_factor), ] return "Q-Learning" + (