Skip to content

Commit 1bf9b0c

Browse files
committed
enable
1 parent 06c77c7 commit 1bf9b0c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,8 @@ def test_to_dtype(self):
403403
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
404404

405405
@unittest.skipIf(
406-
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
407-
reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
406+
not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
407+
reason="CPU offload is only available with `accelerate v0.14.0` or higher",
408408
)
409409
def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
410410
components = self.get_dummy_components()
@@ -419,7 +419,7 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
419419
inputs = self.get_dummy_inputs(generator_device)
420420
output_without_offload = pipe(**inputs).frames[0]
421421

422-
pipe.enable_sequential_cpu_offload()
422+
pipe.enable_sequential_cpu_offload(device=torch_device)
423423

424424
inputs = self.get_dummy_inputs(generator_device)
425425
output_with_offload = pipe(**inputs).frames[0]

0 commit comments

Comments
 (0)