Skip to content

Commit

Permalink
Obs wrappers to add direction and actual task rules (#23)
Browse files Browse the repository at this point in the history
* extended obs wip

* extended obs wip

* extended obs with rules wip

* additional comment

* fix ruff action

* fix render wrapper
  • Loading branch information
Howuhh authored Jul 12, 2024
1 parent b0dd4a2 commit 405d47a
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/codestyle.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
pip install -e ".[dev]"
- name: check codestyle
run: |
ruff --config pyproject.toml --diff .
ruff check --config pyproject.toml --diff .
- name: check type hints
run: |
pyright --project=pyproject.toml src/xminigrid
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:

# pyright checking
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.350
rev: v1.1.371
hooks:
- id: pyright
args: [--project=pyproject.toml]
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,10 @@ ignore = [
[tool.ruff.format]
skip-magic-trailing-comma = false

[tool.ruff.isort]
[tool.ruff.lint.isort]
# see https://github.com/astral-sh/ruff/issues/8571
known-third-party = ["wandb"]


[tool.pyright]
include = ["src/xminigrid"]
exclude = [
Expand Down
2 changes: 1 addition & 1 deletion src/xminigrid/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def default_params(self, **kwargs: Any) -> EnvParamsT:
def num_actions(self, params: EnvParamsT) -> int:
return int(NUM_ACTIONS)

def observation_shape(self, params: EnvParamsT) -> tuple[int, int, int]:
def observation_shape(self, params: EnvParamsT) -> tuple[int, int, int] | dict[str, Any]:
return params.view_size, params.view_size, NUM_LAYERS

@abc.abstractmethod
Expand Down
25 changes: 22 additions & 3 deletions src/xminigrid/experimental/img_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,33 @@ def _render_obs(obs: jax.Array) -> jax.Array:

class RGBImgObservationWrapper(Wrapper):
def observation_shape(self, params):
return params.view_size * TILE_SIZE, params.view_size * TILE_SIZE, 3
new_shape = (params.view_size * TILE_SIZE, params.view_size * TILE_SIZE, 3)

base_shape = self._env.observation_shape(params)
if isinstance(base_shape, dict):
assert "img" in base_shape
obs_shape = {**base_shape, **{"img": new_shape}}
else:
obs_shape = new_shape

return obs_shape

def __convert_obs(self, timestep):
if isinstance(timestep.observation, dict):
assert "img" in timestep.observation
rendered_obs = {**timestep.observation, **{"img": _render_obs(timestep.observation["img"])}}
else:
rendered_obs = _render_obs(timestep.observation)

timestep = timestep.replace(observation=rendered_obs)
return timestep

def reset(self, params, key):
timestep = self._env.reset(params, key)
timestep = timestep.replace(observation=_render_obs(timestep.observation))
timestep = self.__convert_obs(timestep)
return timestep

def step(self, params, timestep, action):
timestep = self._env.step(params, timestep, action)
timestep = timestep.replace(observation=_render_obs(timestep.observation))
timestep = self.__convert_obs(timestep)
return timestep
2 changes: 1 addition & 1 deletion src/xminigrid/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class TimeStep(struct.PyTreeNode, Generic[EnvCarryT]):
step_type: StepType
reward: jax.Array
discount: jax.Array
observation: jax.Array
observation: jax.Array | dict[str, jax.Array]

def first(self):
return self.step_type == StepType.FIRST
Expand Down
103 changes: 102 additions & 1 deletion src/xminigrid/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any

import jax

from .environment import Environment, EnvParamsT
Expand All @@ -19,7 +21,7 @@ def default_params(self, **kwargs) -> EnvParamsT:
def num_actions(self, params: EnvParamsT) -> int:
return self._env.num_actions(params)

def observation_shape(self, params: EnvParamsT) -> tuple[int, int, int]:
def observation_shape(self, params: EnvParamsT) -> tuple[int, int, int] | dict[str, Any]:
return self._env.observation_shape(params)

def _generate_problem(self, params: EnvParamsT, key: jax.Array) -> State[EnvCarryT]:
Expand Down Expand Up @@ -67,3 +69,102 @@ def step(self, params, timestep, action):
lambda: self._env.step(params, timestep, action),
)
return timestep


# Yes, these are a bit stupid, but a tmp workaround to not write an actual system for spaces.
# May be, in the future, I will port the entire API to some existing one, like functional Gymnasium.
# For now, faster to do this stuff with dicts instead...
# NB: if you do not want to use this (due to the dicts as obs),
# just get needed parts from the original TimeStep and State dataclasses
class DirectionObservationWrapper(Wrapper):
def observation_shape(self, params):
base_shape = self._env.observation_shape(params)
if isinstance(base_shape, dict):
assert "img" in base_shape
obs_shape = {**base_shape, **{"direction": 4}}
else:
obs_shape = {
"img": self._env.observation_shape(params),
"direction": 4,
}
return obs_shape

def __extend_obs(self, timestep):
direction = jax.nn.one_hot(timestep.state.agent.direction, num_classes=4)
if isinstance(timestep.observation, dict):
assert "img" in timestep.observation
extended_obs = {
**timestep.observation,
**{"direction": direction},
}
else:
extended_obs = {
"img": timestep.observation,
"direction": direction,
}

timestep = timestep.replace(observation=extended_obs)
return timestep

def reset(self, params, key):
timestep = self._env.reset(params, key)
timestep = self.__extend_obs(timestep)
return timestep

def step(self, params, timestep, action):
timestep = self._env.step(params, timestep, action)
timestep = self.__extend_obs(timestep)
return timestep


class RulesAndGoalsObservationWrapper(Wrapper):
def observation_shape(self, params):
base_shape = self._env.observation_shape(params)
if isinstance(base_shape, dict):
assert "img" in base_shape
obs_shape = {
**base_shape,
**{
"goal_encoding": params.ruleset.goal.shape,
"rule_encoding": params.ruleset.rules.shape,
},
}
else:
obs_shape = {
"img": self._env.observation_shape(params),
"goal_encoding": params.ruleset.goal.shape,
"rule_encoding": params.ruleset.rules.shape,
}
return obs_shape

def __extend_obs(self, timestep):
goal_encoding = timestep.state.goal_encoding
rule_encoding = timestep.state.rule_encoding
if isinstance(timestep.observation, dict):
assert "img" in timestep.observation
extended_obs = {
**timestep.observation,
**{
"goal_encoding": goal_encoding,
"rule_encoding": rule_encoding,
},
}
else:
extended_obs = {
"img": timestep.observation,
"goal_encoding": goal_encoding,
"rule_encoding": rule_encoding,
}

timestep = timestep.replace(observation=extended_obs)
return timestep

def reset(self, params, key):
timestep = self._env.reset(params, key)
timestep = self.__extend_obs(timestep)
return timestep

def step(self, params, timestep, action):
timestep = self._env.step(params, timestep, action)
timestep = self.__extend_obs(timestep)
return timestep

0 comments on commit 405d47a

Please sign in to comment.