Skip to content

Commit 9368ca6

Browse files
kurtamohlervmoens
authored andcommitted
[Feature] Add Choice spec
ghstack-source-id: afa315a Pull Request resolved: #2713
1 parent 20a19fe commit 9368ca6

File tree

6 files changed

+288
-13
lines changed

6 files changed

+288
-13
lines changed

benchmarks/ecosystem/gym_env_throughput.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
if __name__ == "__main__":
3333
avail_devices = ("cpu",)
34-
if torch.cuda.device_count():
34+
if torch.cuda.is_available():
3535
avail_devices = avail_devices + ("cuda:0",)
3636

3737
for envname in [

examples/distributed/collectors/multi_nodes/rpc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
"slurm_gpus_per_task": args.slurm_gpus_per_task,
7575
}
7676
device_str = "device" if num_workers <= 1 else "devices"
77-
if torch.cuda.device_count():
77+
if torch.cuda.is_available():
7878
collector_kwargs = {device_str: "cuda:0", f"storing_{device_str}": "cuda:0"}
7979
else:
8080
collector_kwargs = {device_str: "cpu", "storing_{device_str}": "cpu"}

test/test_specs.py

+169-10
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@
1212
import torch
1313
import torchrl.data.tensor_specs
1414
from scipy.stats import chisquare
15-
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
15+
from tensordict import (
16+
LazyStackedTensorDict,
17+
NonTensorData,
18+
NonTensorStack,
19+
TensorDict,
20+
TensorDictBase,
21+
)
1622
from tensordict.utils import _unravel_key_to_tuple
1723
from torchrl._utils import _make_ordinal_device
1824

@@ -23,6 +29,7 @@
2329
Bounded,
2430
BoundedTensorSpec,
2531
Categorical,
32+
Choice,
2633
Composite,
2734
CompositeSpec,
2835
ContinuousBox,
@@ -702,6 +709,63 @@ def test_change_batch_size(self, shape, is_complete, device, dtype):
702709
assert ts["nested"].shape == (3,)
703710

704711

712+
class TestChoiceSpec:
713+
@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
714+
def test_choice(self, input_type):
715+
if input_type == "spec":
716+
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
717+
example_in = torch.tensor(11.0)
718+
example_out = torch.tensor(9.0)
719+
elif input_type == "nontensor":
720+
choices = [NonTensorData("a"), NonTensorData("b")]
721+
example_in = NonTensorData("b")
722+
example_out = NonTensorData("c")
723+
elif input_type == "nontensorstack":
724+
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
725+
example_in = NonTensorStack("a", "b", "c")
726+
example_out = NonTensorStack("a", "c", "b")
727+
torch.manual_seed(0)
728+
for _ in range(10):
729+
spec = Choice(choices)
730+
res = spec.rand()
731+
assert spec.is_in(res)
732+
assert spec.is_in(example_in)
733+
assert not spec.is_in(example_out)
734+
735+
def test_errors(self):
736+
with pytest.raises(TypeError, match="must be a list"):
737+
Choice("abc")
738+
739+
with pytest.raises(
740+
TypeError,
741+
match="must be either a TensorSpec, NonTensorData, or NonTensorStack",
742+
):
743+
Choice(["abc"])
744+
745+
with pytest.raises(TypeError, match="must be the same type"):
746+
Choice([Bounded(0, 1, (1,)), Categorical(10, (1,))])
747+
748+
with pytest.raises(ValueError, match="must have the same shape"):
749+
Choice([Categorical(10, (1,)), Categorical(10, (2,))])
750+
751+
with pytest.raises(ValueError, match="must have the same dtype"):
752+
Choice(
753+
[
754+
Categorical(10, (2,), dtype=torch.long),
755+
Categorical(10, (2,), dtype=torch.float),
756+
]
757+
)
758+
759+
if torch.cuda.is_available():
760+
with pytest.raises(ValueError, match="must have the same device"):
761+
Choice(
762+
[
763+
Categorical(10, (2,), device="cpu"),
764+
Categorical(10, (2,), device="cuda"),
765+
]
766+
)
767+
768+
705769
@pytest.mark.parametrize("shape", [(), (2, 3)])
706770
@pytest.mark.parametrize("device", get_default_devices())
707771
def test_create_composite_nested(shape, device):
@@ -851,7 +915,7 @@ def test_equality_bounded(self):
851915

852916
ts_other = Bounded(minimum, maximum + 1, torch.Size((1,)), device, dtype)
853917
assert ts != ts_other
854-
if torch.cuda.device_count():
918+
if torch.cuda.is_available():
855919
ts_other = Bounded(minimum, maximum, torch.Size((1,)), "cuda:0", dtype)
856920
assert ts != ts_other
857921

@@ -879,7 +943,7 @@ def test_equality_onehot(self):
879943
)
880944
assert ts != ts_other
881945

882-
if torch.cuda.device_count():
946+
if torch.cuda.is_available():
883947
ts_other = OneHot(
884948
n=n, device="cuda:0", dtype=dtype, use_register=use_register
885949
)
@@ -909,7 +973,7 @@ def test_equality_unbounded(self):
909973
ts_same = Unbounded(device=device, dtype=dtype)
910974
assert ts == ts_same
911975

912-
if torch.cuda.device_count():
976+
if torch.cuda.is_available():
913977
ts_other = Unbounded(device="cuda:0", dtype=dtype)
914978
assert ts != ts_other
915979

@@ -942,7 +1006,7 @@ def test_equality_ndbounded(self):
9421006
ts_other = Bounded(low=minimum, high=maximum + 1, device=device, dtype=dtype)
9431007
assert ts != ts_other
9441008

945-
if torch.cuda.device_count():
1009+
if torch.cuda.is_available():
9461010
ts_other = Bounded(low=minimum, high=maximum, device="cuda:0", dtype=dtype)
9471011
assert ts != ts_other
9481012

@@ -970,7 +1034,7 @@ def test_equality_discrete(self):
9701034
ts_other = Categorical(n=n + 1, shape=shape, device=device, dtype=dtype)
9711035
assert ts != ts_other
9721036

973-
if torch.cuda.device_count():
1037+
if torch.cuda.is_available():
9741038
ts_other = Categorical(n=n, shape=shape, device="cuda:0", dtype=dtype)
9751039
assert ts != ts_other
9761040

@@ -1008,7 +1072,7 @@ def test_equality_ndunbounded(self, shape):
10081072
ts_other = Unbounded(shape=other_shape, device=device, dtype=dtype)
10091073
assert ts != ts_other
10101074

1011-
if torch.cuda.device_count():
1075+
if torch.cuda.is_available():
10121076
ts_other = Unbounded(shape=shape, device="cuda:0", dtype=dtype)
10131077
assert ts != ts_other
10141078

@@ -1034,7 +1098,7 @@ def test_equality_binary(self):
10341098
ts_other = Binary(n=n + 5, device=device, dtype=dtype)
10351099
assert ts != ts_other
10361100

1037-
if torch.cuda.device_count():
1101+
if torch.cuda.is_available():
10381102
ts_other = Binary(n=n, device="cuda:0", dtype=dtype)
10391103
assert ts != ts_other
10401104

@@ -1068,7 +1132,7 @@ def test_equality_multi_onehot(self, nvec):
10681132
ts_other = MultiOneHot(nvec=other_nvec, device=device, dtype=dtype)
10691133
assert ts != ts_other
10701134

1071-
if torch.cuda.device_count():
1135+
if torch.cuda.is_available():
10721136
ts_other = MultiOneHot(nvec=nvec, device="cuda:0", dtype=dtype)
10731137
assert ts != ts_other
10741138

@@ -1102,7 +1166,7 @@ def test_equality_multi_discrete(self, nvec):
11021166
ts_other = MultiCategorical(nvec=other_nvec, device=device, dtype=dtype)
11031167
assert ts != ts_other
11041168

1105-
if torch.cuda.device_count():
1169+
if torch.cuda.is_available():
11061170
ts_other = MultiCategorical(nvec=nvec, device="cuda:0", dtype=dtype)
11071171
assert ts != ts_other
11081172

@@ -1498,6 +1562,19 @@ def test_non_tensor(self):
14981562
)
14991563
assert spec.expand(2, 3, 4).example_data == "example_data"
15001564

1565+
@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
1566+
def test_choice(self, input_type):
1567+
if input_type == "spec":
1568+
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
1569+
elif input_type == "nontensor":
1570+
choices = [NonTensorData("a"), NonTensorData("b")]
1571+
elif input_type == "nontensorstack":
1572+
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
1573+
1574+
spec = Choice(choices)
1575+
res = spec.expand([3])
1576+
assert res.shape == torch.Size([3])
1577+
15011578
@pytest.mark.parametrize("shape1", [None, (), (5,)])
15021579
@pytest.mark.parametrize("shape2", [(), (10,)])
15031580
def test_onehot(self, shape1, shape2):
@@ -1701,6 +1778,19 @@ def test_non_tensor(self):
17011778
assert spec.clone() is not spec
17021779
assert spec.clone().example_data == "example_data"
17031780

1781+
@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
1782+
def test_choice(self, input_type):
1783+
if input_type == "spec":
1784+
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
1785+
elif input_type == "nontensor":
1786+
choices = [NonTensorData("a"), NonTensorData("b")]
1787+
elif input_type == "nontensorstack":
1788+
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
1789+
1790+
spec = Choice(choices)
1791+
assert spec.clone() == spec
1792+
assert spec.clone() is not spec
1793+
17041794
@pytest.mark.parametrize("shape1", [None, (), (5,)])
17051795
def test_onehot(
17061796
self,
@@ -1786,6 +1876,31 @@ def test_non_tensor(self):
17861876
with pytest.raises(RuntimeError, match="Cannot enumerate a NonTensorSpec."):
17871877
spec.cardinality()
17881878

1879+
@pytest.mark.parametrize(
1880+
"input_type",
1881+
["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"],
1882+
)
1883+
def test_choice(self, input_type):
1884+
if input_type == "bounded_spec":
1885+
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
1886+
elif input_type == "categorical_spec":
1887+
choices = [Categorical(10), Categorical(20)]
1888+
elif input_type == "nontensor":
1889+
choices = [NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
1890+
elif input_type == "nontensorstack":
1891+
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
1892+
1893+
spec = Choice(choices)
1894+
1895+
if input_type == "bounded_spec":
1896+
assert spec.cardinality() == float("inf")
1897+
elif input_type == "categorical_spec":
1898+
assert spec.cardinality() == 30
1899+
elif input_type == "nontensor":
1900+
assert spec.cardinality() == 3
1901+
elif input_type == "nontensorstack":
1902+
assert spec.cardinality() == 2
1903+
17891904
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
17901905
def test_onehot(
17911906
self,
@@ -2096,6 +2211,23 @@ def test_non_tensor(self, device):
20962211
assert spec.to(device).device == device
20972212
assert spec.to(device).example_data == "example_data"
20982213

2214+
@pytest.mark.parametrize(
2215+
"input_type",
2216+
["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"],
2217+
)
2218+
def test_choice(self, input_type, device):
2219+
if input_type == "bounded_spec":
2220+
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
2221+
elif input_type == "categorical_spec":
2222+
choices = [Categorical(10), Categorical(20)]
2223+
elif input_type == "nontensor":
2224+
choices = [NonTensorData("a"), NonTensorData("b")]
2225+
elif input_type == "nontensorstack":
2226+
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
2227+
2228+
spec = Choice(choices)
2229+
assert spec.to(device).device == device
2230+
20992231
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
21002232
def test_onehot(self, shape1, device):
21012233
if shape1 is None:
@@ -2363,6 +2495,33 @@ def test_stack_non_tensor(self, shape, stack_dim):
23632495
assert new_spec.device == torch.device("cpu")
23642496
assert new_spec.example_data == "example_data"
23652497

2498+
@pytest.mark.parametrize(
2499+
"input_type",
2500+
["bounded_spec", "categorical_spec", "nontensor"],
2501+
)
2502+
def test_stack_choice(self, input_type, shape, stack_dim):
2503+
if input_type == "bounded_spec":
2504+
choices = [Bounded(0, 2.5, shape), Bounded(10, 12, shape)]
2505+
elif input_type == "categorical_spec":
2506+
choices = [Categorical(10, shape), Categorical(20, shape)]
2507+
elif input_type == "nontensor":
2508+
if len(shape) == 0:
2509+
choices = [NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
2510+
else:
2511+
choices = [
2512+
NonTensorStack("a").expand(shape + (1,)).squeeze(-1),
2513+
NonTensorStack("d").expand(shape + (1,)).squeeze(-1),
2514+
]
2515+
2516+
spec0 = Choice(choices)
2517+
spec1 = Choice(choices)
2518+
res = torch.stack([spec0, spec1], stack_dim)
2519+
assert isinstance(res, Choice)
2520+
assert (
2521+
res.shape
2522+
== torch.stack([torch.empty(shape), torch.empty(shape)], stack_dim).shape
2523+
)
2524+
23662525
def test_stack_onehot(self, shape, stack_dim):
23672526
n = 5
23682527
shape = (*shape, 5)

torchrl/collectors/distributed/rpc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def _init_master_rpc(
449449
):
450450
"""Init RPC on main node."""
451451
options = rpc.TensorPipeRpcBackendOptions(**self.tensorpipe_options)
452-
if torch.cuda.device_count():
452+
if torch.cuda.is_available():
453453
if self.visible_devices:
454454
for i in range(self.num_workers):
455455
rank = i + 1

torchrl/data/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
BoundedContinuous,
7777
BoundedTensorSpec,
7878
Categorical,
79+
Choice,
7980
Composite,
8081
CompositeSpec,
8182
DEVICE_TYPING,

0 commit comments

Comments
 (0)