Skip to content

Commit d1b71f7

Browse files
committed
Test: Add regression test for group offloading + delete_adapters
1 parent ff7e6af commit d1b71f7

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

tests/lora/utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from diffusers.utils import logging
3232
from diffusers.utils.import_utils import is_peft_available
33+
from diffusers.hooks.group_offloading import apply_group_offloading
3334

3435
from ..testing_utils import (
3536
CaptureLogger,
@@ -2367,3 +2368,42 @@ def test_lora_loading_model_cpu_offload(self):
23672368

23682369
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
23692370
self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3))
2371+
2372+
@require_torch_accelerator
2373+
def test_lora_group_offloading_delete_adapters(self):
2374+
components, _, denoiser_lora_config = self.get_dummy_components()
2375+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2376+
pipe = self.pipeline_class(**components)
2377+
pipe = pipe.to(torch_device)
2378+
pipe.set_progress_bar_config(disable=None)
2379+
2380+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2381+
denoiser.add_adapter(denoiser_lora_config)
2382+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
2383+
2384+
with tempfile.TemporaryDirectory() as tmpdirname:
2385+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
2386+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
2387+
self.pipeline_class.save_lora_weights(
2388+
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
2389+
)
2390+
2391+
components, _, _ = self.get_dummy_components()
2392+
pipe = self.pipeline_class(**components)
2393+
pipe.to(torch_device)
2394+
2395+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2396+
2397+
# Enable Group Offloading (leaf_level)
2398+
apply_group_offloading(
2399+
denoiser,
2400+
onload_device=torch_device,
2401+
offload_device="cpu",
2402+
offload_type="leaf_level",
2403+
)
2404+
2405+
pipe.load_lora_weights(tmpdirname, adapter_name="default")
2406+
pipe(**inputs, generator=torch.manual_seed(0))
2407+
# Delete the adapter
2408+
pipe.delete_adapters("default")
2409+
pipe(**inputs, generator=torch.manual_seed(0))

0 commit comments

Comments
 (0)