Skip to content

can't resume training with quantized base model #9108

Open
@bghira

Description

@bghira

Describe the bug

We have a KeyError when the state dict goes to load into the transformer.

Reproduction

import torch
from diffusers.models import FluxTransformer2DModel
from peft import LoraConfig, set_peft_model_state_dict
from diffusers import FluxPipeline
from diffusers.utils import convert_unet_state_dict_to_peft
from accelerate import Accelerator
from optimum.quanto import quantize, freeze, qint8

# Step 1: FluxTransformer2DModel Model
print("loading model")
model = FluxTransformer2DModel.from_pretrained('black-forest-labs/FLUX.1-dev', subfolder='transformer')
accelerator = Accelerator()
# Step 2: Add LoRA Adapter
config = LoraConfig(r=8, lora_alpha=16, target_modules=["to_q", "to_k"], lora_dropout=0.1)
print("adding adapter (random)")
model.add_adapter(config)

# Step 3: Quantize the Model
print("quantizing model")
quantize(model, weights=qint8)
print("freezing model")
freeze(model)

# Step 4: prepare model
print("prepare model")
model = accelerator.prepare(model)

# Step 5: Load LoRA State Dictionary
lora_path = "checkpoint/"

print("retrieve lora state dict")
lora_state_dict = FluxPipeline.lora_state_dict(lora_path)
transformer_state_dict = {
    f'{k.replace("transformer.", "")}': v
    for k, v in lora_state_dict.items()
    if k.startswith("unet.")
}
print("convert state dict")
transformer_state_dict = convert_unet_state_dict_to_peft(
    transformer_state_dict
)
incompatible_keys = set_peft_model_state_dict(
    model, transformer_state_dict, adapter_name="default"
)
print(f"unexpected keys: {incompatible_keys.unexpected_keys}")

Logs

adding adapter (random)
quantizing model
freezing model
prepare model
retrieve lora state dict
convert state dict
peft_model_state_dict: dict_keys([])
Loading with state_dict odict_keys([])
state_dict: dict_keys([])
prefix: time_text_embed.timestep_embedder.linear_1.weight.
Traceback (most recent call last):
  File "/Users/bghira/src/SimpleTuner/test_flux.py", line 41, in <module>
    )
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/peft/utils/save_and_load.py", line 354, in set_peft_model_state_dict
    load_result = model.load_state_dict(peft_model_state_dict, strict=False)
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2201, in load_state_dict
    load(self, state_dict)
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2183, in load
    module._load_from_state_dict(
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 159, in _load_from_state_dict
    deserialized_weight = QBytesTensor.load_from_state_dict(
  File "/Users/bghira/src/SimpleTuner/.venv/lib/python3.10/site-packages/optimum/quanto/tensor/qbytes.py", line 92, in load_from_state_dict
    inner_tensors_dict[name] = state_dict.pop(prefix + name)
KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data'

System Info

latest git main

Who can help?

@sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions