@@ -123,7 +123,7 @@ def test_egreedy_masked(self, module, eps_init, spec_class):
123
123
{"observation" : torch .zeros (* batch_size , action_size )},
124
124
batch_size = batch_size ,
125
125
)
126
- with pytest .raises (KeyError , match = "Action mask key action_mask not found in " ):
126
+ with pytest .raises (RuntimeError , match = "Failed while executing module " ):
127
127
explorative_policy (td )
128
128
129
129
torch .manual_seed (0 )
@@ -182,9 +182,7 @@ def test_no_spec_error(
182
182
batch_size = batch_size ,
183
183
)
184
184
185
- with pytest .raises (
186
- RuntimeError , match = "spec must be provided to the exploration wrapper."
187
- ):
185
+ with pytest .raises (RuntimeError , match = "Failed while executing module" ):
188
186
explorative_policy (td )
189
187
190
188
@pytest .mark .parametrize ("module" , [True , False ])
@@ -201,9 +199,7 @@ def test_wrong_action_shape(self, module):
201
199
policy ,
202
200
)
203
201
td = TensorDict ({"observation" : torch .zeros (10 , 4 )}, batch_size = [10 ])
204
- with pytest .raises (
205
- ValueError , match = "Action spec shape does not match the action shape"
206
- ):
202
+ with pytest .raises (RuntimeError , match = "Failed while executing module" ):
207
203
explorative_policy (td )
208
204
209
205
@@ -383,9 +379,8 @@ def test_nested(
383
379
)
384
380
385
381
action_spec = env .action_spec
386
- d_act = action_spec .shape [- 1 ]
382
+ action_spec .shape [- 1 ]
387
383
388
- net = nn .LazyLinear (d_act ).to (device )
389
384
policy = TensorDictModule (
390
385
CountingEnvCountModule (action_spec = action_spec ),
391
386
in_keys = [("data" , "states" ) if nested_obs_action else "observation" ],
0 commit comments