Skip to content

Commit

Permalink
Revert OrderedDict key ordering in Dict space (#1291)
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts authored Jan 8, 2025
1 parent e6e3521 commit c6c5815
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
6 changes: 4 additions & 2 deletions gymnasium/spaces/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import collections.abc
import typing
from collections import OrderedDict
from typing import Any, KeysView, Sequence

import numpy as np
Expand Down Expand Up @@ -66,8 +67,9 @@ def __init__(
seed: Optionally, you can use this argument to seed the RNGs of the spaces that make up the :class:`Dict` space.
**spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above.
"""
# Convert the spaces into an OrderedDict
if isinstance(spaces, collections.abc.Mapping):
if isinstance(spaces, OrderedDict):
spaces = dict(spaces.items())
elif isinstance(spaces, collections.abc.Mapping):
# for legacy reasons, we need to preserve the sorted dictionary items.
# as this could matter for projects flatten the dictionary.
try:
Expand Down
16 changes: 15 additions & 1 deletion tests/spaces/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,22 @@ def test_dict_init():
assert a == b == c == d
assert len(caught_warnings) == 0

# test sorting
with warnings.catch_warnings(record=True) as caught_warnings:
Dict({1: Discrete(2), "a": Discrete(3)})
# Sorting is applied to the keys
a = Dict({"b": Box(low=0.0, high=1.0), "a": Discrete(2)})
assert a.keys() == {"a", "b"}

# Sorting is not applied to the keys
b = Dict(OrderedDict(b=Box(low=0.0, high=1.0), a=Discrete(2)))
c = Dict((("b", Box(low=0.0, high=1.0)), ("a", Discrete(2))))
d = Dict(b=Box(low=0.0, high=1.0), a=Discrete(2))
assert b.keys() == c.keys() == d.keys() == {"b", "a"}
assert len(caught_warnings) == 0

# test sorting with different classes
with warnings.catch_warnings(record=True) as caught_warnings:
assert Dict({1: Discrete(2), "a": Discrete(3)}).keys() == {1, "a"}
assert len(caught_warnings) == 0


Expand Down

0 comments on commit c6c5815

Please sign in to comment.