Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import safetensors
import torch

from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
Expand Down Expand Up @@ -792,6 +793,8 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)

_maybe_remove_and_reapply_group_offloading(self)

def enable_lora_hotswap(
self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error"
) -> None:
Expand Down
41 changes: 41 additions & 0 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AutoencoderKL,
UNet2DConditionModel,
)
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import logging
from diffusers.utils.import_utils import is_peft_available

Expand Down Expand Up @@ -2367,3 +2368,43 @@ def test_lora_loading_model_cpu_offload(self):

output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3))

@require_torch_accelerator
def test_lora_group_offloading_delete_adapters(self):
components, _, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)

components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)

denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet

# Enable Group Offloading (leaf_level)
apply_group_offloading(
denoiser,
onload_device=torch_device,
offload_device="cpu",
offload_type="leaf_level",
)

pipe.load_lora_weights(tmpdirname, adapter_name="default")
out_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
# Delete the adapter
pipe.delete_adapters("default")
out_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(out_lora, out_no_lora, atol=1e-3, rtol=1e-3))
Loading