-
Notifications
You must be signed in to change notification settings - Fork 55
Add Dict Obs support for DQN #66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
Hello, Please avoid any |
My bad. I was using ubuntu's mypy which was not catching these errors not sure why. I'll probably create a new branch without these commits, it'll be cleaner.
Noted. Removed them at most places. |
|
|
||
| return key | ||
|
|
||
| def prepare_obs( # type: ignore[override] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this preprocessing is done only for the data collection, right?
but it should not be needed because the preprocessing is done inside __call__ by the multi input q network?
I started rewriting SB3 preprocessing.py and then realized it should not be needed.
from typing import Union
import jax
import jax.numpy as jnp
from gymnasium import spaces
from stable_baselines3.common.preprocessing import is_image_space
def preprocess_obs(
obs: Union[jax.Array, dict[str, jax.Array]],
observation_space: spaces.Space,
normalize_images: bool = True,
) -> Union[jax.Array, dict[str, jax.Array]]:
"""
Preprocess observation to be to a neural network.
For images, it normalizes the values by dividing them by 255 (to have values in [0, 1])
For discrete observations, it create a one hot vector.
:param obs: Observation
:param observation_space:
:param normalize_images: Whether to normalize images or not
(True by default)
:return:
"""
if isinstance(observation_space, spaces.Dict):
# Do not modify by reference the original observation
assert isinstance(obs, dict), f"Expected dict, got {type(obs)}"
preprocessed_obs = jax.tree.map(
lambda observation, obs_space: preprocess_obs(observation, obs_space, normalize_images=normalize_images),
obs,
observation_space.spaces,
)
return preprocessed_obs # type: ignore[return-value]
assert not isinstance(obs, dict)
if isinstance(observation_space, spaces.Box):
if normalize_images and is_image_space(observation_space):
return obs.astype(jnp.float32) / 255.0
return obs.astype(jnp.float32)
elif isinstance(observation_space, spaces.Discrete):
# One hot encoding and convert to float to avoid errors
return jax.nn.one_hot(obs, num_classes=int(observation_space.n)).astype(jnp.float32)
elif isinstance(observation_space, spaces.MultiBinary):
return obs.astype(jnp.float32)
else:
# MISSING: spaces.MultiDiscrete
raise NotImplementedError(f"Preprocessing not implemented for {observation_space}")There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this preprocessing is done only for the data collection, right?
yes
but it should not be needed because the preprocessing is done inside call by the multi input q network?
yes, part of it is done inside __call__ (pytorch image -> jax transforming and normalising). Inside prepare_obs we are converting non vectorized obs to vectorized and image obs to pytorch format.
converting image to pytorch then to jax format might be redundant, we can find a way to avoid it
but we'll need the vectorizing part thus we need this prepare_obs.
This is my understanding.
Description
Add MultiInputPolicy for DQN.
Tried to make minimal changes to support this. Did separate out extractors like sb3's MultiInputPolicy
Motivation and Context
Types of changes
Checklist:
make format(required)make check-codestyleandmake lint(required)make pytestandmake typeboth pass. (required)make doc(required)Note: You can run most of the checks using
make commit-checks.Note: we are using a maximum length of 127 characters per line