You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@BenjaminBossan
A bug occurs in the PEFT library when using multiple LoRA adapters, each with a unique modules_to_save configuration. The issue arises when the modules_to_save from the first LoRA adapter (e.g., adapter_1) is applied to subsequent adapters (e.g., adapter_2), rather than maintaining independent configurations. As a result, modules specified in modules_to_save for adapter_1 also appear in adapter_2, leading to unintended behavior and possibly affecting fine-tuning accuracy. This incorrect handling of modules_to_save causes duplicate entries where only the respective LoRA adapter’s modules should be saved.
Information
The official example scripts
My own modified scripts
Tasks
An officially supported task in the examples folder
My own task or dataset (give details below)
Reproduction
The following example code demonstrates this issue, displaying the model structure where adapter_2 contains modules meant only for adapter_1.
Example Code
importosfromtransformersimportAutoModelForCausalLMfrompeftimportLoraConfig, get_peft_model, PeftModel# Get the directory of the current Python scriptscript_dir=os.path.dirname(os.path.abspath(__file__))
# Define relative paths for adaptersadapter_1_path=os.path.join(script_dir, "adapter_1")
adapter_2_path=os.path.join(script_dir, "adapter_2")
# Load base modelbase_model=AutoModelForCausalLM.from_pretrained("gpt2")
# Define LoRA configs with different modules_to_savelora_config_1=LoraConfig(
r=8,
lora_alpha=32,
target_modules=["c_attn"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
modules_to_save=["lm_head"]
)
lora_config_2=LoraConfig(
r=8,
lora_alpha=32,
target_modules=["c_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
modules_to_save=["wte"]
)
# Apply and save the first adapteros.makedirs(adapter_1_path, exist_ok=True)
model_with_lora_1=get_peft_model(base_model, lora_config_1, adapter_name="adapter_1")
model_with_lora_1.save_pretrained(adapter_1_path)
# Apply and save the second adapteros.makedirs(adapter_2_path, exist_ok=True)
model_with_lora_2=get_peft_model(base_model, lora_config_2, adapter_name="adapter_2")
model_with_lora_2.save_pretrained(adapter_2_path)
# Load a fresh base model and wrap it in PeftModel by loading the first adapterbase_model=AutoModelForCausalLM.from_pretrained("gpt2")
peft_model=PeftModel.from_pretrained(base_model, os.path.join(adapter_1_path, "adapter_1"), adapter_name="adapter_1")
# Load the second adapter into the PeftModelpeft_model.load_adapter(os.path.join(adapter_2_path, "adapter_2"), adapter_name="adapter_2")
# Display structure and inspect unexpected 'modules_to_save' overlapprint("Expected `modules_to_save` for each adapter:")
print("Adapter 1 `modules_to_save`: ['lm_head']")
print("Adapter 2 `modules_to_save`: ['wte']")
print("\nActual model structure and `modules_to_save` contents:\n")
print(peft_model.transformer.wte)
print(peft_model.lm_head)
Thanks a lot for reporting this. Indeed, the handling of modules_to_save can be messy at times and the outcome you show should be avoided. I don't have the opportunity to test this right now, but my assumption is that this extra module won't disrupt the results for adapter 2 because it is a copy of the original layer and behaves exactly the same, as that right?
No worries, glad to be of any help. As far as I have tested it should be fine and using the correct loaded layer, the only problem is redundancy in loaded modules. I also dug a bit deeper and noticed that the problem originates from this function:
The set is not being updated to only the new layer and it will still hold the old layer in the set too (which shouldn't). For example if I manually hack the above script the problem will be solved:
...
# Apply and save the second adapteros.makedirs(adapter_2_path, exist_ok=True)
model_with_lora_2=get_peft_model(base_model, lora_config_2, adapter_name="adapter_2")
model_with_lora_2.save_pretrained(adapter_2_path)
# Load a fresh base model and wrap it in PeftModel by loading the first adapterbase_model=AutoModelForCausalLM.from_pretrained("gpt2")
peft_model=PeftModel.from_pretrained(base_model, os.path.join(adapter_1_path, "adapter_1"), adapter_name="adapter_1")
peft_model.modules_to_save= {"wte"} # <----------- HERE manually changing the modules_to_save# Load the second adapter into the PeftModelpeft_model.load_adapter(os.path.join(adapter_2_path, "adapter_2"), adapter_name="adapter_2")
...
System Info
Python 3.11.9
transformers==4.40.2
peft==0.11.2
Who can help?
@BenjaminBossan
A bug occurs in the PEFT library when using multiple LoRA adapters, each with a unique
modules_to_save
configuration. The issue arises when themodules_to_save
from the first LoRA adapter (e.g.,adapter_1
) is applied to subsequent adapters (e.g.,adapter_2
), rather than maintaining independent configurations. As a result, modules specified inmodules_to_save
foradapter_1
also appear inadapter_2
, leading to unintended behavior and possibly affecting fine-tuning accuracy. This incorrect handling ofmodules_to_save
causes duplicate entries where only the respective LoRA adapter’s modules should be saved.Information
Tasks
examples
folderReproduction
The following example code demonstrates this issue, displaying the model structure where
adapter_2
contains modules meant only foradapter_1
.Example Code
The code output will be:
Expected behavior
As you see adapter 2 is also built for the "lm_head" module to which it shouldn't, the expected output is shown below:
The text was updated successfully, but these errors were encountered: