-
Notifications
You must be signed in to change notification settings - Fork 55
KL Adaptive LR for PPO and LR schedule for SAC/TQC #72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces adaptive learning rate scheduling based on KL divergence for PPO, extends learning rate schedules and dynamic updates to SAC and TQC, updates the SBX package version and dependencies, and adds optimizer hyperparameter injection across various policies.
- Implement KLAdaptiveLR and integrate it into PPO (
target_kl) for adaptive LR control. - Refactor SAC and TQC to use
lr_schedulefor initial and per-step learning rates. - Inject hyperparameters into JAX optimizers for dynamic LR updates and add
ortho_initsupport in PPO policies. - Bump SBX version to 0.21.0 and update
stable_baselines3requirement to>=2.6.1a1.
Reviewed Changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_run.py | Added target_kl argument to test_ppo to cover adaptive LR scheduling in PPO tests. |
| setup.py | Updated SBX dependency on stable_baselines3 to a pre-release >=2.6.1a1,<3.0. |
| sbx/version.txt | Bumped SBX package version from 0.20.0 to 0.21.0. |
| sbx/common/utils.py | Introduced KLAdaptiveLR dataclass for KL-based adaptive learning rate adjustments. |
| sbx/common/on_policy_algorithm.py | Added _update_learning_rate to allow in-flight hyperparameter injection for on-policy algos. |
| sbx/common/off_policy_algorithm.py | Added override of _update_learning_rate and renamed qf_learning_rate to initial_qf_learning_rate. |
| sbx/tqc/tqc.py | Switched optimizers to use lr_schedule at init and call dynamic LR updates during training. |
| sbx/tqc/policies.py | Injected optimizer hyperparameters via optax.inject_hyperparams and updated log_std_init default. |
| sbx/sac/sac.py | Mirrored TQC changes: use lr_schedule for initial LR and update per training step. |
| sbx/sac/policies.py | Refactored policy builder to inject hyperparameters into optimizers for dynamic LR updates. |
| sbx/ppo/ppo.py | Integrated KLAdaptiveLR, switched to FloatSchedule, extended _one_update to return ratio, and updated training loop for adaptive LR. |
| sbx/ppo/policies.py | Added ortho_init flag and optimizer hyperparameter injection; updated default log_std_init. |
| sbx/dqn/dqn.py | Replaced legacy get_linear_fn with the new LinearSchedule utility. |
| sbx/crossq/crossq.py | Changed optimizer setup to use lr_schedule(1) instead of fixed learning rate. |
Comments suppressed due to low confidence (1)
sbx/common/utils.py:7
- Add unit tests for KLAdaptiveLR.update to verify its behavior when kl_div values are above, below, or within the target KL margin.
class KLAdaptiveLR:
| # Note: most gSDE parameters are not used | ||
| # this is to keep API consistent with SB3 | ||
| log_std_init: float = -3, | ||
| log_std_init: float = 0.0, |
Copilot
AI
May 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing the default log_std_init from -3 to 0.0 alters the initial policy variance and is a breaking behavior change; document this in the changelog or consider deprecation notices.
| log_std_init: float = 0.0, | |
| log_std_init: float = -3, |
| excluded.remove("policy") | ||
| return excluded | ||
|
|
||
| def _update_learning_rate( # type: ignore[override] |
Copilot
AI
May 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The learning rate update logic is duplicated in both on-policy and off-policy algorithm classes; consider refactoring this into a shared utility to reduce code duplication.
| # For MultiDiscrete | ||
| max_num_choices: int = 0 | ||
| split_indices: np.ndarray = field(default_factory=lambda: np.array([])) | ||
| # Last layer with small scale |
Copilot
AI
May 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Document the new ortho_init parameter in the class docstring and __init__ signature to clarify its purpose and how it affects weight initialization.
| # Last layer with small scale |
Description
See #55 and #25
Motivation and Context
Types of changes
Checklist:
make format(required)make check-codestyleandmake lint(required)make pytestandmake typeboth pass. (required)make doc(required)Note: You can run most of the checks using
make commit-checks.Note: we are using a maximum length of 127 characters per line