Skip to content

Conversation

@suijth
Copy link

@suijth suijth commented Feb 21, 2025

Description

Add MultiInputPolicy for DQN.

Tried to make minimal changes to support this. Did separate out extractors like sb3's MultiInputPolicy

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)
  • I have checked that the documentation builds using make doc (required)

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

@araffin
Copy link
Owner

araffin commented Feb 25, 2025

Hello,
thanks for the PR.
When you check items in the checklist (especially all the `make' commands), it means you have executed them, not that you will let the CI do the job.

Please avoid any type: ignore unless necessary too.

@suijth
Copy link
Author

suijth commented Feb 26, 2025

Hello, thanks for the PR. When you check items in the checklist (especially all the `make' commands), it means you have executed them, not that you will let the CI do the job.

My bad. I was using ubuntu's mypy which was not catching these errors not sure why. I'll probably create a new branch without these commits, it'll be cleaner.

Please avoid any type: ignore unless necessary too.

Noted. Removed them at most places.

@araffin araffin self-requested a review May 15, 2025 07:52

return key

def prepare_obs( # type: ignore[override]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this preprocessing is done only for the data collection, right?

but it should not be needed because the preprocessing is done inside __call__ by the multi input q network?

I started rewriting SB3 preprocessing.py and then realized it should not be needed.

from typing import Union

import jax
import jax.numpy as jnp
from gymnasium import spaces
from stable_baselines3.common.preprocessing import is_image_space


def preprocess_obs(
    obs: Union[jax.Array, dict[str, jax.Array]],
    observation_space: spaces.Space,
    normalize_images: bool = True,
) -> Union[jax.Array, dict[str, jax.Array]]:
    """
    Preprocess observation to be to a neural network.
    For images, it normalizes the values by dividing them by 255 (to have values in [0, 1])
    For discrete observations, it create a one hot vector.

    :param obs: Observation
    :param observation_space:
    :param normalize_images: Whether to normalize images or not
        (True by default)
    :return:
    """
    if isinstance(observation_space, spaces.Dict):
        # Do not modify by reference the original observation
        assert isinstance(obs, dict), f"Expected dict, got {type(obs)}"
        preprocessed_obs = jax.tree.map(
            lambda observation, obs_space: preprocess_obs(observation, obs_space, normalize_images=normalize_images),
            obs,
            observation_space.spaces,
        )
        return preprocessed_obs  # type: ignore[return-value]

    assert not isinstance(obs, dict)

    if isinstance(observation_space, spaces.Box):
        if normalize_images and is_image_space(observation_space):
            return obs.astype(jnp.float32) / 255.0
        return obs.astype(jnp.float32)

    elif isinstance(observation_space, spaces.Discrete):
        # One hot encoding and convert to float to avoid errors
        return jax.nn.one_hot(obs, num_classes=int(observation_space.n)).astype(jnp.float32)

    elif isinstance(observation_space, spaces.MultiBinary):
        return obs.astype(jnp.float32)
    else:
        # MISSING: spaces.MultiDiscrete
        raise NotImplementedError(f"Preprocessing not implemented for {observation_space}")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this preprocessing is done only for the data collection, right?

yes

but it should not be needed because the preprocessing is done inside call by the multi input q network?

yes, part of it is done inside __call__ (pytorch image -> jax transforming and normalising). Inside prepare_obs we are converting non vectorized obs to vectorized and image obs to pytorch format.
converting image to pytorch then to jax format might be redundant, we can find a way to avoid it
but we'll need the vectorizing part thus we need this prepare_obs.

This is my understanding.

@suijth suijth requested a review from araffin May 26, 2025 18:06
@araffin araffin added the Maintainers on vacation Maintainers are on vacation so they can recharge their batteries, we will be back soon ;) label May 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Maintainers on vacation Maintainers are on vacation so they can recharge their batteries, we will be back soon ;)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants