Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,5 @@ tests_training:
accelerate launch --config_file tests/training/deepspeed_config.yaml tests/training/training.py --quant 8bit $(if $(IS_GITHUB_CI),--report-log "training_deepspeed_8bit.log",)
accelerate launch --config_file tests/training/fsdp_config.yaml tests/training/training.py $(if $(IS_GITHUB_CI),--report-log "training_fsdp.log",)
accelerate launch --config_file tests/training/fsdp_config.yaml tests/training/training.py --quant 4bit $(if $(IS_GITHUB_CI),--report-log "training_fsdp_4bit.log",)
accelerate launch --config_file tests/training/fsdp_config.yaml tests/training/test_fsdp_adapters.py --test all $(if $(IS_GITHUB_CI),--report-log "training_fsdp_adapters.log",)
accelerate launch --config_file tests/training/fsdp_config.yaml tests/training/test_fsdp_adapters.py --test all --quant 4bit $(if $(IS_GITHUB_CI),--report-log "training_fsdp_adapters_4bit.log",)
29 changes: 23 additions & 6 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,18 @@ def enable_adapters(self, enabled: bool) -> None:
# disable grads on all adapter layers
for layer_name in self.adapter_layer_names:
layer = getattr(self, layer_name)
layer.requires_grad_(False)
# Handle FSDP case where params may be non-leaf tensors
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like this to explain why they are non-leaf tensors?

Suggested change
# Handle FSDP case where params may be non-leaf tensors
# Handle FSDP case where params may be non-leaf tensors by being wrapped in DTensors

# layer.parameters() returns an iterator, so we need to check if layer is a module
if hasattr(layer, "parameters"):
for param in layer.parameters():
if param.is_leaf:
param.requires_grad_(False)
else:
# layer is a parameter/tensor itself (e.g., from ParameterDict)
# In this case we need to iterate through the dict
for param in layer.values():
if param.is_leaf:
param.requires_grad_(False)
self._disable_adapters = True

