-
Notifications
You must be signed in to change notification settings - Fork 55
Open
Labels
bugSomething isn't workingSomething isn't working
Description
🐛 Bug
PPO fails when the optimizer is not an optax.chain of =2 elements. Using a single-transform optimizer (e.g. just optax.adam) causes a runtime error.
To Reproduce
- Define a custom
PPOPolicythat builds itsTrainStatewith anoptax.chain
containing only a single element (e.g. justoptax.adam). - Initialize a PPO model with this policy.
- Call
.learn(...)on the PPO model. - Observe that training fails with
IndexErroratsbx/ppo/ppo.py:262.
from typing import Any
import flax.linen as nn
import numpy as np
import optax
import jax
from flax.training.train_state import TrainState
import jax.numpy as jnp
from gymnasium import spaces
from stable_baselines3.common.env_util import make_vec_env
from sbx import PPO
from sbx.ppo.policies import PPOPolicy, Actor, Critic
class CustomPPO(PPOPolicy):
def __init__(self, observation_space, action_space, lr_schedule, **kwargs):
super().__init__(observation_space, action_space, lr_schedule, **kwargs)
def build(self, key: jax.Array, lr_schedule, max_grad_norm) -> jax.Array:
key, actor_key, vf_key = jax.random.split(key, 3)
key, self.key = jax.random.split(key, 2)
self.reset_noise()
obs = jnp.array([self.observation_space.sample()])
self.actor = self.actor_class(
net_arch=self.net_arch_pi,
log_std_init=self.log_std_init,
activation_fn=self.activation_fn,
ortho_init=self.ortho_init,
action_dim=int(np.prod(self.action_space.shape)),
)
optimizer_class = optax.inject_hyperparams(self.optimizer_class)(
learning_rate=lr_schedule(1), **self.optimizer_kwargs
)
# Optax chain with only one element
self.actor_state = TrainState.create(
apply_fn=self.actor.apply,
params=self.actor.init(actor_key, obs),
tx=optax.chain(
optimizer_class,
),
)
self.vf = self.critic_class(net_arch=self.net_arch_vf, activation_fn=self.activation_fn)
self.vf_state = TrainState.create(
apply_fn=self.vf.apply,
params=self.vf.init(vf_key, obs),
tx=optax.chain(
optimizer_class,
),
)
self.actor.apply = jax.jit(self.actor.apply) # type: ignore
self.vf.apply = jax.jit(self.vf.apply) # type: ignore
return key
def test_ppo() -> None:
env = make_vec_env("Pendulum-v1")
model = PPO(CustomPPO, env)
model.learn(64, progress_bar=True)tests\test_ppo.py:110:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
sbx\ppo\ppo.py:351: in learn
return super().learn(
venv\.venv\Lib\site-packages\stable_baselines3\common\on_policy_algorithm.py:337: in learn
self.train()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <sbx.ppo.ppo.PPO object at 0x000001DCE2D957F0>
def train(self) -> None:
"""
Update policy using the currently gathered rollout buffer.
"""
# Update optimizer learning rate
if self.target_kl is None:
self._update_learning_rate(
> [self.policy.actor_state.opt_state[1], self.policy.vf_state.opt_state[1]],
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
learning_rate=self.lr_schedule(self._current_progress_remaining),
)
E IndexError: tuple index out of range
sbx\ppo\ppo.py:262: IndexErrorExpected behavior
PPO should not assume that the optimizer is an optax.chain with two transforms. It should work with any valid Optax optimizer, including a single transform like optax.adam. The number of transforms in the chain should be irrelevant for learning to proceed.
### System Info
- OS: Windows-11-10.0.26100-SP0 10.0.26100
- Python: 3.12.7
- Stable-Baselines3: 2.7.0
- PyTorch: 2.8.0+cpu
- GPU Enabled: False
- Numpy: 2.3.2
- Cloudpickle: 3.1.1
- Gymnasium: 1.2.0
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working