-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Fix error of PEFT with disable adapters and FSDP #3001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Isalia20
wants to merge
6
commits into
huggingface:main
Choose a base branch
from
Isalia20:fix-fsdp-disable-adapters
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+191
−6
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
91a89be
fix
Isalia20 1c93c14
fix (#1)
Isalia20 13b9a75
Merge branch 'huggingface:main' into main
Isalia20 34d22fb
Merge branch 'main' into fix-fsdp-disable-adapters
Isalia20 daa01af
fix bugs
Isalia20 d4258c1
simplify tests, update comments
Isalia20 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| # Copyright 2025-present the HuggingFace Inc. team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| Script to test FSDP adapter operations (disable_adapters, set_adapter, etc.) in a distributed environment. | ||
|
|
||
| This script is designed to be run with `accelerate launch` to properly test FSDP behavior while running one pass with autograd and another with adapters being disabled. | ||
|
|
||
| Usage: | ||
| accelerate launch --config_file tests/training/fsdp_config.yaml tests/training/adapters.py | ||
| """ | ||
|
|
||
| import argparse | ||
| import tempfile | ||
|
|
||
| import torch | ||
| from accelerate import PartialState | ||
| from datasets import load_dataset | ||
| from torch import nn | ||
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
| from transformers import ( | ||
| AutoModelForCausalLM, | ||
| AutoTokenizer, | ||
| DataCollatorForLanguageModeling, | ||
| Trainer, | ||
| TrainingArguments, | ||
| ) | ||
|
|
||
| from peft import LoraConfig, get_peft_model | ||
|
|
||
|
|
||
| def get_base_model_weights(peft_model): | ||
| """Extract base model weights (non-LoRA weights).""" | ||
| base_weights = {} | ||
| for name, param in peft_model.named_parameters(): | ||
| if "lora" not in name.lower(): | ||
| base_weights[name] = param.detach().clone() | ||
| return base_weights | ||
|
|
||
|
|
||
| def get_adapter_weights(peft_model, adapter_name): | ||
| """Extract weights for a specific adapter.""" | ||
| adapter_weights = {} | ||
| for name, param in peft_model.named_parameters(): | ||
| if adapter_name in name: | ||
| adapter_weights[name] = param.detach().clone() | ||
| return adapter_weights | ||
|
|
||
|
|
||
| def verify_weights_unchanged(initial_weights, final_weights, weight_type): | ||
| """Verify that weights have not changed during training.""" | ||
| for name in initial_weights: | ||
| if name not in final_weights: | ||
| raise AssertionError(f"{weight_type} weight missing after training: {name}") | ||
| torch.testing.assert_close( | ||
| initial_weights[name].to(device=final_weights[name].device, dtype=final_weights[name].dtype), | ||
| final_weights[name], | ||
| ) | ||
|
|
||
|
|
||
| class Model(nn.Module): | ||
| def __init__(self, model_id): | ||
| super().__init__() | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_id, | ||
| torch_dtype=torch.bfloat16, | ||
| ) | ||
| self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
|
||
| peft_config = LoraConfig( | ||
| r=16, | ||
| lora_alpha=32, | ||
| target_modules=["q_proj", "v_proj"], | ||
| lora_dropout=0.05, | ||
| bias="none", | ||
| task_type="CAUSAL_LM", | ||
| ) | ||
| self.peft_model = get_peft_model(model, peft_config) | ||
|
|
||
| # Second adapter config (will remain disabled/unused throughout training) | ||
| peft_config_second = LoraConfig( | ||
| r=8, | ||
| lora_alpha=16, | ||
| target_modules=["q_proj", "v_proj"], | ||
| lora_dropout=0.05, | ||
| bias="none", | ||
| task_type="CAUSAL_LM", | ||
| ) | ||
| self.peft_model.add_adapter("second_adapter", peft_config_second) | ||
|
|
||
| self.peft_model.set_adapter("default") | ||
| self.peft_model.to(torch.bfloat16) | ||
|
|
||
| for name, param in self.peft_model.named_parameters(): | ||
| param.requires_grad = "lora_" in name.lower() and "second_adapter" not in name | ||
|
|
||
| def forward(self, input_ids=None, attention_mask=None, labels=None): | ||
| out1 = self.peft_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | ||
| with self.peft_model.disable_adapter(): | ||
| out2 = self.peft_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | ||
| combined_loss = out1.loss + out2.loss | ||
| return (combined_loss,) | ||
|
|
||
|
|
||
| def test_training(model_id: str): | ||
| state = PartialState() | ||
| torch.manual_seed(42) | ||
| model = Model(model_id) | ||
|
|
||
| initial_base_weights = get_base_model_weights(model.peft_model) | ||
| initial_second_adapter_weights = get_adapter_weights(model.peft_model, "second_adapter") | ||
|
|
||
| if state.is_main_process: | ||
| print(f"Number of base model weight tensors: {len(initial_base_weights)}") | ||
| print(f"Number of second_adapter weight tensors: {len(initial_second_adapter_weights)}") | ||
|
|
||
| data = load_dataset("ybelkada/english_quotes_copy") | ||
| data = data.map(lambda samples: model.tokenizer(samples["quote"]), batched=True) | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmp_dir: | ||
| trainer = Trainer( | ||
| model=model, | ||
| train_dataset=data["train"], | ||
| optimizer_cls_and_kwargs=(torch.optim.SGD, {"lr": 2e-4}), | ||
| args=TrainingArguments( | ||
| per_device_train_batch_size=4, | ||
| gradient_accumulation_steps=4, | ||
| warmup_steps=2, | ||
| max_steps=5, | ||
| learning_rate=2e-4, | ||
| bf16=True, | ||
| logging_steps=1, | ||
| output_dir=tmp_dir, | ||
| ), | ||
| data_collator=DataCollatorForLanguageModeling(model.tokenizer, mlm=False), | ||
| ) | ||
| trainer.train() | ||
| with FSDP.summon_full_params(trainer.model): | ||
| final_base_weights = get_base_model_weights(model.peft_model) | ||
| final_second_adapter_weights = get_adapter_weights(model.peft_model, "second_adapter") | ||
|
|
||
| verify_weights_unchanged(initial_base_weights, final_base_weights, "Base model") | ||
| verify_weights_unchanged(initial_second_adapter_weights, final_second_adapter_weights, "second_adapter") | ||
|
|
||
|
|
||
| def main(model_id: str): | ||
| test_training(model_id) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--model_id", type=str, required=False, default="Qwen/Qwen3-0.6B") | ||
| args = parser.parse_args() | ||
| main(model_id=args.model_id) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the note comment from before can stay as it doesn't seem to be invalidated.