File tree Expand file tree Collapse file tree 1 file changed +12
-1
lines changed Expand file tree Collapse file tree 1 file changed +12
-1
lines changed Original file line number Diff line number Diff line change
1
+ from typing import Optional
2
+
1
3
import jax
2
4
import jax .numpy as jnp
3
5
@@ -29,7 +31,16 @@ def auto_reset(step_fn, init_fn):
29
31
2. Performance
30
32
"""
31
33
32
- def wrapped_step_fn (state : State , action : Array , key : PRNGKey ):
34
+ def wrapped_step_fn (
35
+ state : State , action : Array , key : Optional [PRNGKey ] = None
36
+ ):
37
+ assert key is not None , (
38
+ "v2.0.0 changes the signature of auto reset. Please specify PRNGKey at the third argument:\n \n "
39
+ " * < v2.0.0: step_fn(state, action)\n "
40
+ " * >= v2.0.0: step_fn(state, action, key)\n \n "
41
+ "Note that codes under pgx.experimental are subject to change without notice."
42
+ )
43
+
33
44
key1 , key2 = jax .random .split (key )
34
45
state = jax .lax .cond (
35
46
(state .terminated | state .truncated ),
You can’t perform that action at this time.
0 commit comments