Skip to content

[Bug] PPO fails with IndexError when using single-transform Optax optimizer #77

@wittlsn

Description

@wittlsn

🐛 Bug

PPO fails when the optimizer is not an optax.chain of =2 elements. Using a single-transform optimizer (e.g. just optax.adam) causes a runtime error.

To Reproduce

  1. Define a custom PPOPolicy that builds its TrainState with an optax.chain
    containing only a single element (e.g. just optax.adam).
  2. Initialize a PPO model with this policy.
  3. Call .learn(...) on the PPO model.
  4. Observe that training fails with IndexError at sbx/ppo/ppo.py:262.
from typing import Any
import flax.linen as nn
import numpy as np
import optax
import jax
from flax.training.train_state import TrainState
import jax.numpy as jnp
from gymnasium import spaces
from stable_baselines3.common.env_util import make_vec_env

from sbx import PPO
from sbx.ppo.policies import PPOPolicy, Actor, Critic


class CustomPPO(PPOPolicy):
    def __init__(self, observation_space, action_space, lr_schedule, **kwargs):
        super().__init__(observation_space, action_space, lr_schedule, **kwargs)

    def build(self, key: jax.Array, lr_schedule, max_grad_norm) -> jax.Array:
        key, actor_key, vf_key = jax.random.split(key, 3)
        key, self.key = jax.random.split(key, 2)
        self.reset_noise()
        obs = jnp.array([self.observation_space.sample()])

        self.actor = self.actor_class(
            net_arch=self.net_arch_pi,
            log_std_init=self.log_std_init,
            activation_fn=self.activation_fn,
            ortho_init=self.ortho_init,
            action_dim=int(np.prod(self.action_space.shape)),
        )

        optimizer_class = optax.inject_hyperparams(self.optimizer_class)(
            learning_rate=lr_schedule(1), **self.optimizer_kwargs
        )

        # Optax chain with only one element
        self.actor_state = TrainState.create(
            apply_fn=self.actor.apply,
            params=self.actor.init(actor_key, obs),
            tx=optax.chain(
                optimizer_class,
            ),
        )

        self.vf = self.critic_class(net_arch=self.net_arch_vf, activation_fn=self.activation_fn)
        self.vf_state = TrainState.create(
            apply_fn=self.vf.apply,
            params=self.vf.init(vf_key, obs),
            tx=optax.chain(
                optimizer_class,
            ),
        )

        self.actor.apply = jax.jit(self.actor.apply)  # type: ignore
        self.vf.apply = jax.jit(self.vf.apply)        # type: ignore
        return key


def test_ppo() -> None:
    env = make_vec_env("Pendulum-v1")
    model = PPO(CustomPPO, env)
    model.learn(64, progress_bar=True)
tests\test_ppo.py:110: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
sbx\ppo\ppo.py:351: in learn
    return super().learn(
venv\.venv\Lib\site-packages\stable_baselines3\common\on_policy_algorithm.py:337: in learn
    self.train()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <sbx.ppo.ppo.PPO object at 0x000001DCE2D957F0>

    def train(self) -> None:
        """
        Update policy using the currently gathered rollout buffer.
        """
        # Update optimizer learning rate
        if self.target_kl is None:
            self._update_learning_rate(
>               [self.policy.actor_state.opt_state[1], self.policy.vf_state.opt_state[1]],
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                learning_rate=self.lr_schedule(self._current_progress_remaining),
            )
E           IndexError: tuple index out of range

sbx\ppo\ppo.py:262: IndexError

Expected behavior

PPO should not assume that the optimizer is an optax.chain with two transforms. It should work with any valid Optax optimizer, including a single transform like optax.adam. The number of transforms in the chain should be irrelevant for learning to proceed.

### System Info

  • OS: Windows-11-10.0.26100-SP0 10.0.26100
  • Python: 3.12.7
  • Stable-Baselines3: 2.7.0
  • PyTorch: 2.8.0+cpu
  • GPU Enabled: False
  • Numpy: 2.3.2
  • Cloudpickle: 3.1.1
  • Gymnasium: 1.2.0

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions