3030)
3131from diffusers .utils import logging
3232from diffusers .utils .import_utils import is_peft_available
33+ from diffusers .hooks .group_offloading import apply_group_offloading
3334
3435from ..testing_utils import (
3536 CaptureLogger ,
@@ -2367,3 +2368,42 @@ def test_lora_loading_model_cpu_offload(self):
23672368
23682369 output_lora_loaded = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
23692370 self .assertTrue (np .allclose (output_lora , output_lora_loaded , atol = 1e-3 , rtol = 1e-3 ))
2371+
2372+ @require_torch_accelerator
2373+ def test_lora_group_offloading_delete_adapters (self ):
2374+ components , _ , denoiser_lora_config = self .get_dummy_components ()
2375+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2376+ pipe = self .pipeline_class (** components )
2377+ pipe = pipe .to (torch_device )
2378+ pipe .set_progress_bar_config (disable = None )
2379+
2380+ denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
2381+ denoiser .add_adapter (denoiser_lora_config )
2382+ self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
2383+
2384+ with tempfile .TemporaryDirectory () as tmpdirname :
2385+ modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
2386+ lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
2387+ self .pipeline_class .save_lora_weights (
2388+ save_directory = tmpdirname , safe_serialization = True , ** lora_state_dicts
2389+ )
2390+
2391+ components , _ , _ = self .get_dummy_components ()
2392+ pipe = self .pipeline_class (** components )
2393+ pipe .to (torch_device )
2394+
2395+ denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
2396+
2397+ # Enable Group Offloading (leaf_level)
2398+ apply_group_offloading (
2399+ denoiser ,
2400+ onload_device = torch_device ,
2401+ offload_device = "cpu" ,
2402+ offload_type = "leaf_level" ,
2403+ )
2404+
2405+ pipe .load_lora_weights (tmpdirname , adapter_name = "default" )
2406+ pipe (** inputs , generator = torch .manual_seed (0 ))
2407+ # Delete the adapter
2408+ pipe .delete_adapters ("default" )
2409+ pipe (** inputs , generator = torch .manual_seed (0 ))
0 commit comments