Skip to content

Q: "S5 ComplexWarning: Casting complex values to real discards the imaginary part" intended? #3

Open
@ConstantinRuhdorfer

Description

@ConstantinRuhdorfer

Hi 👋 ,

Thaks for the repo!
I am currently testing out your implementation of S5. Sadly I am not very familiar with the S5 architecture.
When I run your code I get this warning:

~/.local/lib/python3.10/site-packages/jax/_src/lax/lax.py:2652: ComplexWarning: Casting complex values to real discards the imaginary part
  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)

The warning originates in the PPO loss computation and is related to the complex parameters of the S5 model.
The command I am running is below.
Is this behavior intended? I tried reading up on the literature on S4 and S5 but it was not immediately obvious to me so I have little intuition around what it means to cast complex parameters to float.

Feedback is appreciated! Thanks!

python3 -m minimax.train \
--seed=1 \
--agent_rl_algo=ppo \
--n_total_updates=30000 \
--train_runner=plr \
--n_devices=1 \
--student_model_name=default_student_cnn \
--env_name=Maze \
--verbose=False \
--log_dir=~/logs/minimax \
--log_interval=10 \
--from_last_checkpoint=True \
--checkpoint_interval=1000 \
--archive_interval=0 \
--archive_init_checkpoint=False \
--test_interval=100 \
--n_students=1 \
--n_parallel=32 \
--n_eval=1 \
--n_rollout_steps=256 \
--lr=3e-05 \
--lr_anneal_steps=0 \
--max_grad_norm=0.5 \
--adam_eps=1e-05 \
--track_env_metrics=True \
--discount=0.999 \
--n_unroll_rollout=10 \
--render=False \
--ued_score=max_mc \
--plr_replay_prob=0.5 \
--plr_buffer_size=4000 \
--plr_staleness_coef=0.3 \
--plr_temp=0.3 \
--plr_use_score_ranks=True \
--plr_min_fill_ratio=0.5 \
--plr_use_robust_plr=True \
--plr_use_parallel_eval=False \
--plr_force_unique=True \
--student_gae_lambda=0.98 \
--student_entropy_coef=0.001 \
--student_value_loss_coef=0.5 \
--student_n_unroll_update=5 \
--student_ppo_n_epochs=5 \
--student_ppo_n_minibatches=1 \
--student_ppo_clip_eps=0.2 \
--student_ppo_clip_value_loss=True \
--student_recurrent_arch=s5 \
--student_recurrent_hidden_dim=256 \
--student_hidden_dim=32 \
--student_n_hidden_layers=1 \
--student_n_conv_filters=16 \
--student_n_scalar_embeddings=4 \
--student_scalar_embed_dim=5 \
--student_s5_n_blocks=2 \
--student_s5_n_layers=2 \
--student_s5_layernorm_pos=pre \
--student_s5_activation=half_glu1 \
--maze_height=13 \
--maze_width=13 \
--maze_n_walls=60 \
--maze_replace_wall_pos=True \
--maze_sample_n_walls=False \
--maze_see_agent=False \
--maze_normalize_obs=True \
--maze_obs_agent_pos=False \
--maze_max_episode_steps=250 \
--test_n_episodes=10 \
--test_env_names=Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze \
--maze_test_see_agent=False \
--maze_test_normalize_obs=True \
--xpid=plr-maze13x13w60na_f-rf_p0.5b4000t0.3s0.3m0.5r_r1s_32p_1e_256t_ae1e-05_smm-ppo_lr3e-05g0.999cv0.5ce0.001e5mb1l0.98_pc0.2_h32cf16fc1se5ba_re_lpr_ahg1_s5_h256nb2nl2_0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions