Skip to content

Commit 3186fd2

Browse files
authored
[experimental] Add error message to auto_reset (#1102)
1 parent 9b97fbe commit 3186fd2

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

pgx/experimental/wrappers.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import jax
24
import jax.numpy as jnp
35

@@ -29,7 +31,16 @@ def auto_reset(step_fn, init_fn):
2931
2. Performance
3032
"""
3133

32-
def wrapped_step_fn(state: State, action: Array, key: PRNGKey):
34+
def wrapped_step_fn(
35+
state: State, action: Array, key: Optional[PRNGKey] = None
36+
):
37+
assert key is not None, (
38+
"v2.0.0 changes the signature of auto reset. Please specify PRNGKey at the third argument:\n\n"
39+
" * < v2.0.0: step_fn(state, action)\n"
40+
" * >= v2.0.0: step_fn(state, action, key)\n\n"
41+
"Note that codes under pgx.experimental are subject to change without notice."
42+
)
43+
3344
key1, key2 = jax.random.split(key)
3445
state = jax.lax.cond(
3546
(state.terminated | state.truncated),

0 commit comments

Comments
 (0)