Skip to content

[LoRA] loading LoRA into a quantized base model #10550

Closed
@sayakpaul

Description

@sayakpaul

Similar issues:

  1. [LoRA] Quanto Flux LoRA can't load #10512
  2. NF4 quantized flux models with loras #10496
Reproduction
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, FluxTransformer2DModel, FluxPipeline
from huggingface_hub import hf_hub_download


transformer_8bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="transformer",
    quantization_config=DiffusersBitsAndBytesConfig(load_in_8bit=True),
    torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    transformer=transformer_8bit,
    torch_dtype=torch.bfloat16,
).to("cuda")

pipe.load_lora_weights(
    hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), 
    adapter_name="hyper-sd"
)
pipe.set_adapters("hyper-sd", adapter_weights=0.125)

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."

image = pipe(
    prompt=prompt,
    height=1024,
    width=1024,
    max_sequence_length=512,
    num_inference_steps=8,
    guidance_scale=50,
    generator=torch.Generator().manual_seed(42),
).images[0]
image[0].save("out.jpg")

Happens on main as well as v0.31.0-release branch as well.

Error
Traceback (most recent call last):
  File "/home/sayak/diffusers/load_loras_flux.py", line 18, in <module>
    pipe.load_lora_weights(
  File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 1846, in load_lora_weights
    self.load_lora_into_transformer(
  File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 1948, in load_lora_into_transformer
    inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/peft/mapping.py", line 260, in inject_adapter_in_model
    peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/peft/tuners/lora/model.py", line 141, in __init__
    super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 184, in __init__
    self.inject_adapter(self.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 501, in inject_adapter
    self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/peft/tuners/lora/model.py", line 239, in _create_and_replace
    self._replace_module(parent, target_name, new_module, target)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/peft/tuners/lora/model.py", line 263, in _replace_module
    new_module.to(child.weight.device)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1340, in to
    return self._apply(convert)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
    module._apply(fn)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
    module._apply(fn)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 927, in _apply
    param_applied = fn(param)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1333, in convert
    raise NotImplementedError(
NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.

@BenjaminBossan any suggestions here?

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions