Open
Description
Describe the bug
We have a KeyError when the state dict goes to load into the transformer.
Reproduction
import torch
from diffusers.models import FluxTransformer2DModel
from peft import LoraConfig, set_peft_model_state_dict
from diffusers import FluxPipeline
from diffusers.utils import convert_unet_state_dict_to_peft
from accelerate import Accelerator
from optimum.quanto import quantize, freeze, qint8
# Step 1: FluxTransformer2DModel Model
print("loading model")
model = FluxTransformer2DModel.from_pretrained('black-forest-labs/FLUX.1-dev', subfolder='transformer')
accelerator = Accelerator()
# Step 2: Add LoRA Adapter
config = LoraConfig(r=8, lora_alpha=16, target_modules=["to_q", "to_k"], lora_dropout=0.1)
print("adding adapter (random)")
model.add_adapter(config)
# Step 3: Quantize the Model
print("quantizing model")
quantize(model, weights=qint8)
print("freezing model")
freeze(model)
# Step 4: prepare model
print("prepare model")
model = accelerator.prepare(model)
# Step 5: Load LoRA State Dictionary
lora_path = "checkpoint/"
print("retrieve lora state dict")
lora_state_dict = FluxPipeline.lora_state_dict(lora_path)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v
for k, v in lora_state_dict.items()
if k.startswith("unet.")
}
print("convert state dict")
transformer_state_dict = convert_unet_state_dict_to_peft(
transformer_state_dict
)
incompatible_keys = set_peft_model_state_dict(
model, transformer_state_dict, adapter_name="default"
)
print(f"unexpected keys: {incompatible_keys.unexpected_keys}")
Logs
adding adapter (random)
quantizing model
freezing model
prepare model
retrieve lora state dict
convert state dict
peft_model_state_dict: dict_keys([])
Loading with state_dict odict_keys([])
state_dict: dict_keys([])
prefix: time_text_embed.timestep_embedder.linear_1.weight.
Traceback (most recent call last):
File "/Users/bghira/src/SimpleTuner/test_flux.py", line 41, in <module>
)
File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/peft/utils/save_and_load.py", line 354, in set_peft_model_state_dict
load_result = model.load_state_dict(peft_model_state_dict, strict=False)
File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2201, in load_state_dict
load(self, state_dict)
File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
load(child, child_state_dict, child_prefix) # noqa: F821
File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
load(child, child_state_dict, child_prefix) # noqa: F821
File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
load(child, child_state_dict, child_prefix) # noqa: F821
File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2183, in load
module._load_from_state_dict(
File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 159, in _load_from_state_dict
deserialized_weight = QBytesTensor.load_from_state_dict(
File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/optimum/quanto/tensor/qbytes.py", line 92, in load_from_state_dict
inner_tensors_dict[name] = state_dict.pop(prefix + name)
KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data'
System Info
latest git main