Skip to content

Commit

Permalink
use new signature for jax.numpy.clip (#33)
Browse files Browse the repository at this point in the history
* use new signature for jax.numpy.clip

jax jax 0.4.27 deprecated a_min and a_max and added min and max. See https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-27-may-7-2024

* bump dependencies
  • Loading branch information
garymm authored Aug 10, 2024
1 parent bbc5f5f commit 5c65e54
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ classifiers = [
]

dependencies = [
"jax>=0.4.26",
"jaxlib>=0.4.26",
"jax>=0.4.27",
"jaxlib>=0.4.27",
"flax>=0.8.0",
"rich>=13.4.2",
"chex>=0.1.85",
Expand Down Expand Up @@ -101,4 +101,4 @@ reportMissingTypeStubs = false
reportMissingModuleSource = false

pythonVersion = "3.10"
pythonPlatform = "All"
pythonPlatform = "All"
4 changes: 2 additions & 2 deletions src/xminigrid/core/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def _move(position: jax.Array, direction: jax.Array) -> jax.Array:
def move_forward(grid: GridState, agent: AgentState) -> ActionOutput:
next_position = jnp.clip(
_move(agent.position, agent.direction),
a_min=jnp.array((0, 0)),
a_max=jnp.array((grid.shape[0] - 1, grid.shape[1] - 1)), # H, W
min=jnp.array((0, 0)),
max=jnp.array((grid.shape[0] - 1, grid.shape[1] - 1)), # H, W
)
position = jax.lax.select(
check_walkable(grid, next_position),
Expand Down

0 comments on commit 5c65e54

Please sign in to comment.