def set_adapter(self, adapter_names: str | list[str], inference_mode: bool = False) -> None:
Expand All @@ -1411,12 +1422,18 @@ def set_adapter(self, adapter_names: str | list[str], inference_mode: bool = Fal
for layer_name in self.adapter_layer_names:
module_dict = getattr(self, layer_name)
for key, layer in module_dict.items():
if (key in adapter_names) and (not inference_mode):
# Note: It is possible that not a single layer is called with requires_grad_(True) here. This may
# happen if a completely different adapter layer is being activated.
Comment on lines -1414 to -1416
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the note comment from before can stay as it doesn't seem to be invalidated.

layer.requires_grad_(True)
should_require_grad = (key in adapter_names) and (not inference_mode)
# Handle FSDP case where params may be non-leaf tensors
# Check if layer is a module or a parameter/tensor directly
if isinstance(layer, (torch.nn.Parameter, torch.Tensor)):
# layer is a parameter/tensor itself (e.g., from ParameterDict)
if layer.is_leaf:
layer.requires_grad_(should_require_grad)
else:
layer.requires_grad_(False)
# layer is a module with parameters
for param in layer.parameters():
if param.is_leaf:
param.requires_grad_(should_require_grad)

self._active_adapter = adapter_names

Expand Down
256 changes: 256 additions & 0 deletions tests/training/test_fsdp_adapters.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename this to adapters.py to avoid being gobbled up by pytest and dropping the fsdp prefix since this technically is a test for whatever distributed training we use

Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Script to test FSDP adapter operations (disable_adapters, set_adapter, etc.) in a distributed environment.

This script is designed to be run with `accelerate launch` to properly test FSDP behavior across multiple GPUs.

Usage:
accelerate launch --config_file tests/training/fsdp_config.yaml tests/training/test_fsdp_adapters.py
accelerate launch --config_file tests/training/fsdp_config.yaml tests/training/test_fsdp_adapters.py --test disable_adapters
accelerate launch --config_file tests/training/fsdp_config.yaml tests/training/test_fsdp_adapters.py --test set_adapter
"""

import argparse
import tempfile

import torch
from accelerate import PartialState
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)

from peft import LoraConfig, get_peft_model


def print_if_process_zero(*args, **kwargs):
PartialState().print(*args, **kwargs)


def test_disable_adapters(model_id: str, quant: str | None):
"""Test that disable_adapters() works correctly with FSDP."""
print_if_process_zero("=" * 50)
print_if_process_zero(f"Testing disable_adapters with {model_id=}, {quant=}")
print_if_process_zero("=" * 50)

if quant == "4bit":
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_type="bfloat16",
bnb_4bit_quant_storage="bfloat16",
bnb_4bit_use_double_quant=True,
)
else:
quant_config = None

tokenizer = AutoTokenizer.from_pretrained(model_id)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
device_map={"": PartialState().process_index},
)

peft_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
print_if_process_zero(model)
if PartialState().is_local_main_process:
model.print_trainable_parameters()

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

with tempfile.TemporaryDirectory() as tmp_dir:
trainer = Trainer(
model=model,
train_dataset=data["train"],
optimizer_cls_and_kwargs=(torch.optim.SGD, {"lr": 2e-4}),
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=5,
learning_rate=2e-4,
bf16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

# Train for a few steps first
trainer.train()

# Test disable_adapters - should not raise
print_if_process_zero("Testing disable_adapters()...")
model.disable_adapters()
print_if_process_zero("disable_adapters() succeeded!")

# Test enable_adapters - should not raise
print_if_process_zero("Testing enable_adapters()...")
model.enable_adapters()
print_if_process_zero("enable_adapters() succeeded!")

# Test context manager - should not raise
print_if_process_zero("Testing disable_adapter() context manager...")
with model.disable_adapter():
pass
print_if_process_zero("Context manager succeeded!")

# Train a few more steps after re-enabling
trainer.train()

print_if_process_zero("All disable_adapters tests passed!")


def test_set_adapter(model_id: str, quant: str | None):
"""Test that set_adapter() works correctly with FSDP."""
print_if_process_zero("=" * 50)
print_if_process_zero(f"Testing set_adapter with {model_id=}, {quant=}")
print_if_process_zero("=" * 50)

if quant == "4bit":
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_type="bfloat16",
bnb_4bit_quant_storage="bfloat16",
bnb_4bit_use_double_quant=True,
)
else:
quant_config = None

tokenizer = AutoTokenizer.from_pretrained(model_id)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
device_map={"": PartialState().process_index},
)

# Create first adapter
peft_config1 = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config1, adapter_name="adapter1")

# Add second adapter
peft_config2 = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model.add_adapter("adapter2", peft_config2)

print_if_process_zero(model)
if PartialState().is_local_main_process:
model.print_trainable_parameters()

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

with tempfile.TemporaryDirectory() as tmp_dir:
trainer = Trainer(
model=model,
train_dataset=data["train"],
optimizer_cls_and_kwargs=(torch.optim.SGD, {"lr": 2e-4}),
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=5,
learning_rate=2e-4,
bf16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

# Train with adapter1
trainer.train()

# Test set_adapter - should not raise
print_if_process_zero("Testing set_adapter('adapter2')...")
model.set_adapter("adapter2")
print_if_process_zero("set_adapter('adapter2') succeeded!")

# Test switching back
print_if_process_zero("Testing set_adapter('adapter1')...")
model.set_adapter("adapter1")
print_if_process_zero("set_adapter('adapter1') succeeded!")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the test should mimic a more realistic training scenario, i.e. switching to adapter 1, training a bit, switching to adapter 2, training that.


# Test with list of adapters
print_if_process_zero("Testing set_adapter(['adapter1', 'adapter2'])...")
model.set_adapter(["adapter1", "adapter2"])
print_if_process_zero("set_adapter(['adapter1', 'adapter2']) succeeded!")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also test if the base weights and the adapter2 weights are untouched and only adapter1 weights have changed.

print_if_process_zero("All set_adapter tests passed!")


def main(test_name: str, model_id: str, quant: str | None):
if test_name == "disable_adapters":
test_disable_adapters(model_id, quant)
elif test_name == "set_adapter":
test_set_adapter(model_id, quant)
elif test_name == "all":
test_disable_adapters(model_id, quant)
test_set_adapter(model_id, quant)
else:
raise ValueError(f"Unknown test: {test_name}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, required=False, default="Qwen/Qwen3-0.6B")
parser.add_argument("--quant", type=str, choices=["4bit"], required=False, default=None)
parser.add_argument(
"--test",
type=str,
choices=["disable_adapters", "set_adapter", "all"],
required=False,
default="all",
help="Which test to run",
)
args = parser.parse_args()
main(test_name=args.test, model_id=args.model_id, quant=args.quant)
Loading