12
12
import torch
13
13
import torchrl .data .tensor_specs
14
14
from scipy .stats import chisquare
15
- from tensordict import LazyStackedTensorDict , TensorDict , TensorDictBase
15
+ from tensordict import (
16
+ LazyStackedTensorDict ,
17
+ NonTensorData ,
18
+ NonTensorStack ,
19
+ TensorDict ,
20
+ TensorDictBase ,
21
+ )
16
22
from tensordict .utils import _unravel_key_to_tuple
17
23
from torchrl ._utils import _make_ordinal_device
18
24
23
29
Bounded ,
24
30
BoundedTensorSpec ,
25
31
Categorical ,
32
+ Choice ,
26
33
Composite ,
27
34
CompositeSpec ,
28
35
ContinuousBox ,
@@ -702,6 +709,63 @@ def test_change_batch_size(self, shape, is_complete, device, dtype):
702
709
assert ts ["nested" ].shape == (3 ,)
703
710
704
711
712
+ class TestChoiceSpec :
713
+ @pytest .mark .parametrize ("input_type" , ["spec" , "nontensor" , "nontensorstack" ])
714
+ def test_choice (self , input_type ):
715
+ if input_type == "spec" :
716
+ choices = [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())]
717
+ example_in = torch .tensor (11.0 )
718
+ example_out = torch .tensor (9.0 )
719
+ elif input_type == "nontensor" :
720
+ choices = [NonTensorData ("a" ), NonTensorData ("b" )]
721
+ example_in = NonTensorData ("b" )
722
+ example_out = NonTensorData ("c" )
723
+ elif input_type == "nontensorstack" :
724
+ choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
725
+ example_in = NonTensorStack ("a" , "b" , "c" )
726
+ example_out = NonTensorStack ("a" , "c" , "b" )
727
+ torch .manual_seed (0 )
728
+ for _ in range (10 ):
729
+ spec = Choice (choices )
730
+ res = spec .rand ()
731
+ assert spec .is_in (res )
732
+ assert spec .is_in (example_in )
733
+ assert not spec .is_in (example_out )
734
+
735
+ def test_errors (self ):
736
+ with pytest .raises (TypeError , match = "must be a list" ):
737
+ Choice ("abc" )
738
+
739
+ with pytest .raises (
740
+ TypeError ,
741
+ match = "must be either a TensorSpec, NonTensorData, or NonTensorStack" ,
742
+ ):
743
+ Choice (["abc" ])
744
+
745
+ with pytest .raises (TypeError , match = "must be the same type" ):
746
+ Choice ([Bounded (0 , 1 , (1 ,)), Categorical (10 , (1 ,))])
747
+
748
+ with pytest .raises (ValueError , match = "must have the same shape" ):
749
+ Choice ([Categorical (10 , (1 ,)), Categorical (10 , (2 ,))])
750
+
751
+ with pytest .raises (ValueError , match = "must have the same dtype" ):
752
+ Choice (
753
+ [
754
+ Categorical (10 , (2 ,), dtype = torch .long ),
755
+ Categorical (10 , (2 ,), dtype = torch .float ),
756
+ ]
757
+ )
758
+
759
+ if torch .cuda .is_available ():
760
+ with pytest .raises (ValueError , match = "must have the same device" ):
761
+ Choice (
762
+ [
763
+ Categorical (10 , (2 ,), device = "cpu" ),
764
+ Categorical (10 , (2 ,), device = "cuda" ),
765
+ ]
766
+ )
767
+
768
+
705
769
@pytest .mark .parametrize ("shape" , [(), (2 , 3 )])
706
770
@pytest .mark .parametrize ("device" , get_default_devices ())
707
771
def test_create_composite_nested (shape , device ):
@@ -851,7 +915,7 @@ def test_equality_bounded(self):
851
915
852
916
ts_other = Bounded (minimum , maximum + 1 , torch .Size ((1 ,)), device , dtype )
853
917
assert ts != ts_other
854
- if torch .cuda .device_count ():
918
+ if torch .cuda .is_available ():
855
919
ts_other = Bounded (minimum , maximum , torch .Size ((1 ,)), "cuda:0" , dtype )
856
920
assert ts != ts_other
857
921
@@ -879,7 +943,7 @@ def test_equality_onehot(self):
879
943
)
880
944
assert ts != ts_other
881
945
882
- if torch .cuda .device_count ():
946
+ if torch .cuda .is_available ():
883
947
ts_other = OneHot (
884
948
n = n , device = "cuda:0" , dtype = dtype , use_register = use_register
885
949
)
@@ -909,7 +973,7 @@ def test_equality_unbounded(self):
909
973
ts_same = Unbounded (device = device , dtype = dtype )
910
974
assert ts == ts_same
911
975
912
- if torch .cuda .device_count ():
976
+ if torch .cuda .is_available ():
913
977
ts_other = Unbounded (device = "cuda:0" , dtype = dtype )
914
978
assert ts != ts_other
915
979
@@ -942,7 +1006,7 @@ def test_equality_ndbounded(self):
942
1006
ts_other = Bounded (low = minimum , high = maximum + 1 , device = device , dtype = dtype )
943
1007
assert ts != ts_other
944
1008
945
- if torch .cuda .device_count ():
1009
+ if torch .cuda .is_available ():
946
1010
ts_other = Bounded (low = minimum , high = maximum , device = "cuda:0" , dtype = dtype )
947
1011
assert ts != ts_other
948
1012
@@ -970,7 +1034,7 @@ def test_equality_discrete(self):
970
1034
ts_other = Categorical (n = n + 1 , shape = shape , device = device , dtype = dtype )
971
1035
assert ts != ts_other
972
1036
973
- if torch .cuda .device_count ():
1037
+ if torch .cuda .is_available ():
974
1038
ts_other = Categorical (n = n , shape = shape , device = "cuda:0" , dtype = dtype )
975
1039
assert ts != ts_other
976
1040
@@ -1008,7 +1072,7 @@ def test_equality_ndunbounded(self, shape):
1008
1072
ts_other = Unbounded (shape = other_shape , device = device , dtype = dtype )
1009
1073
assert ts != ts_other
1010
1074
1011
- if torch .cuda .device_count ():
1075
+ if torch .cuda .is_available ():
1012
1076
ts_other = Unbounded (shape = shape , device = "cuda:0" , dtype = dtype )
1013
1077
assert ts != ts_other
1014
1078
@@ -1034,7 +1098,7 @@ def test_equality_binary(self):
1034
1098
ts_other = Binary (n = n + 5 , device = device , dtype = dtype )
1035
1099
assert ts != ts_other
1036
1100
1037
- if torch .cuda .device_count ():
1101
+ if torch .cuda .is_available ():
1038
1102
ts_other = Binary (n = n , device = "cuda:0" , dtype = dtype )
1039
1103
assert ts != ts_other
1040
1104
@@ -1068,7 +1132,7 @@ def test_equality_multi_onehot(self, nvec):
1068
1132
ts_other = MultiOneHot (nvec = other_nvec , device = device , dtype = dtype )
1069
1133
assert ts != ts_other
1070
1134
1071
- if torch .cuda .device_count ():
1135
+ if torch .cuda .is_available ():
1072
1136
ts_other = MultiOneHot (nvec = nvec , device = "cuda:0" , dtype = dtype )
1073
1137
assert ts != ts_other
1074
1138
@@ -1102,7 +1166,7 @@ def test_equality_multi_discrete(self, nvec):
1102
1166
ts_other = MultiCategorical (nvec = other_nvec , device = device , dtype = dtype )
1103
1167
assert ts != ts_other
1104
1168
1105
- if torch .cuda .device_count ():
1169
+ if torch .cuda .is_available ():
1106
1170
ts_other = MultiCategorical (nvec = nvec , device = "cuda:0" , dtype = dtype )
1107
1171
assert ts != ts_other
1108
1172
@@ -1498,6 +1562,19 @@ def test_non_tensor(self):
1498
1562
)
1499
1563
assert spec .expand (2 , 3 , 4 ).example_data == "example_data"
1500
1564
1565
+ @pytest .mark .parametrize ("input_type" , ["spec" , "nontensor" , "nontensorstack" ])
1566
+ def test_choice (self , input_type ):
1567
+ if input_type == "spec" :
1568
+ choices = [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())]
1569
+ elif input_type == "nontensor" :
1570
+ choices = [NonTensorData ("a" ), NonTensorData ("b" )]
1571
+ elif input_type == "nontensorstack" :
1572
+ choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
1573
+
1574
+ spec = Choice (choices )
1575
+ res = spec .expand ([3 ])
1576
+ assert res .shape == torch .Size ([3 ])
1577
+
1501
1578
@pytest .mark .parametrize ("shape1" , [None , (), (5 ,)])
1502
1579
@pytest .mark .parametrize ("shape2" , [(), (10 ,)])
1503
1580
def test_onehot (self , shape1 , shape2 ):
@@ -1701,6 +1778,19 @@ def test_non_tensor(self):
1701
1778
assert spec .clone () is not spec
1702
1779
assert spec .clone ().example_data == "example_data"
1703
1780
1781
+ @pytest .mark .parametrize ("input_type" , ["spec" , "nontensor" , "nontensorstack" ])
1782
+ def test_choice (self , input_type ):
1783
+ if input_type == "spec" :
1784
+ choices = [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())]
1785
+ elif input_type == "nontensor" :
1786
+ choices = [NonTensorData ("a" ), NonTensorData ("b" )]
1787
+ elif input_type == "nontensorstack" :
1788
+ choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
1789
+
1790
+ spec = Choice (choices )
1791
+ assert spec .clone () == spec
1792
+ assert spec .clone () is not spec
1793
+
1704
1794
@pytest .mark .parametrize ("shape1" , [None , (), (5 ,)])
1705
1795
def test_onehot (
1706
1796
self ,
@@ -1786,6 +1876,31 @@ def test_non_tensor(self):
1786
1876
with pytest .raises (RuntimeError , match = "Cannot enumerate a NonTensorSpec." ):
1787
1877
spec .cardinality ()
1788
1878
1879
+ @pytest .mark .parametrize (
1880
+ "input_type" ,
1881
+ ["bounded_spec" , "categorical_spec" , "nontensor" , "nontensorstack" ],
1882
+ )
1883
+ def test_choice (self , input_type ):
1884
+ if input_type == "bounded_spec" :
1885
+ choices = [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())]
1886
+ elif input_type == "categorical_spec" :
1887
+ choices = [Categorical (10 ), Categorical (20 )]
1888
+ elif input_type == "nontensor" :
1889
+ choices = [NonTensorData ("a" ), NonTensorData ("b" ), NonTensorData ("c" )]
1890
+ elif input_type == "nontensorstack" :
1891
+ choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
1892
+
1893
+ spec = Choice (choices )
1894
+
1895
+ if input_type == "bounded_spec" :
1896
+ assert spec .cardinality () == float ("inf" )
1897
+ elif input_type == "categorical_spec" :
1898
+ assert spec .cardinality () == 30
1899
+ elif input_type == "nontensor" :
1900
+ assert spec .cardinality () == 3
1901
+ elif input_type == "nontensorstack" :
1902
+ assert spec .cardinality () == 2
1903
+
1789
1904
@pytest .mark .parametrize ("shape1" , [(5 ,), (5 , 6 )])
1790
1905
def test_onehot (
1791
1906
self ,
@@ -2096,6 +2211,23 @@ def test_non_tensor(self, device):
2096
2211
assert spec .to (device ).device == device
2097
2212
assert spec .to (device ).example_data == "example_data"
2098
2213
2214
+ @pytest .mark .parametrize (
2215
+ "input_type" ,
2216
+ ["bounded_spec" , "categorical_spec" , "nontensor" , "nontensorstack" ],
2217
+ )
2218
+ def test_choice (self , input_type , device ):
2219
+ if input_type == "bounded_spec" :
2220
+ choices = [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())]
2221
+ elif input_type == "categorical_spec" :
2222
+ choices = [Categorical (10 ), Categorical (20 )]
2223
+ elif input_type == "nontensor" :
2224
+ choices = [NonTensorData ("a" ), NonTensorData ("b" )]
2225
+ elif input_type == "nontensorstack" :
2226
+ choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
2227
+
2228
+ spec = Choice (choices )
2229
+ assert spec .to (device ).device == device
2230
+
2099
2231
@pytest .mark .parametrize ("shape1" , [(5 ,), (5 , 6 )])
2100
2232
def test_onehot (self , shape1 , device ):
2101
2233
if shape1 is None :
@@ -2363,6 +2495,33 @@ def test_stack_non_tensor(self, shape, stack_dim):
2363
2495
assert new_spec .device == torch .device ("cpu" )
2364
2496
assert new_spec .example_data == "example_data"
2365
2497
2498
+ @pytest .mark .parametrize (
2499
+ "input_type" ,
2500
+ ["bounded_spec" , "categorical_spec" , "nontensor" ],
2501
+ )
2502
+ def test_stack_choice (self , input_type , shape , stack_dim ):
2503
+ if input_type == "bounded_spec" :
2504
+ choices = [Bounded (0 , 2.5 , shape ), Bounded (10 , 12 , shape )]
2505
+ elif input_type == "categorical_spec" :
2506
+ choices = [Categorical (10 , shape ), Categorical (20 , shape )]
2507
+ elif input_type == "nontensor" :
2508
+ if len (shape ) == 0 :
2509
+ choices = [NonTensorData ("a" ), NonTensorData ("b" ), NonTensorData ("c" )]
2510
+ else :
2511
+ choices = [
2512
+ NonTensorStack ("a" ).expand (shape + (1 ,)).squeeze (- 1 ),
2513
+ NonTensorStack ("d" ).expand (shape + (1 ,)).squeeze (- 1 ),
2514
+ ]
2515
+
2516
+ spec0 = Choice (choices )
2517
+ spec1 = Choice (choices )
2518
+ res = torch .stack ([spec0 , spec1 ], stack_dim )
2519
+ assert isinstance (res , Choice )
2520
+ assert (
2521
+ res .shape
2522
+ == torch .stack ([torch .empty (shape ), torch .empty (shape )], stack_dim ).shape
2523
+ )
2524
+
2366
2525
def test_stack_onehot (self , shape , stack_dim ):
2367
2526
n = 5
2368
2527
shape = (* shape , 5 )
0 commit comments