Skip to content

Commit b5283c2

Browse files
committed
Versioning
1 parent 0795ac5 commit b5283c2

File tree

4 files changed

+9
-13
lines changed

4 files changed

+9
-13
lines changed

.github/scripts/td_script.sh

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/bin/bash
22

3-
export TORCHRL_BUILD_VERSION=0.7.1
3+
export TORCHRL_BUILD_VERSION=0.7.2
44

5-
${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U
5+
#${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U
6+
pip install tensordict==0.7.2

.github/scripts/version_script.bat

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@echo off
2-
set TORCHRL_BUILD_VERSION=0.7.1
2+
set TORCHRL_BUILD_VERSION=0.7.2
33
echo TORCHRL_BUILD_VERSION is set to %TORCHRL_BUILD_VERSION%
44

55
@echo on

test/test_exploration.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def test_egreedy_masked(self, module, eps_init, spec_class):
123123
{"observation": torch.zeros(*batch_size, action_size)},
124124
batch_size=batch_size,
125125
)
126-
with pytest.raises(KeyError, match="Action mask key action_mask not found in"):
126+
with pytest.raises(RuntimeError, match="Failed while executing module"):
127127
explorative_policy(td)
128128

129129
torch.manual_seed(0)
@@ -182,9 +182,7 @@ def test_no_spec_error(
182182
batch_size=batch_size,
183183
)
184184

185-
with pytest.raises(
186-
RuntimeError, match="spec must be provided to the exploration wrapper."
187-
):
185+
with pytest.raises(RuntimeError, match="Failed while executing module"):
188186
explorative_policy(td)
189187

190188
@pytest.mark.parametrize("module", [True, False])
@@ -201,9 +199,7 @@ def test_wrong_action_shape(self, module):
201199
policy,
202200
)
203201
td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
204-
with pytest.raises(
205-
ValueError, match="Action spec shape does not match the action shape"
206-
):
202+
with pytest.raises(RuntimeError, match="Failed while executing module"):
207203
explorative_policy(td)
208204

209205

@@ -383,9 +379,8 @@ def test_nested(
383379
)
384380

385381
action_spec = env.action_spec
386-
d_act = action_spec.shape[-1]
382+
action_spec.shape[-1]
387383

388-
net = nn.LazyLinear(d_act).to(device)
389384
policy = TensorDictModule(
390385
CountingEnvCountModule(action_spec=action_spec),
391386
in_keys=[("data", "states") if nested_obs_action else "observation"],

version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.7.1
1+
0.7.2

0 commit comments

Comments
 (0)