Skip to content

Commit

Permalink
modify epsilon greedy exploration to add epsilon scheduling
Browse files Browse the repository at this point in the history
Summary: Our current epsilon-greedy exploration module does not allow annealing epsilon. Add this.

Reviewed By: rodrigodesalvobraz

Differential Revision: D65923616

fbshipit-source-id: b3e280010e2adae22a6b428abc9e6e1af42cc0dc
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Dec 13, 2024
1 parent c9576ed commit 1b86382
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,28 @@ class EGreedyExploration(UniformExplorationBase):
epsilon Greedy exploration module.
"""

def __init__(self, epsilon: float) -> None:
def __init__(
self,
epsilon: float,
start_epsilon: Optional[float] = None,
end_epsilon: Optional[float] = None,
warmup_steps: Optional[int] = None,
) -> None:
super().__init__()
self.epsilon = epsilon
self.start_epsilon = start_epsilon
self.end_epsilon = end_epsilon
self.warmup_steps = warmup_steps
self.time_step = 0
self._epsilon_scheduling: bool = (
self.start_epsilon is not None
and self.end_epsilon is not None
and self.warmup_steps is not None
)
if self._epsilon_scheduling:
assert self.start_epsilon is not None
self.curr_epsilon: float = self.start_epsilon
else:
self.curr_epsilon = epsilon

def act(
self,
Expand All @@ -39,13 +58,24 @@ def act(
action_availability_mask: torch.Tensor | None = None,
representation: torch.nn.Module | None = None,
) -> Action:
if self._epsilon_scheduling:
assert self.warmup_steps is not None
if self.time_step < self.warmup_steps:
assert self.warmup_steps is not None
frac = self.time_step / self.warmup_steps
assert self.start_epsilon is not None
assert self.end_epsilon is not None
self.curr_epsilon = (
self.start_epsilon + (self.end_epsilon - self.start_epsilon) * frac
)
self.time_step += 1
if exploit_action is None:
raise ValueError(
"exploit_action cannot be None for epsilon-greedy exploration"
)
if not isinstance(action_space, DiscreteActionSpace):
raise TypeError("action space must be discrete")
if random.random() < self.epsilon:
if random.random() < self.curr_epsilon:
return action_space.sample(action_availability_mask).to(
exploit_action.device
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __str__(self) -> str:
assert isinstance(exploration_module, EGreedyExploration)
items = [
"α=" + str(self.learning_rate),
"ε=" + str(exploration_module.epsilon),
"ε=" + str(exploration_module.curr_epsilon),
"λ=" + str(self.discount_factor),
]
return "Q-Learning" + (
Expand Down

0 comments on commit 1b86382

Please sign in to comment.