diff --git a/src/xminigrid/envs/minigrid/memory.py b/src/xminigrid/envs/minigrid/memory.py index 90319d1..83fcea4 100644 --- a/src/xminigrid/envs/minigrid/memory.py +++ b/src/xminigrid/envs/minigrid/memory.py @@ -110,7 +110,12 @@ def step( self, params: EnvParams, timestep: TimeStep[MemoryEnvCarry], action: IntOrArray ) -> TimeStep[MemoryEnvCarry]: # disabling pick_up action - action = jax.lax.select(jnp.equal(action, 3), jnp.asarray(5, dtype=jnp.uint8), action) + action = jax.lax.select( + jnp.equal(action, 3), + jnp.asarray(5), + action, + ).astype(jnp.uint8) + new_grid, new_agent, _ = take_action(timestep.state.grid, timestep.state.agent, action) new_state = timestep.state.replace(grid=new_grid, agent=new_agent, step_num=timestep.state.step_num + 1)