Skip to content

Commit

Permalink
RGB image observation wrapper compatible with jit (#9)
Browse files Browse the repository at this point in the history
* image obs rendering

* revert pre-commit

* fix types

* img obs benchmarking, fixed cache bug

* regenerated benchmarks

* add check

* remove check

* manual control video save

* adapted training scripts for img obs

* fix imageio

* fix bug

* add force cache reload, fix bug in wrapper

* fix bug with rules and empty tiles

* refine readme, cnn arch
  • Loading branch information
Howuhh authored Mar 24, 2024
1 parent e07118a commit a401d41
Show file tree
Hide file tree
Showing 21 changed files with 335 additions and 112 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ repos:
rev: v1.1.350
hooks:
- id: pyright
# args: [--project=pyproject.toml]
args: [--project=pyproject.toml]
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ On the high level, current API combines [dm_env](https://github.com/google-deepm
import jax
import xminigrid
from xminigrid.wrappers import GymAutoResetWrapper
from xminigrid.experimental.img_obs import RGBImgObservationWrapper

key = jax.random.PRNGKey(0)
reset_key, ruleset_key = jax.random.split(key)
Expand All @@ -109,6 +110,9 @@ env_params = env_params.replace(ruleset=ruleset)
# auto-reset wrapper
env = GymAutoResetWrapper(env)

# render obs as rgb images if needed (warn: this will affect speed greatly)
env = RGBImgObservationWrapper(env)

# fully jit-compatible step and reset methods
timestep = jax.jit(env.reset)(env_params, reset_key)
timestep = jax.jit(env.step)(env_params, timestep, action=0)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ dependencies = [
"flax>=0.8.0",
"rich>=13.4.2",
"chex>=0.1.85",
"imageio>=2.31.2",
"imageio-ffmpeg>=0.4.9",
]

[project.optional-dependencies]
Expand All @@ -49,8 +51,6 @@ dev = [

baselines = [
"matplotlib>=3.7.2",
"imageio>=2.31.2",
"imageio-ffmpeg>=0.4.9",
"wandb>=0.15.10",
"pyrallis>=0.3.1",
"distrax>=0.1.4",
Expand Down
24 changes: 20 additions & 4 deletions scripts/benchmark_xland.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,29 @@
parser = argparse.ArgumentParser()
parser.add_argument("--env-id", type=str, default="MiniGrid-Empty-16x16")
parser.add_argument("--benchmark-id", type=str, default="Trivial")
parser.add_argument("--img-obs", action="store_true")
parser.add_argument("--timesteps", type=int, default=1000)
parser.add_argument("--num-envs", type=int, default=8192)
parser.add_argument("--num-repeat", type=int, default=10, help="Number of timing repeats")
parser.add_argument("--num-iter", type=int, default=1, help="Number of runs during one repeat (time is summed)")


def build_benchmark(env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None):
def build_benchmark(
env_id: str,
num_envs: int,
timesteps: int,
benchmark_id: Optional[str] = None,
img_obs: bool = False,
):
env, env_params = xminigrid.make(env_id)
env = GymAutoResetWrapper(env)

# enable img observations if needed
if img_obs:
from xminigrid.experimental.img_obs import RGBImgObservationWrapper

env = RGBImgObservationWrapper(env)

# choose XLand benchmark if needed
if "XLand-MiniGrid" in env_id and benchmark_id is not None:
ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0))
Expand Down Expand Up @@ -73,13 +87,15 @@ def timeit_benchmark(args, benchmark_fn):
print("Num devices for pmap:", num_devices)

# building for single env benchmarking
benchmark_fn_single = build_benchmark(args.env_id, 1, args.timesteps, args.benchmark_id)
benchmark_fn_single = build_benchmark(args.env_id, 1, args.timesteps, args.benchmark_id, args.img_obs)
benchmark_fn_single = jax.jit(benchmark_fn_single)
# building vmap for vectorization benchmarking
benchmark_fn_vmap = build_benchmark(args.env_id, args.num_envs, args.timesteps, args.benchmark_id)
benchmark_fn_vmap = build_benchmark(args.env_id, args.num_envs, args.timesteps, args.benchmark_id, args.img_obs)
benchmark_fn_vmap = jax.jit(benchmark_fn_vmap)
# building pmap for multi-gpu benchmarking (each doing (num_envs / num_devices) vmaps)
benchmark_fn_pmap = build_benchmark(args.env_id, args.num_envs // num_devices, args.timesteps, args.benchmark_id)
benchmark_fn_pmap = build_benchmark(
args.env_id, args.num_envs // num_devices, args.timesteps, args.benchmark_id, args.img_obs
)
benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap)

key = jax.random.PRNGKey(0)
Expand Down
16 changes: 14 additions & 2 deletions scripts/benchmark_xland_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,24 @@

parser = argparse.ArgumentParser()
parser.add_argument("--benchmark-id", type=str, default="trivial-1m")
parser.add_argument("--img-obs", action="store_true")
parser.add_argument("--timesteps", type=int, default=1000)
parser.add_argument("--num-repeat", type=int, default=10, help="Number of timing repeats")
parser.add_argument("--num-iter", type=int, default=1, help="Number of runs during one repeat (time is summed)")


def build_benchmark(env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None):
def build_benchmark(
env_id: str, num_envs: int, timesteps: int, benchmark_id: Optional[str] = None, img_obs: bool = False
):
env, env_params = xminigrid.make(env_id)
env = GymAutoResetWrapper(env)

# enable img observations if needed
if img_obs:
from xminigrid.experimental.img_obs import RGBImgObservationWrapper

env = RGBImgObservationWrapper(env)

# choose XLand benchmark if needed
if "XLand-MiniGrid" in env_id and benchmark_id is not None:
ruleset = load_benchmark(benchmark_id).sample_ruleset(jax.random.PRNGKey(0))
Expand Down Expand Up @@ -77,7 +87,9 @@ def timeit_benchmark(args, benchmark_fn):
for env_id in tqdm(environments, desc="Envs.."):
assert num_envs % num_devices == 0
# building pmap for multi-gpu benchmarking (each doing (num_envs / num_devices) vmaps)
benchmark_fn_pmap = build_benchmark(env_id, num_envs // num_devices, args.timesteps, args.benchmark_id)
benchmark_fn_pmap = build_benchmark(
env_id, num_envs // num_devices, args.timesteps, args.benchmark_id, args.img_obs
)
benchmark_fn_pmap = jax.pmap(benchmark_fn_pmap)

# benchmarking
Expand Down
22 changes: 10 additions & 12 deletions scripts/generate_benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ python scripts/ruleset_generator.py \
--total_rulesets=1_000_000 \
--save_path="trivial_1m"


# small
python scripts/ruleset_generator.py \
--prune_chain \
Expand Down Expand Up @@ -41,17 +40,16 @@ python scripts/ruleset_generator.py \
--total_rulesets=1_000_000 \
--save_path="high_1m"


# medium + distractors
python scripts/ruleset_generator.py \
--prune_chain \
--prune_prob=0.8 \
--chain_depth=2 \
--sample_distractor_rules \
--num_distractor_rules=4 \
--num_distractor_objects=2 \
--total_rulesets=1_000_000 \
--save_path="medium_dist_1m"
## medium + distractors
#python scripts/ruleset_generator.py \
# --prune_chain \
# --prune_prob=0.8 \
# --chain_depth=2 \
# --sample_distractor_rules \
# --num_distractor_rules=4 \
# --num_distractor_objects=2 \
# --total_rulesets=1_000_000 \
# --save_path="medium_dist_1m"

# medium 3M
python scripts/ruleset_generator.py \
Expand Down
5 changes: 3 additions & 2 deletions src/xminigrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .registration import make, register, registered_environments

# TODO: add __all__
__version__ = "0.6.0"
__version__ = "0.7.0"

# ---------- XLand-MiniGrid environments ----------

Expand Down Expand Up @@ -210,7 +210,8 @@

# BlockedUnlockPickUp
register(
id="MiniGrid-BlockedUnlockPickUp", entry_point="xminigrid.envs.minigrid.blockedunlockpickup:BlockedUnlockPickUp"
id="MiniGrid-BlockedUnlockPickUp",
entry_point="xminigrid.envs.minigrid.blockedunlockpickup:BlockedUnlockPickUp",
)

# DoorKey
Expand Down
13 changes: 6 additions & 7 deletions src/xminigrid/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
DATA_PATH = os.environ.get("XLAND_MINIGRID_DATA", os.path.expanduser("~/.xland_minigrid"))

NAME2HFFILENAME = {
"trivial-1m": "trivial_1m",
"small-1m": "small_1m",
"small-dist-1m": "small_dist_1m",
"medium-1m": "medium_1m_v1",
"medium-3m": "medium_3m_v1",
"high-1m": "high_1m",
"high-3m": "high_3m",
"trivial-1m": "trivial_1m_v2",
"small-1m": "small_1m_v2",
"medium-1m": "medium_1m_v2",
"medium-3m": "medium_3m_v2",
"high-1m": "high_1m_v2",
"high-3m": "high_3m_v2",
}


Expand Down
66 changes: 28 additions & 38 deletions src/xminigrid/core/constants.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,44 @@
import jax.numpy as jnp
from flax import struct

NUM_ACTIONS = 6

# GRID: [tile, color]
NUM_LAYERS = 2
NUM_TILES = 15
NUM_COLORS = 14
NUM_ACTIONS = 6
NUM_TILES = 13
NUM_COLORS = 12


# TODO: do we really need END_OF_MAP? seem like unseen can be used instead...
# enums, kinda...
class Tiles(struct.PyTreeNode):
EMPTY: int = struct.field(pytree_node=False, default=0)
END_OF_MAP: int = struct.field(pytree_node=False, default=1)
UNSEEN: int = struct.field(pytree_node=False, default=2)
FLOOR: int = struct.field(pytree_node=False, default=3)
WALL: int = struct.field(pytree_node=False, default=4)
BALL: int = struct.field(pytree_node=False, default=5)
SQUARE: int = struct.field(pytree_node=False, default=6)
PYRAMID: int = struct.field(pytree_node=False, default=7)
GOAL: int = struct.field(pytree_node=False, default=8)
KEY: int = struct.field(pytree_node=False, default=9)
DOOR_LOCKED: int = struct.field(pytree_node=False, default=10)
DOOR_CLOSED: int = struct.field(pytree_node=False, default=11)
DOOR_OPEN: int = struct.field(pytree_node=False, default=12)
HEX: int = struct.field(pytree_node=False, default=13)
STAR: int = struct.field(pytree_node=False, default=14)
FLOOR: int = struct.field(pytree_node=False, default=1)
WALL: int = struct.field(pytree_node=False, default=2)
BALL: int = struct.field(pytree_node=False, default=3)
SQUARE: int = struct.field(pytree_node=False, default=4)
PYRAMID: int = struct.field(pytree_node=False, default=5)
GOAL: int = struct.field(pytree_node=False, default=6)
KEY: int = struct.field(pytree_node=False, default=7)
DOOR_LOCKED: int = struct.field(pytree_node=False, default=8)
DOOR_CLOSED: int = struct.field(pytree_node=False, default=9)
DOOR_OPEN: int = struct.field(pytree_node=False, default=10)
HEX: int = struct.field(pytree_node=False, default=11)
STAR: int = struct.field(pytree_node=False, default=12)


class Colors(struct.PyTreeNode):
EMPTY: int = struct.field(pytree_node=False, default=0)
END_OF_MAP: int = struct.field(pytree_node=False, default=1)
UNSEEN: int = struct.field(pytree_node=False, default=2)
RED: int = struct.field(pytree_node=False, default=3)
GREEN: int = struct.field(pytree_node=False, default=4)
BLUE: int = struct.field(pytree_node=False, default=5)
PURPLE: int = struct.field(pytree_node=False, default=6)
YELLOW: int = struct.field(pytree_node=False, default=7)
GREY: int = struct.field(pytree_node=False, default=8)
BLACK: int = struct.field(pytree_node=False, default=9)
ORANGE: int = struct.field(pytree_node=False, default=10)
WHITE: int = struct.field(pytree_node=False, default=11)
BROWN: int = struct.field(pytree_node=False, default=12)
PINK: int = struct.field(pytree_node=False, default=13)
RED: int = struct.field(pytree_node=False, default=1)
GREEN: int = struct.field(pytree_node=False, default=2)
BLUE: int = struct.field(pytree_node=False, default=3)
PURPLE: int = struct.field(pytree_node=False, default=4)
YELLOW: int = struct.field(pytree_node=False, default=5)
GREY: int = struct.field(pytree_node=False, default=6)
BLACK: int = struct.field(pytree_node=False, default=7)
ORANGE: int = struct.field(pytree_node=False, default=8)
WHITE: int = struct.field(pytree_node=False, default=9)
BROWN: int = struct.field(pytree_node=False, default=10)
PINK: int = struct.field(pytree_node=False, default=11)


# Only ~100 combinations so far, better to preallocate them
Expand All @@ -65,7 +61,6 @@ class Colors(struct.PyTreeNode):

WALKABLE = jnp.array(
(
Tiles.EMPTY,
Tiles.FLOOR,
Tiles.GOAL,
Tiles.DOOR_OPEN,
Expand All @@ -83,12 +78,7 @@ class Colors(struct.PyTreeNode):
)
)

FREE_TO_PUT_DOWN = jnp.array(
(
Tiles.EMPTY,
Tiles.FLOOR,
)
)
FREE_TO_PUT_DOWN = jnp.array((Tiles.FLOOR,))

LOS_BLOCKING = jnp.array(
(
Expand Down
4 changes: 2 additions & 2 deletions src/xminigrid/core/grid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable, Union
from typing import Callable

import jax
import jax.numpy as jnp
Expand All @@ -22,7 +22,7 @@ def equal(tile1: Tile, tile2: Tile) -> Tile:

def get_neighbouring_tiles(grid: GridState, y: IntOrArray, x: IntOrArray) -> tuple[Tile, Tile, Tile, Tile]:
# end_of_map = TILES_REGISTRY[Tiles.END_OF_MAP, Colors.END_OF_MAP]
end_of_map = Tiles.END_OF_MAP
end_of_map = Tiles.EMPTY

up_tile = grid.at[y - 1, x].get(mode="fill", fill_value=end_of_map)
right_tile = grid.at[y, x + 1].get(mode="fill", fill_value=end_of_map)
Expand Down
6 changes: 3 additions & 3 deletions src/xminigrid/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def crop_field_of_view(grid: GridState, agent: AgentState, height: int, width: i
grid = jnp.pad(
grid,
pad_width=((height, height), (width, width), (0, 0)),
constant_values=Tiles.END_OF_MAP,
constant_values=Tiles.EMPTY,
)
# account for padding
y = agent.position[0] + height
Expand Down Expand Up @@ -110,8 +110,8 @@ def minigrid_field_of_view(grid: GridState, agent: AgentState, height: int, widt
fov_grid = crop_field_of_view(grid, agent, height, width)
fov_grid = align_with_up(fov_grid, agent.direction)
mask = generate_viz_mask_minigrid(fov_grid)
# set UNSEEN value for all layers (including colors, as UNSEEN color has same id value)
fov_grid = jnp.where(mask[..., None], fov_grid, Tiles.UNSEEN)
# set EMPTY as unseen value for all layers (including colors, as EMPTY color has same id value)
fov_grid = jnp.where(mask[..., None], fov_grid, Tiles.EMPTY)

# TODO: should we even do this? Agent with good memory can remember what he picked up.
# WARN: this can overwrite tile the agent is on, GOAL for example.
Expand Down
Loading

0 comments on commit a401d41

Please sign in to comment.