Skip to content

Commit e5ef188

Browse files
authored
Ensure safe_to_tensor moves tensors to the specified device. (#831)
This PR fixes a bug in the `safe_to_tensor` utility: previously it did not move tensors to a new device according to the `device` kwarg which caused issues when there is more than one device available. The bug went unnoticed for a long while since our circleCI runners do not have GPUs enabled.
1 parent a8b079c commit e5ef188

File tree

10 files changed

+32
-16
lines changed

10 files changed

+32
-16
lines changed

.circleci/config.yml

+14-6
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ commands:
9090
steps:
9191
- run:
9292
name: install macOS packages
93-
command: HOMEBREW_NO_AUTO_UPDATE=1 brew install coreutils parallel gnu-getopt
93+
command: HOMEBREW_NO_AUTO_UPDATE=1 brew install coreutils gnu-getopt parallel [email protected] virtualenv
9494

9595
- checkout
9696

@@ -138,11 +138,13 @@ commands:
138138
# Download and cache dependencies
139139
- restore_cache:
140140
keys:
141-
- v11win-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.ps1" }}
141+
- v13win-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.ps1" }}
142142

143143
- run:
144144
name: install python
145-
command: choco install --allow-downgrade -y python --version=3.8.10
145+
# Use python3.9 in Windows instead of python3.8 because otherwise
146+
# pytest-notebook's indirect dependency pywinpty will fail to build.
147+
command: choco install --allow-downgrade -y python --version=3.9.13
146148
shell: powershell.exe
147149

148150
- run:
@@ -163,14 +165,20 @@ commands:
163165

164166
- run:
165167
name: install dependencies
166-
# Only create venv if it's not been restored from cache
167-
command: if (-not (Test-Path venv)) { .\ci\build_and_activate_venv.ps1 -venv venv }
168+
# Only create venv if it's not been restored from cache.
169+
# Need to throw error explicitly on error or else {} will get rid of
170+
# the exit code.
171+
command: |
172+
if (-not (Test-Path venv)) {
173+
.\ci\build_and_activate_venv.ps1 -venv venv
174+
if ($LASTEXITCODE -ne 0) { throw "Failed to create venv" }
175+
}
168176
shell: powershell.exe
169177

170178
- save_cache:
171179
paths:
172180
- .\venv
173-
key: v11win-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.ps1" }}
181+
key: v13win-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.ps1" }}
174182

175183
- run:
176184
name: install imitation

ci/build_and_activate_venv.ps1

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ If ($venv -eq $null) {
77
$venv = "venv"
88
}
99

10-
virtualenv -p python3.8 $venv
10+
virtualenv -p python3.9 $venv
1111
& $venv\Scripts\activate
1212
pip install ".[docs,parallel,test]"

mypy.ini

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
[mypy]
22
ignore_missing_imports = true
33
exclude = output
4+
5+
# torch had some type errors, we ignore them because they're not our fault
6+
[mypy-torch._dynamo.*]
7+
follow_imports = skip
8+
follow_imports_for_stubs = True

src/imitation/algorithms/adversarial/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def __init__(
207207
self.debug_use_ground_truth = debug_use_ground_truth
208208
self.venv = venv
209209
self.gen_algo = gen_algo
210-
self._reward_net = reward_net.to(gen_algo.device)
210+
self._reward_net: reward_nets.RewardNet = reward_net.to(gen_algo.device)
211211
self._log_dir = util.parse_path(log_dir)
212212

213213
# Create graph for optimising/recording stats on discriminator

src/imitation/data/serialize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def save(path: AnyPath, trajectories: Sequence[Trajectory]) -> None:
2020
trajectories: The trajectories to save.
2121
"""
2222
p = util.parse_path(path)
23-
huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p)
23+
huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(str(p))
2424
logging.info(f"Dumped demonstrations to {p}.")
2525

2626

src/imitation/policies/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _predict(
3131
):
3232
np_actions = []
3333
if isinstance(obs, dict):
34-
np_obs = types.DictObs(
34+
np_obs: Union[types.DictObs, np.ndarray] = types.DictObs(
3535
{k: v.detach().cpu().numpy() for k, v in obs.items()},
3636
)
3737
else:

src/imitation/util/util.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,7 @@ def safe_to_tensor(array: Union[np.ndarray, th.Tensor], **kwargs) -> th.Tensor:
255255
Returns:
256256
A PyTorch tensor with the same content as `array`.
257257
"""
258-
if isinstance(array, th.Tensor):
259-
return array
260-
261-
if not array.flags.writeable:
258+
if isinstance(array, np.ndarray) and not array.flags.writeable:
262259
array = array.copy()
263260

264261
return th.as_tensor(array, **kwargs)
@@ -476,6 +473,6 @@ def split_in_half(x: int) -> Tuple[int, int]:
476473
def clear_screen() -> None:
477474
"""Clears the console screen."""
478475
if os.name == "nt": # Windows
479-
os.system("cls")
476+
os.system("cls") # pragma: no cover
480477
else:
481478
os.system("clear")

tests/algorithms/test_sqil.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_sqil_performance_continuous(
247247
pytestconfig: pytest.Config,
248248
pendulum_single_venv: vec_env.VecEnv,
249249
rl_algo_class: Type[off_policy_algorithm.OffPolicyAlgorithm],
250-
):
250+
): # pragma: no cover
251251
rl_kwargs = dict(
252252
learning_starts=500,
253253
learning_rate=0.001,

tests/data/test_huggingface_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def test_save_load_roundtrip(
6363

6464

6565
@hypothesis.given(st.data(), h_strats.trajectories_list)
66+
# the first run sometimes takes longer, so we give it more time
67+
@hypothesis.settings(deadline=datetime.timedelta(milliseconds=300))
6668
def test_sliced_access(data: st.DataObject, trajectories: Sequence[types.Trajectory]):
6769
"""Test that slicing a TrajectoryDatasetSequence behaves as expected."""
6870
# GIVEN
@@ -84,6 +86,8 @@ def test_sliced_access(data: st.DataObject, trajectories: Sequence[types.Traject
8486

8587

8688
@hypothesis.given(st.data(), h_strats.trajectory)
89+
# the first run sometimes takes longer, so we give it more time
90+
@hypothesis.settings(deadline=datetime.timedelta(milliseconds=300))
8791
def test_sliced_info_dict_access(
8892
data: st.DataObject,
8993
trajectory: types.Trajectory,

tests/util/test_util.py

+2
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def test_safe_to_numpy():
103103
numpy = util.safe_to_numpy(tensor)
104104
assert (numpy == tensor.numpy()).all()
105105
assert util.safe_to_numpy(None) is None
106+
with pytest.warns(UserWarning, match=".*performance.*"):
107+
util.safe_to_numpy(tensor, warn=True)
106108

107109

108110
def test_tensor_iter_norm():

0 commit comments

Comments
 (0)