Open
Description
Hi 👋 ,
Thaks for the repo!
I am currently testing out your implementation of S5. Sadly I am not very familiar with the S5 architecture.
When I run your code I get this warning:
~/.local/lib/python3.10/site-packages/jax/_src/lax/lax.py:2652: ComplexWarning: Casting complex values to real discards the imaginary part
x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
The warning originates in the PPO loss computation and is related to the complex parameters of the S5 model.
The command I am running is below.
Is this behavior intended? I tried reading up on the literature on S4 and S5 but it was not immediately obvious to me so I have little intuition around what it means to cast complex parameters to float.
Feedback is appreciated! Thanks!
python3 -m minimax.train \
--seed=1 \
--agent_rl_algo=ppo \
--n_total_updates=30000 \
--train_runner=plr \
--n_devices=1 \
--student_model_name=default_student_cnn \
--env_name=Maze \
--verbose=False \
--log_dir=~/logs/minimax \
--log_interval=10 \
--from_last_checkpoint=True \
--checkpoint_interval=1000 \
--archive_interval=0 \
--archive_init_checkpoint=False \
--test_interval=100 \
--n_students=1 \
--n_parallel=32 \
--n_eval=1 \
--n_rollout_steps=256 \
--lr=3e-05 \
--lr_anneal_steps=0 \
--max_grad_norm=0.5 \
--adam_eps=1e-05 \
--track_env_metrics=True \
--discount=0.999 \
--n_unroll_rollout=10 \
--render=False \
--ued_score=max_mc \
--plr_replay_prob=0.5 \
--plr_buffer_size=4000 \
--plr_staleness_coef=0.3 \
--plr_temp=0.3 \
--plr_use_score_ranks=True \
--plr_min_fill_ratio=0.5 \
--plr_use_robust_plr=True \
--plr_use_parallel_eval=False \
--plr_force_unique=True \
--student_gae_lambda=0.98 \
--student_entropy_coef=0.001 \
--student_value_loss_coef=0.5 \
--student_n_unroll_update=5 \
--student_ppo_n_epochs=5 \
--student_ppo_n_minibatches=1 \
--student_ppo_clip_eps=0.2 \
--student_ppo_clip_value_loss=True \
--student_recurrent_arch=s5 \
--student_recurrent_hidden_dim=256 \
--student_hidden_dim=32 \
--student_n_hidden_layers=1 \
--student_n_conv_filters=16 \
--student_n_scalar_embeddings=4 \
--student_scalar_embed_dim=5 \
--student_s5_n_blocks=2 \
--student_s5_n_layers=2 \
--student_s5_layernorm_pos=pre \
--student_s5_activation=half_glu1 \
--maze_height=13 \
--maze_width=13 \
--maze_n_walls=60 \
--maze_replace_wall_pos=True \
--maze_sample_n_walls=False \
--maze_see_agent=False \
--maze_normalize_obs=True \
--maze_obs_agent_pos=False \
--maze_max_episode_steps=250 \
--test_n_episodes=10 \
--test_env_names=Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze \
--maze_test_see_agent=False \
--maze_test_normalize_obs=True \
--xpid=plr-maze13x13w60na_f-rf_p0.5b4000t0.3s0.3m0.5r_r1s_32p_1e_256t_ae1e-05_smm-ppo_lr3e-05g0.999cv0.5ce0.001e5mb1l0.98_pc0.2_h32cf16fc1se5ba_re_lpr_ahg1_s5_h256nb2nl2_0
Metadata
Metadata
Assignees
Labels
No labels