Skip to content

Conversation

@araffin
Copy link
Owner

@araffin araffin commented May 19, 2025

Description

See #55 and #25

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)
  • I have checked that the documentation builds using make doc (required)

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

@araffin araffin requested a review from Copilot May 19, 2025 10:42
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces adaptive learning rate scheduling based on KL divergence for PPO, extends learning rate schedules and dynamic updates to SAC and TQC, updates the SBX package version and dependencies, and adds optimizer hyperparameter injection across various policies.

  • Implement KLAdaptiveLR and integrate it into PPO (target_kl) for adaptive LR control.
  • Refactor SAC and TQC to use lr_schedule for initial and per-step learning rates.
  • Inject hyperparameters into JAX optimizers for dynamic LR updates and add ortho_init support in PPO policies.
  • Bump SBX version to 0.21.0 and update stable_baselines3 requirement to >=2.6.1a1.

Reviewed Changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tests/test_run.py Added target_kl argument to test_ppo to cover adaptive LR scheduling in PPO tests.
setup.py Updated SBX dependency on stable_baselines3 to a pre-release >=2.6.1a1,<3.0.
sbx/version.txt Bumped SBX package version from 0.20.0 to 0.21.0.
sbx/common/utils.py Introduced KLAdaptiveLR dataclass for KL-based adaptive learning rate adjustments.
sbx/common/on_policy_algorithm.py Added _update_learning_rate to allow in-flight hyperparameter injection for on-policy algos.
sbx/common/off_policy_algorithm.py Added override of _update_learning_rate and renamed qf_learning_rate to initial_qf_learning_rate.
sbx/tqc/tqc.py Switched optimizers to use lr_schedule at init and call dynamic LR updates during training.
sbx/tqc/policies.py Injected optimizer hyperparameters via optax.inject_hyperparams and updated log_std_init default.
sbx/sac/sac.py Mirrored TQC changes: use lr_schedule for initial LR and update per training step.
sbx/sac/policies.py Refactored policy builder to inject hyperparameters into optimizers for dynamic LR updates.
sbx/ppo/ppo.py Integrated KLAdaptiveLR, switched to FloatSchedule, extended _one_update to return ratio, and updated training loop for adaptive LR.
sbx/ppo/policies.py Added ortho_init flag and optimizer hyperparameter injection; updated default log_std_init.
sbx/dqn/dqn.py Replaced legacy get_linear_fn with the new LinearSchedule utility.
sbx/crossq/crossq.py Changed optimizer setup to use lr_schedule(1) instead of fixed learning rate.
Comments suppressed due to low confidence (1)

sbx/common/utils.py:7

  • Add unit tests for KLAdaptiveLR.update to verify its behavior when kl_div values are above, below, or within the target KL margin.
class KLAdaptiveLR:

# Note: most gSDE parameters are not used
# this is to keep API consistent with SB3
log_std_init: float = -3,
log_std_init: float = 0.0,
Copy link

Copilot AI May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the default log_std_init from -3 to 0.0 alters the initial policy variance and is a breaking behavior change; document this in the changelog or consider deprecation notices.

Suggested change
log_std_init: float = 0.0,
log_std_init: float = -3,

Copilot uses AI. Check for mistakes.
excluded.remove("policy")
return excluded

def _update_learning_rate( # type: ignore[override]
Copy link

Copilot AI May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The learning rate update logic is duplicated in both on-policy and off-policy algorithm classes; consider refactoring this into a shared utility to reduce code duplication.

Copilot uses AI. Check for mistakes.
# For MultiDiscrete
max_num_choices: int = 0
split_indices: np.ndarray = field(default_factory=lambda: np.array([]))
# Last layer with small scale
Copy link

Copilot AI May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Document the new ortho_init parameter in the class docstring and __init__ signature to clarify its purpose and how it affects weight initialization.

Suggested change
# Last layer with small scale

Copilot uses AI. Check for mistakes.
@araffin araffin merged commit 849e908 into master May 19, 2025
4 checks passed
@araffin araffin deleted the feat/adaptive-lr branch May 19, 2025 10:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants