1212import torch
1313import torchrl .data .tensor_specs
1414from scipy .stats import chisquare
15- from tensordict import LazyStackedTensorDict , TensorDict , TensorDictBase
15+ from tensordict import (
16+ LazyStackedTensorDict ,
17+ NonTensorData ,
18+ NonTensorStack ,
19+ TensorDict ,
20+ TensorDictBase ,
21+ )
1622from tensordict .utils import _unravel_key_to_tuple
1723from torchrl ._utils import _make_ordinal_device
1824
2329 Bounded ,
2430 BoundedTensorSpec ,
2531 Categorical ,
32+ Choice ,
2633 Composite ,
2734 CompositeSpec ,
2835 ContinuousBox ,
@@ -702,6 +709,63 @@ def test_change_batch_size(self, shape, is_complete, device, dtype):
702709 assert ts ["nested" ].shape == (3 ,)
703710
704711
712+ class TestChoiceSpec :
713+ @pytest .mark .parametrize ("input_type" , ["spec" , "nontensor" , "nontensorstack" ])
714+ def test_choice (self , input_type ):
715+ if input_type == "spec" :
716+ choices = [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())]
717+ example_in = torch .tensor (11.0 )
718+ example_out = torch .tensor (9.0 )
719+ elif input_type == "nontensor" :
720+ choices = [NonTensorData ("a" ), NonTensorData ("b" )]
721+ example_in = NonTensorData ("b" )
722+ example_out = NonTensorData ("c" )
723+ elif input_type == "nontensorstack" :
724+ choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
725+ example_in = NonTensorStack ("a" , "b" , "c" )
726+ example_out = NonTensorStack ("a" , "c" , "b" )
727+ torch .manual_seed (0 )
728+ for _ in range (10 ):
729+ spec = Choice (choices )
730+ res = spec .rand ()
731+ assert spec .is_in (res )
732+ assert spec .is_in (example_in )
733+ assert not spec .is_in (example_out )
734+
735+ def test_errors (self ):
736+ with pytest .raises (TypeError , match = "must be a list" ):
737+ Choice ("abc" )
738+
739+ with pytest .raises (
740+ TypeError ,
741+ match = "must be either a TensorSpec, NonTensorData, or NonTensorStack" ,
742+ ):
743+ Choice (["abc" ])
744+
745+ with pytest .raises (TypeError , match = "must be the same type" ):
746+ Choice ([Bounded (0 , 1 , (1 ,)), Categorical (10 , (1 ,))])
747+
748+ with pytest .raises (ValueError , match = "must have the same shape" ):
749+ Choice ([Categorical (10 , (1 ,)), Categorical (10 , (2 ,))])
750+
751+ with pytest .raises (ValueError , match = "must have the same dtype" ):
752+ Choice (
753+ [
754+ Categorical (10 , (2 ,), dtype = torch .long ),
755+ Categorical (10 , (2 ,), dtype = torch .float ),
756+ ]
757+ )
758+
759+ if torch .cuda .is_available ():
760+ with pytest .raises (ValueError , match = "must have the same device" ):
761+ Choice (
762+ [
763+ Categorical (10 , (2 ,), device = "cpu" ),
764+ Categorical (10 , (2 ,), device = "cuda" ),
765+ ]
766+ )
767+
768+
705769@pytest .mark .parametrize ("shape" , [(), (2 , 3 )])
706770@pytest .mark .parametrize ("device" , get_default_devices ())
707771def test_create_composite_nested (shape , device ):
@@ -851,7 +915,7 @@ def test_equality_bounded(self):
851915
852916 ts_other = Bounded (minimum , maximum + 1 , torch .Size ((1 ,)), device , dtype )
853917 assert ts != ts_other
854- if torch .cuda .device_count ():
918+ if torch .cuda .is_available ():
855919 ts_other = Bounded (minimum , maximum , torch .Size ((1 ,)), "cuda:0" , dtype )
856920 assert ts != ts_other
857921
@@ -879,7 +943,7 @@ def test_equality_onehot(self):
879943 )
880944 assert ts != ts_other
881945
882- if torch .cuda .device_count ():
946+ if torch .cuda .is_available ():
883947 ts_other = OneHot (
884948 n = n , device = "cuda:0" , dtype = dtype , use_register = use_register
885949 )
@@ -909,7 +973,7 @@ def test_equality_unbounded(self):
909973 ts_same = Unbounded (device = device , dtype = dtype )
910974 assert ts == ts_same
911975
912- if torch .cuda .device_count ():
976+ if torch .cuda .is_available ():
913977 ts_other = Unbounded (device = "cuda:0" , dtype = dtype )
914978 assert ts != ts_other
915979
@@ -942,7 +1006,7 @@ def test_equality_ndbounded(self):
9421006 ts_other = Bounded (low = minimum , high = maximum + 1 , device = device , dtype = dtype )
9431007 assert ts != ts_other
9441008
945- if torch .cuda .device_count ():
1009+ if torch .cuda .is_available ():
9461010 ts_other = Bounded (low = minimum , high = maximum , device = "cuda:0" , dtype = dtype )
9471011 assert ts != ts_other
9481012
@@ -970,7 +1034,7 @@ def test_equality_discrete(self):
9701034 ts_other = Categorical (n = n + 1 , shape = shape , device = device , dtype = dtype )
9711035 assert ts != ts_other
9721036
973- if torch .cuda .device_count ():
1037+ if torch .cuda .is_available ():
9741038 ts_other = Categorical (n = n , shape = shape , device = "cuda:0" , dtype = dtype )
9751039 assert ts != ts_other
9761040
@@ -1008,7 +1072,7 @@ def test_equality_ndunbounded(self, shape):
10081072 ts_other = Unbounded (shape = other_shape , device = device , dtype = dtype )
10091073 assert ts != ts_other
10101074
1011- if torch .cuda .device_count ():
1075+ if torch .cuda .is_available ():
10121076 ts_other = Unbounded (shape = shape , device = "cuda:0" , dtype = dtype )
10131077 assert ts != ts_other
10141078
@@ -1034,7 +1098,7 @@ def test_equality_binary(self):
10341098 ts_other = Binary (n = n + 5 , device = device , dtype = dtype )
10351099 assert ts != ts_other
10361100
1037- if torch .cuda .device_count ():
1101+ if torch .cuda .is_available ():
10381102 ts_other = Binary (n = n , device = "cuda:0" , dtype = dtype )
10391103 assert ts != ts_other
10401104
@@ -1068,7 +1132,7 @@ def test_equality_multi_onehot(self, nvec):
10681132 ts_other = MultiOneHot (nvec = other_nvec , device = device , dtype = dtype )
10691133 assert ts != ts_other
10701134
1071- if torch .cuda .device_count ():
1135+ if torch .cuda .is_available ():
10721136 ts_other = MultiOneHot (nvec = nvec , device = "cuda:0" , dtype = dtype )
10731137 assert ts != ts_other
10741138
@@ -1102,7 +1166,7 @@ def test_equality_multi_discrete(self, nvec):
11021166 ts_other = MultiCategorical (nvec = other_nvec , device = device , dtype = dtype )
11031167 assert ts != ts_other
11041168
1105- if torch .cuda .device_count ():
1169+ if torch .cuda .is_available ():
11061170 ts_other = MultiCategorical (nvec = nvec , device = "cuda:0" , dtype = dtype )
11071171 assert ts != ts_other
11081172
@@ -1498,6 +1562,19 @@ def test_non_tensor(self):
14981562 )
14991563 assert spec .expand (2 , 3 , 4 ).example_data == "example_data"
15001564
1565+ @pytest .mark .parametrize ("input_type" , ["spec" , "nontensor" , "nontensorstack" ])
1566+ def test_choice (self , input_type ):
1567+ if input_type == "spec" :
1568+ choices = [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())]
1569+ elif input_type == "nontensor" :
1570+ choices = [NonTensorData ("a" ), NonTensorData ("b" )]
1571+ elif input_type == "nontensorstack" :
1572+ choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
1573+
1574+ spec = Choice (choices )
1575+ res = spec .expand ([3 ])
1576+ assert res .shape == torch .Size ([3 ])
1577+
15011578 @pytest .mark .parametrize ("shape1" , [None , (), (5 ,)])
15021579 @pytest .mark .parametrize ("shape2" , [(), (10 ,)])
15031580 def test_onehot (self , shape1 , shape2 ):
@@ -1701,6 +1778,19 @@ def test_non_tensor(self):
17011778 assert spec .clone () is not spec
17021779 assert spec .clone ().example_data == "example_data"
17031780
1781+ @pytest .mark .parametrize ("input_type" , ["spec" , "nontensor" , "nontensorstack" ])
1782+ def test_choice (self , input_type ):
1783+ if input_type == "spec" :
1784+ choices = [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())]
1785+ elif input_type == "nontensor" :
1786+ choices = [NonTensorData ("a" ), NonTensorData ("b" )]
1787+ elif input_type == "nontensorstack" :
1788+ choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
1789+
1790+ spec = Choice (choices )
1791+ assert spec .clone () == spec
1792+ assert spec .clone () is not spec
1793+
17041794 @pytest .mark .parametrize ("shape1" , [None , (), (5 ,)])
17051795 def test_onehot (
17061796 self ,
@@ -1786,6 +1876,31 @@ def test_non_tensor(self):
17861876 with pytest .raises (RuntimeError , match = "Cannot enumerate a NonTensorSpec." ):
17871877 spec .cardinality ()
17881878
1879+ @pytest .mark .parametrize (
1880+ "input_type" ,
1881+ ["bounded_spec" , "categorical_spec" , "nontensor" , "nontensorstack" ],
1882+ )
1883+ def test_choice (self , input_type ):
1884+ if input_type == "bounded_spec" :
1885+ choices = [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())]
1886+ elif input_type == "categorical_spec" :
1887+ choices = [Categorical (10 ), Categorical (20 )]
1888+ elif input_type == "nontensor" :
1889+ choices = [NonTensorData ("a" ), NonTensorData ("b" ), NonTensorData ("c" )]
1890+ elif input_type == "nontensorstack" :
1891+ choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
1892+
1893+ spec = Choice (choices )
1894+
1895+ if input_type == "bounded_spec" :
1896+ assert spec .cardinality () == float ("inf" )
1897+ elif input_type == "categorical_spec" :
1898+ assert spec .cardinality () == 30
1899+ elif input_type == "nontensor" :
1900+ assert spec .cardinality () == 3
1901+ elif input_type == "nontensorstack" :
1902+ assert spec .cardinality () == 2
1903+
17891904 @pytest .mark .parametrize ("shape1" , [(5 ,), (5 , 6 )])
17901905 def test_onehot (
17911906 self ,
@@ -2096,6 +2211,23 @@ def test_non_tensor(self, device):
20962211 assert spec .to (device ).device == device
20972212 assert spec .to (device ).example_data == "example_data"
20982213
2214+ @pytest .mark .parametrize (
2215+ "input_type" ,
2216+ ["bounded_spec" , "categorical_spec" , "nontensor" , "nontensorstack" ],
2217+ )
2218+ def test_choice (self , input_type , device ):
2219+ if input_type == "bounded_spec" :
2220+ choices = [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())]
2221+ elif input_type == "categorical_spec" :
2222+ choices = [Categorical (10 ), Categorical (20 )]
2223+ elif input_type == "nontensor" :
2224+ choices = [NonTensorData ("a" ), NonTensorData ("b" )]
2225+ elif input_type == "nontensorstack" :
2226+ choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
2227+
2228+ spec = Choice (choices )
2229+ assert spec .to (device ).device == device
2230+
20992231 @pytest .mark .parametrize ("shape1" , [(5 ,), (5 , 6 )])
21002232 def test_onehot (self , shape1 , device ):
21012233 if shape1 is None :
@@ -2363,6 +2495,33 @@ def test_stack_non_tensor(self, shape, stack_dim):
23632495 assert new_spec .device == torch .device ("cpu" )
23642496 assert new_spec .example_data == "example_data"
23652497
2498+ @pytest .mark .parametrize (
2499+ "input_type" ,
2500+ ["bounded_spec" , "categorical_spec" , "nontensor" ],
2501+ )
2502+ def test_stack_choice (self , input_type , shape , stack_dim ):
2503+ if input_type == "bounded_spec" :
2504+ choices = [Bounded (0 , 2.5 , shape ), Bounded (10 , 12 , shape )]
2505+ elif input_type == "categorical_spec" :
2506+ choices = [Categorical (10 , shape ), Categorical (20 , shape )]
2507+ elif input_type == "nontensor" :
2508+ if len (shape ) == 0 :
2509+ choices = [NonTensorData ("a" ), NonTensorData ("b" ), NonTensorData ("c" )]
2510+ else :
2511+ choices = [
2512+ NonTensorStack ("a" ).expand (shape + (1 ,)).squeeze (- 1 ),
2513+ NonTensorStack ("d" ).expand (shape + (1 ,)).squeeze (- 1 ),
2514+ ]
2515+
2516+ spec0 = Choice (choices )
2517+ spec1 = Choice (choices )
2518+ res = torch .stack ([spec0 , spec1 ], stack_dim )
2519+ assert isinstance (res , Choice )
2520+ assert (
2521+ res .shape
2522+ == torch .stack ([torch .empty (shape ), torch .empty (shape )], stack_dim ).shape
2523+ )
2524+
23662525 def test_stack_onehot (self , shape , stack_dim ):
23672526 n = 5
23682527 shape = (* shape , 5 )
0 commit comments