Skip to content

Commit

Permalink
updated to jax new key naming (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh authored Jul 12, 2024
1 parent 405d47a commit 9ddd611
Show file tree
Hide file tree
Showing 11 changed files with 31 additions and 31 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ import xminigrid
from xminigrid.wrappers import GymAutoResetWrapper
from xminigrid.experimental.img_obs import RGBImgObservationWrapper

key = jax.random.PRNGKey(0)
key = jax.random.key(0)
reset_key, ruleset_key = jax.random.split(key)

# to list available benchmarks: xminigrid.registered_benchmarks()
Expand Down Expand Up @@ -196,11 +196,11 @@ benchmark: Benchmark = xminigrid.load_benchmark(name="trivial-1m")
benchmark: Benchmark = xminigrid.load_benchmark(name="trivial-1m")

# users can sample or get specific rulesets
benchmark.sample_ruleset(jax.random.PRNGKey(0))
benchmark.sample_ruleset(jax.random.key(0))
benchmark.get_ruleset(ruleset_id=benchmark.num_rulesets() - 1)

# or split them for train & test
train, test = benchmark.shuffle(key=jax.random.PRNGKey(0)).split(prop=0.8)
train, test = benchmark.shuffle(key=jax.random.key(0)).split(prop=0.8)
```

We also provide the [script](scripts/ruleset_generator.py) used to generate these benchmarks. Users can use it for their own purposes:
Expand Down
8 changes: 4 additions & 4 deletions examples/train_meta_standalone.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@
" benchmark = xminigrid.load_benchmark(config.benchmark_id)\n",
"\n",
" # set up training state\n",
" rng = jax.random.PRNGKey(config.train_seed)\n",
" rng = jax.random.key(config.train_seed)\n",
" rng, _rng = jax.random.split(rng)\n",
"\n",
" network = ActorCriticRNN(\n",
Expand Down Expand Up @@ -629,7 +629,7 @@
" rng, train_state = runner_state[:2]\n",
"\n",
" # EVALUATE AGENT\n",
" eval_ruleset_rng, eval_reset_rng = jax.random.split(jax.random.PRNGKey(config.eval_seed))\n",
" eval_ruleset_rng, eval_reset_rng = jax.random.split(jax.random.key(config.eval_seed))\n",
" eval_ruleset_rng = jax.random.split(eval_ruleset_rng, num=config.eval_num_envs_per_device)\n",
" eval_reset_rng = jax.random.split(eval_reset_rng, num=config.eval_num_envs_per_device)\n",
"\n",
Expand Down Expand Up @@ -756,7 +756,7 @@
"total_reward, num_episodes = 0, 0\n",
"rendered_imgs = []\n",
"\n",
"rng = jax.random.PRNGKey(1)\n",
"rng = jax.random.key(1)\n",
"rng, _rng = jax.random.split(rng)\n",
"\n",
"# initial inputs\n",
Expand Down Expand Up @@ -823,7 +823,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions examples/train_single_standalone.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@
" env = RGBImgObservationWrapper(env)\n",
"\n",
" # setup training state\n",
" rng = jax.random.PRNGKey(config.seed)\n",
" rng = jax.random.key(config.seed)\n",
" rng, _rng = jax.random.split(rng)\n",
"\n",
" network = ActorCriticRNN(\n",
Expand Down Expand Up @@ -722,7 +722,7 @@
"total_reward = 0\n",
"rendered_imgs = []\n",
"\n",
"rng = jax.random.PRNGKey(1)\n",
"rng = jax.random.key(1)\n",
"rng, _rng = jax.random.split(rng)\n",
"\n",
"# initial inputs\n",
Expand Down Expand Up @@ -786,7 +786,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
20 changes: 10 additions & 10 deletions examples/walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@
"source": [
"import xminigrid\n",
"\n",
"key = jax.random.PRNGKey(0)\n",
"key = jax.random.key(0)\n",
"key, reset_key = jax.random.split(key)\n",
"\n",
"# to list available environments: xminigrid.registered_environments()\n",
Expand Down Expand Up @@ -345,7 +345,7 @@
"rollout_fn = jax.jit(build_rollout(env, env_params, num_steps=1000))\n",
"\n",
"# first execution will compile\n",
"transitions = rollout_fn(jax.random.PRNGKey(0))\n",
"transitions = rollout_fn(jax.random.key(0))\n",
"\n",
"print(\"Transitions shapes: \\n\", jtu.tree_map(jnp.shape, transitions))"
]
Expand Down Expand Up @@ -418,7 +418,7 @@
"outputs": [],
"source": [
"vmap_rollout = jax.jit(jax.vmap(build_rollout(env, env_params, num_steps=1000)))\n",
"rngs = jax.random.split(jax.random.PRNGKey(0), num=1024)\n",
"rngs = jax.random.split(jax.random.key(0), num=1024)\n",
"\n",
"vmap_transitions = vmap_rollout(rngs)\n",
"\n",
Expand Down Expand Up @@ -527,7 +527,7 @@
" benchmark_fn_pmap = build_benchmark(\"MiniGrid-EmptyRandom-8x8\", num_envs // num_devices, 1024)\n",
" benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap)\n",
"\n",
" key = jax.random.PRNGKey(0)\n",
" key = jax.random.key(0)\n",
" pmap_keys = jax.random.split(key, num=num_devices)\n",
"\n",
" elapsed_time = timeit_benchmark(jax.tree_util.Partial(benchmark_fn_vmap, key))\n",
Expand Down Expand Up @@ -896,7 +896,7 @@
"env, env_params = xminigrid.make(\"XLand-MiniGrid-R4-9x9\")\n",
"env_params = env_params.replace(ruleset=ruleset)\n",
"\n",
"timestep = env.reset(env_params, jax.random.PRNGKey(0))\n",
"timestep = env.reset(env_params, jax.random.key(0))\n",
"\n",
"show_img(env.render(env_params, timestep), dpi=64)"
]
Expand All @@ -921,7 +921,7 @@
"benchmark = xminigrid.load_benchmark(name=\"trivial-1m\")\n",
"print(\"Total rulesets:\", benchmark.num_rulesets())\n",
"print(\"Ruleset with id 128: \\n\", benchmark.get_ruleset(ruleset_id=128))\n",
"print(\"Random ruleset: \\n\", benchmark.sample_ruleset(jax.random.PRNGKey(0)))"
"print(\"Random ruleset: \\n\", benchmark.sample_ruleset(jax.random.key(0)))"
]
},
{
Expand All @@ -942,7 +942,7 @@
"outputs": [],
"source": [
"env_params = env_params.replace(ruleset=benchmark.get_ruleset(ruleset_id=128))\n",
"timestep = env.reset(env_params, jax.random.PRNGKey(0))\n",
"timestep = env.reset(env_params, jax.random.key(0))\n",
"\n",
"show_img(env.render(env_params, timestep), dpi=64)"
]
Expand Down Expand Up @@ -992,7 +992,7 @@
"metadata": {},
"outputs": [],
"source": [
"train, test = benchmark.shuffle(key=jax.random.PRNGKey(0)).split(prop=0.8)\n",
"train, test = benchmark.shuffle(key=jax.random.key(0)).split(prop=0.8)\n",
"\n",
"# or, by some function:\n",
"def cond_fn(goal, rules):\n",
Expand Down Expand Up @@ -1042,7 +1042,7 @@
"outputs": [],
"source": [
"env_params = env_params.replace(ruleset=rulesets)\n",
"timestep = jax.vmap(env.reset, in_axes=(0, None))(env_params, jax.random.PRNGKey(0))"
"timestep = jax.vmap(env.reset, in_axes=(0, None))(env_params, jax.random.key(0))"
]
},
{
Expand Down Expand Up @@ -1102,7 +1102,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ classifiers = [
]

dependencies = [
"jax>=0.4.16",
"jaxlib>=0.4.16",
"jax>=0.4.26",
"jaxlib>=0.4.26",
"flax>=0.8.0",
"rich>=13.4.2",
"chex>=0.1.85",
Expand Down
4 changes: 2 additions & 2 deletions scripts/benchmark_xland.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def build_benchmark(

# choose XLand benchmark if needed
if "XLand-MiniGrid" in env_id and benchmark_id is not None:
ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0))
ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.key(0))
env_params = env_params.replace(ruleset=ruleset)

def benchmark_fn(key):
Expand Down Expand Up @@ -98,7 +98,7 @@ def timeit_benchmark(args, benchmark_fn):
)
benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap)

key = jax.random.PRNGKey(0)
key = jax.random.key(0)
pmap_keys = jax.random.split(key, num=num_devices)

# benchmarking
Expand Down
4 changes: 2 additions & 2 deletions scripts/benchmark_xland_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def build_benchmark(

# choose XLand benchmark if needed
if "XLand-MiniGrid" in env_id and benchmark_id is not None:
ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0))
ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.key(0))
env_params = env_params.replace(ruleset=ruleset)

def benchmark_fn(key):
Expand Down Expand Up @@ -93,7 +93,7 @@ def timeit_benchmark(args, benchmark_fn):
benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap)

# benchmarking
pmap_keys = jax.random.split(jax.random.PRNGKey(0), num=num_devices)
pmap_keys = jax.random.split(jax.random.key(0), num=num_devices)

elapsed_time = timeit_benchmark(args, jtu.Partial(benchmark_fn_pmap, pmap_keys))
pmap_fps = (args.timesteps * num_envs) // elapsed_time
Expand Down
2 changes: 1 addition & 1 deletion src/xminigrid/manual_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(

self._reset = jax.jit(self.env.reset)
self._step = jax.jit(self.env.step)
self._key = jax.random.PRNGKey(0)
self._key = jax.random.key(0)

self.timestep = None

Expand Down
2 changes: 1 addition & 1 deletion training/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def main():
total_reward, num_episodes = 0, 0
rendered_imgs = []

rng = jax.random.PRNGKey(0)
rng = jax.random.key(0)
rng, _rng = jax.random.split(rng)

timestep = reset_fn(env_params, _rng)
Expand Down
4 changes: 2 additions & 2 deletions training/train_meta_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def linear_schedule(count):
benchmark = xminigrid.load_benchmark(config.benchmark_id)

# set up training state
rng = jax.random.PRNGKey(config.train_seed)
rng = jax.random.key(config.train_seed)
rng, _rng = jax.random.split(rng)

network = ActorCriticRNN(
Expand Down Expand Up @@ -269,7 +269,7 @@ def _update_minbatch(train_state, batch_info):
rng, train_state = runner_state[:2]

# EVALUATE AGENT
eval_ruleset_rng, eval_reset_rng = jax.random.split(jax.random.PRNGKey(config.eval_seed))
eval_ruleset_rng, eval_reset_rng = jax.random.split(jax.random.key(config.eval_seed))
eval_ruleset_rng = jax.random.split(eval_ruleset_rng, num=config.eval_num_envs_per_device)
eval_reset_rng = jax.random.split(eval_reset_rng, num=config.eval_num_envs_per_device)

Expand Down
2 changes: 1 addition & 1 deletion training/train_single_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def linear_schedule(count):
env = RGBImgObservationWrapper(env)

# setup training state
rng = jax.random.PRNGKey(config.seed)
rng = jax.random.key(config.seed)
rng, _rng = jax.random.split(rng)

network = ActorCriticRNN(
Expand Down

0 comments on commit 9ddd611

Please sign in to comment.