Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Key mismatch when trying to load a LORA adapter into an XLORA model #2132

Open
2 of 4 tasks
p4arth opened this issue Oct 5, 2024 · 4 comments
Open
2 of 4 tasks

Key mismatch when trying to load a LORA adapter into an XLORA model #2132

p4arth opened this issue Oct 5, 2024 · 4 comments

Comments

@p4arth
Copy link

p4arth commented Oct 5, 2024

System Info

peft==0.13.0

Who can help?

@EricLBuehler

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Code to train a the sample LORA

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType

# Load a small BERT model
model_name = "prajjwal1/bert-tiny"
tokenizer = AutoTokenizer.from_pretrained(model_name)

imdb_dataset = load_dataset("imdb")
imdb_dataset = imdb_dataset["train"].shuffle(seed=42).select(range(1000))

def tokenize_imdb(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)

tokenized_imdb = imdb_dataset.map(tokenize_imdb, batched=True)

# LoRA configuration for sentiment analysis (IMDB)
lora_config_sentiment = LoraConfig(
    r=4,
    lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_CLS,
    use_dora=False
)

# Training function
def train_lora(dataset, num_labels, lora_config, output_dir):
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
    peft_model = get_peft_model(model, lora_config)
    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate=1e-3,
        per_device_train_batch_size=8,
        num_train_epochs=1,
        weight_decay=0.01,
        push_to_hub=False,
    )

    trainer = Trainer(
        model=peft_model,
        args=training_args,
        train_dataset=dataset,
    )

    trainer.train()

train_lora(tokenized_imdb, num_labels=2, lora_config=lora_config_sentiment, output_dir="./lora_sentiment")

Code to load the trained LORA into an XLORA model

from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from peft import get_peft_model, XLoraConfig, TaskType, PeftConfig, PeftModel

# Load pre-trained model and tokenizer
xlora_model_name = "prajjwal1/bert-tiny"
xlora_model = AutoModelForSequenceClassification.from_pretrained(xlora_model_name, use_cache=False)
tokenizer = AutoTokenizer.from_pretrained(xlora_model_name)
xlora_model_config = AutoConfig.from_pretrained(xlora_model_name)

xlora_peft_config = XLoraConfig(
    task_type="SEQ_CLS",
    adapters={
        "adapter_1" : "./lora_sentiment/checkpoint-125",
    },
)

# Apply XLoRA to the model
model = get_peft_model(xlora_model, xlora_peft_config)

Expected behavior

The expected behaviour would be that the LORA adapter should successfully integrate into the XLORA model.
The problem arises from the function _load_adapter_into_lora_model inside the src/tuners/xlora/model.py file.
The function mentioned above adds an extra model. prefix to the keys inside the state_dict of the adapter model.

@BenjaminBossan
Copy link
Member

Yes, I can confirm that it's not working. I condensed the example a little bit and switched to the normal BERT model:

import torch
from transformers import AutoModelForSequenceClassification
from peft import get_peft_model, LoraConfig, TaskType, XLoraConfig

model_name = "google-bert/bert-base-uncased"

lora_config_sentiment = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    init_lora_weights=False,
)

model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
peft_model = get_peft_model(model, lora_config_sentiment)
lora_path = "/tmp/peft/2132"
peft_model.save_pretrained(lora_path)

del model, peft_model

# Load pre-trained model and tokenizer
xlora_model = AutoModelForSequenceClassification.from_pretrained(model_name, use_cache=False)
xlora_peft_config = XLoraConfig(
    task_type="SEQ_CLS",
    adapters={
        "adapter_1" : lora_path,
    },
)

# Apply XLoRA to the model: this raises an error
model = get_peft_model(xlora_model, xlora_peft_config)

The error I get is:

Traceback (most recent call last):
  File "/home/name/work/forks/peft/foo.py", line 31, in <module>
    model = get_peft_model(xlora_model, xlora_peft_config)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/peft/src/peft/mapping.py", line 193, in get_peft_model
    return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/peft/src/peft/peft_model.py", line 1402, in __init__
    super().__init__(model, peft_config, adapter_name, **kwargs)
  File "/home/name/work/forks/peft/src/peft/peft_model.py", line 171, in __init__
    self.base_model = cls(model, {adapter_name: peft_config}, adapter_name)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/peft/src/peft/tuners/xlora/model.py", line 279, in __init__
    _load_adapter_into_lora_model(
  File "/home/name/work/forks/peft/src/peft/tuners/xlora/model.py", line 148, in _load_adapter_into_lora_model
    raise ValueError(
ValueError: Got unexpected keys! Please raise an issue and tag @EricLBuehler.

unexpected_keys=['model.model.bert.encoder.layer.0.attention.self.query.lora_A.0.weight', 'model.model.bert.encoder.layer.0.attention.self.query.lora_B.0.weight', 'model.model.bert.encoder.layer.0.attention.self.value.lora_A.0.weight', 'model.model.bert.encoder.layer.0.attention.self.value.lora_B.0.weight', 'model.model.bert.encoder.layer.1.attention.self.query.lora_A.0.weight', 'model.model.bert.encoder.layer.1.attention.self.query.lora_B.0.weight', 'model.model.bert.encoder.layer.1.attention.self.value.lora_A.0.weight', 'model.model.bert.encoder.layer.1.attention.self.value.lora_B.0.weight', 'model.model.bert.encoder.layer.10.attention.self.query.lora_A.0.weight', 'model.model.bert.encoder.layer.10.attention.self.query.lora_B.0.weight', 'model.model.bert.encoder.layer.10.attention.self.value.lora_A.0.weight', 'model.model.bert.encoder.layer.10.attention.self.value.lora_B.0.weight', 'model.model.bert.encoder.layer.11.attention.self.query.lora_A.0.weight', 'model.model.bert.encoder.layer.11.attention.self.query.lora_B.0.weight', 'model.model.bert.encoder.layer.11.attention.self.value.lora_A.0.weight', 'model.model.bert.encoder.layer.11.attention.self.value.lora_B.0.weight', 'model.model.bert.encoder.layer.2.attention.self.query.lora_A.0.weight', 'model.model.bert.encoder.layer.2.attention.self.query.lora_B.0.weight', 'model.model.bert.encoder.layer.2.attention.self.value.lora_A.0.weight', 'model.model.bert.encoder.layer.2.attention.self.value.lora_B.0.weight', 'model.model.bert.encoder.layer.3.attention.self.query.lora_A.0.weight', 'model.model.bert.encoder.layer.3.attention.self.query.lora_B.0.weight', 'model.model.bert.encoder.layer.3.attention.self.value.lora_A.0.weight', 'model.model.bert.encoder.layer.3.attention.self.value.lora_B.0.weight', 'model.model.bert.encoder.layer.4.attention.self.query.lora_A.0.weight', 'model.model.bert.encoder.layer.4.attention.self.query.lora_B.0.weight', 'model.model.bert.encoder.layer.4.attention.self.value.lora_A.0.weight', 'model.model.bert.encoder.layer.4.attention.self.value.lora_B.0.weight', 'model.model.bert.encoder.layer.5.attention.self.query.lora_A.0.weight', 'model.model.bert.encoder.layer.5.attention.self.query.lora_B.0.weight', 'model.model.bert.encoder.layer.5.attention.self.value.lora_A.0.weight', 'model.model.bert.encoder.layer.5.attention.self.value.lora_B.0.weight', 'model.model.bert.encoder.layer.6.attention.self.query.lora_A.0.weight', 'model.model.bert.encoder.layer.6.attention.self.query.lora_B.0.weight', 'model.model.bert.encoder.layer.6.attention.self.value.lora_A.0.weight', 'model.model.bert.encoder.layer.6.attention.self.value.lora_B.0.weight', 'model.model.bert.encoder.layer.7.attention.self.query.lora_A.0.weight', 'model.model.bert.encoder.layer.7.attention.self.query.lora_B.0.weight', 'model.model.bert.encoder.layer.7.attention.self.value.lora_A.0.weight', 'model.model.bert.encoder.layer.7.attention.self.value.lora_B.0.weight', 'model.model.bert.encoder.layer.8.attention.self.query.lora_A.0.weight', 'model.model.bert.encoder.layer.8.attention.self.query.lora_B.0.weight', 'model.model.bert.encoder.layer.8.attention.self.value.lora_A.0.weight', 'model.model.bert.encoder.layer.8.attention.self.value.lora_B.0.weight', 'model.model.bert.encoder.layer.9.attention.self.query.lora_A.0.weight', 'model.model.bert.encoder.layer.9.attention.self.query.lora_B.0.weight', 'model.model.bert.encoder.layer.9.attention.self.value.lora_A.0.weight', 'model.model.bert.encoder.layer.9.attention.self.value.lora_B.0.weight', 'model.model.classifier.modules_to_save.0.bias', 'model.model.classifier.modules_to_save.0.weight']

@EricLBuehler could you please take a look?

@p4arth
Copy link
Author

p4arth commented Oct 19, 2024

could it be a while until this is addressed?

@BenjaminBossan
Copy link
Member

@EricLBuehler Do you know if you have time to take a look at this soon?

@SongHanKen
Copy link

Hi there,
i try to solve this bug, here is my solution:
for old_key in adapter_weights.keys():
key: str = old_key
# Remove all the prefixes until we have model.<...>
while not (key.startswith("model.") and not key.startswith("model.model.")):
key = key[key.find(".") + 1 :]

    # change
    if key.startswith("model.model") or key.startswith("model."):
      key = key.replace("model.","")
      
    # We always want model.model
    key = "model." + key
    new_adapter_weights[key] = adapter_weights[old_key]

i just add if key.startswith("model.model") or key.startswith("model."): key = key.replace("model.","") in the line 135 of model.py, but not sure it is an effective way, hoping it is helpful

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants