Skip to content

Commit 849e908

Browse files
authored
KL Adaptive LR for PPO and LR schedule for SAC/TQC (#72)
* Only check for terminated episodes * Start adding ortho init * Add SimbaPolicy for PPO * Try adding ortho init to SAC * Enable lr schedule for PPO * Allow to pass lr, prepare for adaptive lr * Implement adaptive lr * Add small test * Refactor adaptive lr * Add adaptive lr for SAC * Fix qf_learning_rate * Revert "Fix qf_learning_rate" This reverts commit ab33983. * Revert "Add adaptive lr for SAC" This reverts commit 5832702. * Revert kl div for SAC changes * Revert dist.mode() in two lines * Cleanup code * Add support for Gaussian actor for SAC * Enable Gaussian actor for TQC * Log std too * Avoid NaN in kl div approx * Allow to use layer_norm in actor * Reformat * Allow max grad norm for TQC and fix optimizer class * Comment out max grad norm * Update to schedule classes * Add lr schedule support for TQC * Revert experimental changes and add support for lr schedule for SAC * Add test for adaptive kl div, remove squash output param
1 parent 8238fcc commit 849e908

File tree

14 files changed

+211
-62
lines changed

14 files changed

+211
-62
lines changed

sbx/common/off_policy_algorithm.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import jax
66
import numpy as np
7+
import optax
78
from gymnasium import spaces
89
from stable_baselines3 import HerReplayBuffer
910
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
@@ -15,6 +16,8 @@
1516

1617

1718
class OffPolicyAlgorithmJax(OffPolicyAlgorithm):
19+
qf_learning_rate: float
20+
1821
def __init__(
1922
self,
2023
policy: type[BasePolicy],
@@ -75,8 +78,8 @@ def __init__(
7578
)
7679
# Will be updated later
7780
self.key = jax.random.PRNGKey(0)
78-
# Note: we do not allow schedule for it
79-
self.qf_learning_rate = qf_learning_rate
81+
# Note: we do not allow separate schedule for it
82+
self.initial_qf_learning_rate = qf_learning_rate
8083
self.param_resets = param_resets
8184
self.reset_idx = 0
8285

@@ -89,7 +92,7 @@ def _maybe_reset_params(self) -> None:
8992
):
9093
# Note: we are not resetting the entropy coeff
9194
assert isinstance(self.qf_learning_rate, float)
92-
self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate)
95+
self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) # type: ignore[operator]
9396
self.reset_idx += 1
9497

9598
def _get_torch_save_params(self):
@@ -100,6 +103,29 @@ def _excluded_save_params(self) -> list[str]:
100103
excluded.remove("policy")
101104
return excluded
102105

106+
def _update_learning_rate( # type: ignore[override]
107+
self,
108+
optimizers: Union[list[optax.OptState], optax.OptState],
109+
learning_rate: float,
110+
name: str = "learning_rate",
111+
) -> None:
112+
"""
113+
Update the optimizers learning rate using the current learning rate schedule
114+
and the current progress remaining (from 1 to 0).
115+
116+
:param optimizers: An optimizer or a list of optimizers.
117+
:param learning_rate: The current learning rate to apply
118+
:param name: (Optional) A custom name for the lr (for instance qf_learning_rate)
119+
"""
120+
# Log the current learning rate
121+
self.logger.record(f"train/{name}", learning_rate)
122+
123+
if not isinstance(optimizers, list):
124+
optimizers = [optimizers]
125+
for optimizer in optimizers:
126+
# Note: the optimizer must have been defined with inject_hyperparams
127+
optimizer.hyperparams["learning_rate"] = learning_rate
128+
103129
def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override]
104130
super().set_random_seed(seed)
105131
if seed is None:
@@ -116,7 +142,7 @@ def _setup_model(self) -> None:
116142

117143
self._setup_lr_schedule()
118144
# By default qf_learning_rate = pi_learning_rate
119-
self.qf_learning_rate = self.qf_learning_rate or self.lr_schedule(1)
145+
self.qf_learning_rate = self.initial_qf_learning_rate or self.lr_schedule(1)
120146
self.set_random_seed(self.seed)
121147
# Make a local copy as we should not pickle
122148
# the environment when using HerReplayBuffer

sbx/common/on_policy_algorithm.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import gymnasium as gym
44
import jax
55
import numpy as np
6+
import optax
67
import torch as th
78
from gymnasium import spaces
89
from stable_baselines3.common.buffers import RolloutBuffer
@@ -75,6 +76,27 @@ def _excluded_save_params(self) -> list[str]:
7576
excluded.remove("policy")
7677
return excluded
7778

79+
def _update_learning_rate( # type: ignore[override]
80+
self,
81+
optimizers: Union[list[optax.OptState], optax.OptState],
82+
learning_rate: float,
83+
) -> None:
84+
"""
85+
Update the optimizers learning rate using the current learning rate schedule
86+
and the current progress remaining (from 1 to 0).
87+
88+
:param optimizers:
89+
An optimizer or a list of optimizers.
90+
"""
91+
# Log the current learning rate
92+
self.logger.record("train/learning_rate", learning_rate)
93+
94+
if not isinstance(optimizers, list):
95+
optimizers = [optimizers]
96+
for optimizer in optimizers:
97+
# Note: the optimizer must have been defined with inject_hyperparams
98+
optimizer.hyperparams["learning_rate"] = learning_rate
99+
78100
def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override]
79101
super().set_random_seed(seed)
80102
if seed is None:
@@ -167,12 +189,8 @@ def collect_rollouts(
167189

168190
# Handle timeout by bootstraping with value function
169191
# see GitHub issue #633
170-
for idx, done in enumerate(dones):
171-
if (
172-
done
173-
and infos[idx].get("terminal_observation") is not None
174-
and infos[idx].get("TimeLimit.truncated", False)
175-
):
192+
for idx in dones.nonzero()[0]:
193+
if infos[idx].get("terminal_observation") is not None and infos[idx].get("TimeLimit.truncated", False):
176194
terminal_obs = self.policy.prepare_obs(infos[idx]["terminal_observation"])[0]
177195
terminal_value = np.array(
178196
self.vf.apply( # type: ignore[union-attr]

sbx/common/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from dataclasses import dataclass
2+
3+
import numpy as np
4+
5+
6+
@dataclass
7+
class KLAdaptiveLR:
8+
"""Adaptive lr schedule, see https://arxiv.org/abs/1707.02286"""
9+
10+
# If set will trigger adaptive lr
11+
target_kl: float
12+
current_adaptive_lr: float
13+
# Values taken from https://github.com/leggedrobotics/rsl_rl
14+
min_learning_rate: float = 1e-5
15+
max_learning_rate: float = 1e-2
16+
kl_margin: float = 2.0
17+
# Divide or multiply the lr by this factor
18+
adaptive_lr_factor: float = 1.5
19+
20+
def update(self, kl_div: float) -> None:
21+
if kl_div > self.target_kl * self.kl_margin:
22+
self.current_adaptive_lr /= self.adaptive_lr_factor
23+
elif kl_div < self.target_kl / self.kl_margin:
24+
self.current_adaptive_lr *= self.adaptive_lr_factor
25+
26+
self.current_adaptive_lr = np.clip(self.current_adaptive_lr, self.min_learning_rate, self.max_learning_rate)

sbx/crossq/crossq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _setup_model(self) -> None:
158158
apply_fn=self.ent_coef.apply,
159159
params=self.ent_coef.init(ent_key)["params"],
160160
tx=optax.adam(
161-
learning_rate=self.learning_rate,
161+
learning_rate=self.lr_schedule(1),
162162
),
163163
)
164164

sbx/dqn/dqn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import optax
99
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
10-
from stable_baselines3.common.utils import get_linear_fn
10+
from stable_baselines3.common.utils import LinearSchedule
1111

1212
from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax
1313
from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState
@@ -83,7 +83,7 @@ def __init__(
8383
def _setup_model(self) -> None:
8484
super()._setup_model()
8585

86-
self.exploration_schedule = get_linear_fn(
86+
self.exploration_schedule = LinearSchedule(
8787
self.exploration_initial_eps,
8888
self.exploration_final_eps,
8989
self.exploration_fraction,

sbx/ppo/policies.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class Actor(nn.Module):
4444
# For MultiDiscrete
4545
max_num_choices: int = 0
4646
split_indices: np.ndarray = field(default_factory=lambda: np.array([]))
47+
# Last layer with small scale
48+
ortho_init: bool = False
4749

4850
def get_std(self) -> jnp.ndarray:
4951
# Make it work with gSDE
@@ -65,7 +67,15 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def
6567
x = nn.Dense(n_units)(x)
6668
x = self.activation_fn(x)
6769

68-
action_logits = nn.Dense(self.action_dim)(x)
70+
if self.ortho_init:
71+
orthogonal_init = nn.initializers.orthogonal(scale=0.01)
72+
bias_init = nn.initializers.zeros
73+
action_logits = nn.Dense(self.action_dim, kernel_init=orthogonal_init, bias_init=bias_init)(x)
74+
75+
else:
76+
action_logits = nn.Dense(self.action_dim)(x)
77+
78+
log_std = jnp.zeros(1)
6979
if self.num_discrete_choices is None:
7080
# Continuous actions
7181
log_std = self.param("log_std", constant(self.log_std_init), (self.action_dim,))
@@ -118,6 +128,8 @@ def __init__(
118128
optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam,
119129
optimizer_kwargs: Optional[dict[str, Any]] = None,
120130
share_features_extractor: bool = False,
131+
actor_class: type[nn.Module] = Actor,
132+
critic_class: type[nn.Module] = Critic,
121133
):
122134
if optimizer_kwargs is None:
123135
# Small values to avoid NaN in Adam optimizer
@@ -146,6 +158,9 @@ def __init__(
146158
else:
147159
self.net_arch_pi = self.net_arch_vf = [64, 64]
148160
self.use_sde = use_sde
161+
self.ortho_init = ortho_init
162+
self.actor_class = actor_class
163+
self.critic_class = critic_class
149164

150165
self.key = self.noise_key = jax.random.PRNGKey(0)
151166

@@ -189,38 +204,38 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) ->
189204
else:
190205
raise NotImplementedError(f"{self.action_space}")
191206

192-
self.actor = Actor(
207+
self.actor = self.actor_class(
193208
net_arch=self.net_arch_pi,
194209
log_std_init=self.log_std_init,
195210
activation_fn=self.activation_fn,
211+
ortho_init=self.ortho_init,
196212
**actor_kwargs, # type: ignore[arg-type]
197213
)
198214
# Hack to make gSDE work without modifying internal SB3 code
199215
self.actor.reset_noise = self.reset_noise
200216

217+
# Inject hyperparameters to be able to modify it later
218+
# See https://stackoverflow.com/questions/78527164
219+
# Note: eps=1e-5 for Adam
220+
optimizer_class = optax.inject_hyperparams(self.optimizer_class)(learning_rate=lr_schedule(1), **self.optimizer_kwargs)
221+
201222
self.actor_state = TrainState.create(
202223
apply_fn=self.actor.apply,
203224
params=self.actor.init(actor_key, obs),
204225
tx=optax.chain(
205226
optax.clip_by_global_norm(max_grad_norm),
206-
self.optimizer_class(
207-
learning_rate=lr_schedule(1), # type: ignore[call-arg]
208-
**self.optimizer_kwargs, # , eps=1e-5
209-
),
227+
optimizer_class,
210228
),
211229
)
212230

213-
self.vf = Critic(net_arch=self.net_arch_vf, activation_fn=self.activation_fn)
231+
self.vf = self.critic_class(net_arch=self.net_arch_vf, activation_fn=self.activation_fn)
214232

215233
self.vf_state = TrainState.create(
216234
apply_fn=self.vf.apply,
217235
params=self.vf.init({"params": vf_key}, obs),
218236
tx=optax.chain(
219237
optax.clip_by_global_norm(max_grad_norm),
220-
self.optimizer_class(
221-
learning_rate=lr_schedule(1), # type: ignore[call-arg]
222-
**self.optimizer_kwargs, # , eps=1e-5
223-
),
238+
optimizer_class,
224239
),
225240
)
226241

0 commit comments

Comments
 (0)