Skip to content

Commit

Permalink
[experimental] Add error message to auto_reset (#1102)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Nov 8, 2023
1 parent 9b97fbe commit 3186fd2
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion pgx/experimental/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -29,7 +31,16 @@ def auto_reset(step_fn, init_fn):
2. Performance
"""

def wrapped_step_fn(state: State, action: Array, key: PRNGKey):
def wrapped_step_fn(
state: State, action: Array, key: Optional[PRNGKey] = None
):
assert key is not None, (
"v2.0.0 changes the signature of auto reset. Please specify PRNGKey at the third argument:\n\n"
" * < v2.0.0: step_fn(state, action)\n"
" * >= v2.0.0: step_fn(state, action, key)\n\n"
"Note that codes under pgx.experimental are subject to change without notice."
)

key1, key2 = jax.random.split(key)
state = jax.lax.cond(
(state.terminated | state.truncated),
Expand Down

0 comments on commit 3186fd2

Please sign in to comment.