-
Notifications
You must be signed in to change notification settings - Fork 55
Description
I am following a tutorial that trains the myosuite myoHandReorient8-v0 environment using the stable baselines 3 version of SAC. The main block of code for performing training (which details the SAC parameters, hence why I'm putting it here) is:
def train(env_name, policy_name, timesteps, seed):
"""
Trains a policy using sb3 implementation of SAC.
env_name: str; name of gym env.
policy_name: str; choose unique identifier of this policy
timesteps: int; how long you want to train your policy for
seed: str (not int); relevant if you want to train multiple policies with the same params
"""
env = gym.make(env_name)
env = Monitor(env)
env = DummyVecEnv([lambda: env])
env = VecNormalize(env, norm_obs=True, norm_reward=False, clip_obs=10.)
net_shape = [400, 300]
policy_kwargs = dict(net_arch=dict(pi=net_shape, qf=net_shape))
model = SAC('MlpPolicy', env, learning_rate=linear_schedule(.001), buffer_size=int(3e5),
learning_starts=1000, batch_size=256, tau=.02, gamma=.98, train_freq=(1, "episode"),
gradient_steps=-1,policy_kwargs=policy_kwargs, verbose=1)
succ_callback = SaveSuccesses(check_freq=1, env_name=env_name+'_'+seed,
log_dir=f'{policy_name}_successes_{env_name}_{seed}')
model.set_logger(configure(f'{policy_name}_results_{env_name}_{seed}'))
model.learn(total_timesteps=int(timesteps), callback=succ_callback, log_interval=4)
model.save(f"{policy_name}_model_{env_name}_{seed}")
env.save(f'{policy_name}_env_{env_name}_{seed}')
When I call this train function and use the stable baselines 3 version of SAC (from stable_baselines3 import SAC), the model trains well. However, if I instead use the sbx version of SAC (from sbx import SAC), the actors loss, critic loss and entropy coefficient diverge:
The mujoco simulation also often becomes unstable in the SBX case:
WARNING:absl:Nan, Inf or huge value in QACC at DOF 26. The simulation is unstable. Time = 0.2380.
Simulation couldn't be stepped as intended. Issuing a reset
Naively, I would have thought that the SB3 and SBX versions of SAC would perform approximately the same for the same training parameters. Can you help me understand why this is not the case, and why parameters that work well for SB3 SAC are catastrophic for SBX SAC?
I am using stable_baselines3 2.3.2 and sbx 0.13.0.
