Skip to content

Commit 79f7213

Browse files
BenjaminBossanefraimdahl
authored andcommitted
FIX Correctly determine no_split_modules (huggingface#2570)
See discussion in huggingface/transformers#38141 for context. In the PEFT fsdp_auto_wrap policy, we determine the _no_split_modules. However, this currently neglects to visit the children of the model, which can be required for some architectures. This PR fixes that. Note that the _get_no_split_modules function is largely copied from transformers. One change is that it doesn't take the device_map argument. That argument is used in transformers inside an error message but not for the logic proper. I think it's safe to remove. Morever, I made an unrelated change to fsdp_auto_wrap_policy, namely making local imports global (there was no reason for them to be local).
1 parent 86773e9 commit 79f7213

File tree

2 files changed

+54
-9
lines changed

2 files changed

+54
-9
lines changed

src/peft/utils/other.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
import copy
17+
import functools
1718
import inspect
1819
import os
1920
import re
@@ -24,12 +25,14 @@
2425

2526
import accelerate
2627
import torch
28+
from accelerate import FullyShardedDataParallelPlugin
2729
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
2830
from accelerate.utils import is_npu_available, is_xpu_available
2931
from huggingface_hub import file_exists
3032
from huggingface_hub.errors import EntryNotFoundError, HFValidationError
3133
from packaging import version
3234
from safetensors.torch import storage_ptr, storage_size
35+
from transformers import PreTrainedModel
3336

3437
from ..import_utils import is_auto_gptq_available, is_gptqmodel_available, is_torch_tpu_available
3538
from .constants import (
@@ -942,12 +945,33 @@ def _prepare_prompt_learning_config(peft_config, model_config):
942945
return peft_config
943946

944947

945-
def fsdp_auto_wrap_policy(model):
946-
import functools
947-
import os
948+
def _get_no_split_modules(model) -> set[str]:
949+
"""
950+
Get the modules of the model that should not be split when using device_map. We iterate through the modules to get
951+
the underlying `_no_split_modules`.
952+
953+
Returns:
954+
`List[str]`: List of modules that should not be split
955+
"""
956+
# After discussion in https://github.com/huggingface/transformers/pull/38141, based on:
957+
# https://github.com/huggingface/transformers/blob/1e921a3a9cea92b383ca4b0484ee45596bbdadc3/src/transformers/modeling_utils.py#L2677-L2704
958+
_no_split_modules: set[str] = set()
959+
if not hasattr(model, "_no_split_modules"):
960+
return _no_split_modules
961+
962+
modules_to_check = [model]
963+
while len(modules_to_check) > 0:
964+
module = modules_to_check.pop(-1)
965+
# if the module does not appear in _no_split_modules, we also check the children
966+
if module.__class__.__name__ not in _no_split_modules:
967+
if isinstance(module, PreTrainedModel):
968+
if module._no_split_modules is not None:
969+
_no_split_modules = _no_split_modules | set(module._no_split_modules)
970+
modules_to_check += list(module.children())
971+
return _no_split_modules
948972

949-
from accelerate import FullyShardedDataParallelPlugin
950973

974+
def fsdp_auto_wrap_policy(model):
951975
if hasattr(FullyShardedDataParallelPlugin, "get_module_class_from_name"):
952976
get_module_class_from_name = FullyShardedDataParallelPlugin.get_module_class_from_name
953977
else:
@@ -956,9 +980,7 @@ def fsdp_auto_wrap_policy(model):
956980

957981
from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
958982

959-
default_transformer_cls_names_to_wrap = (
960-
",".join(model._no_split_modules) if getattr(model, "_no_split_modules", None) is not None else ""
961-
)
983+
default_transformer_cls_names_to_wrap = ",".join(_get_no_split_modules(model))
962984
transformer_cls_names_to_wrap = os.environ.get(
963985
"FSDP_TRANSFORMER_CLS_TO_WRAP", default_transformer_cls_names_to_wrap
964986
).split(",")

tests/test_other.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import pytest
1818
import torch
1919
from torch import nn
20-
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification
20+
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, LlavaForConditionalGeneration
2121

2222
from peft import LoraConfig, PeftModel, VeraConfig, get_peft_model
23-
from peft.utils.other import ModulesToSaveWrapper
23+
from peft.utils.other import ModulesToSaveWrapper, _get_no_split_modules
2424

2525

2626
class ModelWithModuleDict(nn.Module):
@@ -507,3 +507,26 @@ def remove_adapter_portion(adapter_name, key):
507507
}
508508

509509
assert adapter_invariant_keys1 == adapter_invariant_keys2
510+
511+
512+
class TestGetNoSplitModules:
513+
# Ensure that children are considered when determining _no_split_modules
514+
# see https://github.com/huggingface/transformers/pull/38141
515+
516+
def test_get_no_split_modules_simple(self):
517+
# choose a model where recursively visiting children is *not* required
518+
model_id = "facebook/opt-125m"
519+
model = AutoModelForCausalLM.from_pretrained(model_id)
520+
assert model._no_split_modules == ["OPTDecoderLayer"]
521+
no_split_modules = _get_no_split_modules(model)
522+
assert no_split_modules == {"OPTDecoderLayer"}
523+
524+
def test_get_no_split_modules_recursive(self):
525+
# choose a model where recursively visiting children is required
526+
model_id = "hf-internal-testing/tiny-random-LlavaForConditionalGeneration"
527+
model = LlavaForConditionalGeneration.from_pretrained(model_id)
528+
# sanity check: just visiting the model itself is not enough:
529+
assert model._no_split_modules == []
530+
531+
no_split_modules = _get_no_split_modules(model)
532+
assert no_split_modules == {"CLIPEncoderLayer", "LlamaDecoderLayer"}

0 commit comments

Comments
 (0)