Skip to content

SB3 and SBX versions of SAC have radically different behaviours #55

@jamesheald

Description

@jamesheald

I am following a tutorial that trains the myosuite myoHandReorient8-v0 environment using the stable baselines 3 version of SAC. The main block of code for performing training (which details the SAC parameters, hence why I'm putting it here) is:

def train(env_name, policy_name, timesteps, seed):
    """
    Trains a policy using sb3 implementation of SAC.
    
    env_name: str; name of gym env.
    policy_name: str; choose unique identifier of this policy
    timesteps: int; how long you want to train your policy for
    seed: str (not int); relevant if you want to train multiple policies with the same params
    """
    env = gym.make(env_name)
    env = Monitor(env)
    env = DummyVecEnv([lambda: env])
    env = VecNormalize(env, norm_obs=True, norm_reward=False, clip_obs=10.)
    
    net_shape = [400, 300]
    policy_kwargs = dict(net_arch=dict(pi=net_shape, qf=net_shape))
    
    model = SAC('MlpPolicy', env, learning_rate=linear_schedule(.001), buffer_size=int(3e5),
            learning_starts=1000, batch_size=256, tau=.02, gamma=.98, train_freq=(1, "episode"),
            gradient_steps=-1,policy_kwargs=policy_kwargs, verbose=1)
    
    succ_callback = SaveSuccesses(check_freq=1, env_name=env_name+'_'+seed, 
                             log_dir=f'{policy_name}_successes_{env_name}_{seed}')
    
    model.set_logger(configure(f'{policy_name}_results_{env_name}_{seed}'))
    model.learn(total_timesteps=int(timesteps), callback=succ_callback, log_interval=4)
    model.save(f"{policy_name}_model_{env_name}_{seed}")
    env.save(f'{policy_name}_env_{env_name}_{seed}')

When I call this train function and use the stable baselines 3 version of SAC (from stable_baselines3 import SAC), the model trains well. However, if I instead use the sbx version of SAC (from sbx import SAC), the actors loss, critic loss and entropy coefficient diverge:

image

The mujoco simulation also often becomes unstable in the SBX case:

WARNING:absl:Nan, Inf or huge value in QACC at DOF 26. The simulation is unstable. Time = 0.2380.
Simulation couldn't be stepped as intended. Issuing a reset

Naively, I would have thought that the SB3 and SBX versions of SAC would perform approximately the same for the same training parameters. Can you help me understand why this is not the case, and why parameters that work well for SB3 SAC are catastrophic for SBX SAC?

I am using stable_baselines3 2.3.2 and sbx 0.13.0.

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions