From 5b226e037cb3278c9103e04097ee3a5e8de084c1 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 1 Oct 2025 11:34:17 +0200 Subject: [PATCH 1/2] ENH Add set_requires_grad method This PR adds the set_requires_grad method to PEFT models (both PeftModel and BaseTuner). As the name suggests, this is a method to set the requires_grad attribute of the specified PEFT adapters. For more general context, this is mostly relevant when dealing with multiple adapters. As is, users can already set the active adapter(s) with set_adapter, which automatically adjust the requires_grad attribute too, so that only the active adapters will have grads enabled. However, there can be situations where activity status and requires grad may differ. Right now, users would need to manually set requires_grad to deal with that, which is error prone (e.g. forgetting modules_to_save). This PR closes this gap in the API. As this functionality is quite general purpose, I added a set_requires_grad function to functional.py for easier integration. Note: The set_requires_grad method will raise an error when called with prompt learning methods like prompt tuning. This is because these methods don't have a universal base class (BaseTuner and BaseTunerLayer) that would allow to add this API. Moreover, they only support a single adapter at a time, hence there is not much need to have this method in the first place. A side effect of not supporting prompt learning is that on the PeftModel, we are free to allow set_requires_grad to accept more than one adapter, which would normally be difficult, because prompt learning only allows one adapter. --- docs/source/package_reference/functional.md | 4 + src/peft/functional.py | 3 +- src/peft/peft_model.py | 22 +++++ src/peft/tuners/tuners_utils.py | 51 ++++++++++ src/peft/utils/other.py | 23 +++++ tests/test_custom_models.py | 102 ++++++++++++++++++++ tests/test_tuners_utils.py | 14 +++ 7 files changed, 218 insertions(+), 1 deletion(-) diff --git a/docs/source/package_reference/functional.md b/docs/source/package_reference/functional.md index 807dc615f8..52251bd490 100644 --- a/docs/source/package_reference/functional.md +++ b/docs/source/package_reference/functional.md @@ -28,6 +28,10 @@ The functions provided here can be considered "public API" of PEFT and hence are [[autodoc]] functional.set_adapter - all +## Set the `requires_grad` attribute of the specified adapters +[[autodoc]] functional.set_requires_grad + - all + ## Load the weights of the PEFT state dict into the model [[autodoc]] functional.set_peft_model_state_dict - all diff --git a/src/peft/functional.py b/src/peft/functional.py index f14041ad65..60df690caf 100644 --- a/src/peft/functional.py +++ b/src/peft/functional.py @@ -19,7 +19,7 @@ """ from peft.mapping import inject_adapter_in_model -from peft.tuners.tuners_utils import cast_adapter_dtype, delete_adapter, set_adapter +from peft.tuners.tuners_utils import cast_adapter_dtype, delete_adapter, set_adapter, set_requires_grad from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict @@ -30,4 +30,5 @@ "inject_adapter_in_model", "set_adapter", "set_peft_model_state_dict", + "set_requires_grad", ] diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 5cb1f7e424..7815b910f3 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -19,6 +19,7 @@ import inspect import os import warnings +from collections.abc import Sequence from contextlib import contextmanager, nullcontext from copy import deepcopy from dataclasses import dataclass @@ -1461,6 +1462,27 @@ def set_adapter(self, adapter_name: str) -> None: # handle auxiliary modules _set_adapter(self, adapter_name) + def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None: + """ + Enable or disable gradients on the given adapter(s). + + Note: Not supported for prompt learning methods like prompt tuning. + + Args: + model (`nn.Module`): + The model from which the adapter should be deleted. + adapter_name (`str` or `Sequence[str]`): + The name of the adapter(s) whose gradients should be enabled/disabled. + requires_grad (`bool`, *optional*) + Whether to enable (`True`, default) or disable (`False`). + """ + if self.active_peft_config.is_prompt_learning: + raise TypeError( + f"Setting `requires_grad` is not supported for prompt learning methods like {self.active_peft_config.peft_type}" + ) + + self.base_model.set_requires_grad(adapter_names=adapter_names, requires_grad=requires_grad) + @property def base_model_torch_dtype(self): return getattr(self.base_model, "dtype", None) diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 4fd0d12843..66903296a8 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -20,6 +20,7 @@ import textwrap import warnings from abc import ABC, abstractmethod +from collections.abc import Sequence from contextlib import contextmanager, nullcontext from typing import Any, Optional, Union, overload @@ -483,6 +484,18 @@ def delete_adapter(self, adapter_name: str) -> None: ) self.active_adapter = new_adapter or [] + def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None: + """ + Enable or disable gradients on the given adapter(s). + + Args: + adapter_name (`str` or `Sequence[str]`): + The name of the adapter(s) whose gradients should be enabled/disabled. + requires_grad (`bool`, *optional*) + Whether to enable (`True`, default) or disable (`False`). + """ + set_requires_grad(self.model, adapter_names=adapter_names, requires_grad=requires_grad) + def _check_new_adapter_config(self, config: PeftConfig) -> None: """ A helper method to check the config of a new adapter being added. @@ -1353,6 +1366,27 @@ def delete_adapter(self, adapter_name: str) -> None: ) self.set_adapter(remaining_adapters[0]) + def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None: + """ + Enable or disable gradients on the given adapter(s). + + Args: + adapter_name (`str` or `Sequence[str]`): + The name of the adapter(s) whose gradients should be enabled/disabled. + requires_grad (`bool`, *optional*) + Whether to enable (`True`, default) or disable (`False`). + """ + if isinstance(adapter_names, str): + adapter_names_set = {adapter_names} + else: + adapter_names_set = set(adapter_names) + + for layer_name in self.adapter_layer_names: + module_dict = getattr(self, layer_name) + for key, layer in module_dict.items(): + if key in adapter_names_set: + layer.requires_grad_(requires_grad) + def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optional[torch.device] = None) -> None: """ Move the adapter of the given name to the device of the base layer. @@ -1877,3 +1911,20 @@ def cast_adapter_dtype(model: nn.Module, adapter_name: str, autocast_adapter_dty for param in submodule[adapter_name].parameters(): if param.dtype in dtypes_to_convert_to_fp32: param.data = param.data.to(torch.float32) + + +def set_requires_grad(model, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None: + """ + Enable or disable gradients on the given adapter(s). + + Args: + model (`nn.Module`): + The model from which the adapter should be deleted. + adapter_name (`str` or `Sequence[str]`): + The name of the adapter(s) whose gradients should be enabled/disabled. + requires_grad (`bool`, *optional*) + Whether to enable (`True`, default) or disable (`False`). + """ + for module in model.modules(): + if isinstance(module, (BaseTunerLayer, AuxiliaryTrainingWrapper)): + module.set_requires_grad(adapter_names=adapter_names, requires_grad=requires_grad) diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 29878bcc91..825d50caee 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -21,6 +21,7 @@ import warnings from collections.abc import Sequence from contextlib import nullcontext +from operator import attrgetter from typing import Any, Optional, Union import accelerate @@ -478,6 +479,28 @@ def delete_adapter(self, adapter_name: str, new_active_adapters: Optional[list[s """Delete an adapter from the layer, set a new active adapter if necessary""" raise NotImplementedError + def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None: + """ + Enable or disable gradients on the given adapter(s). + + Args: + adapter_name (`str` or `Sequence[str]`): + The name of the adapter(s) whose gradients should be enabled/disabled. + requires_grad (`bool`, *optional*) + Whether to enable (`True`, default) or disable (`False`). + """ + if isinstance(adapter_names, str): + adapter_names_set = {adapter_names} + else: + adapter_names_set = set(adapter_names) + + for layer_name in self.adapter_layer_names: + # use attrgetter, as it resolves `.` in the attribute name + module_dict = attrgetter(layer_name)(self) + for key, layer in module_dict.items(): + if key in adapter_names_set: + layer.requires_grad_(requires_grad) + def adapter_state_dict(self, adapter_name): """Return the state dict of this module for a given adapter.""" raise NotImplementedError diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 33cded4116..5c8c002b27 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -2640,6 +2640,108 @@ def test_delete_adapter_multiple_adapters_with_trainable_token_indices(self): def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) + @staticmethod + def _check_requires_grad(module, adapter_name, requires_grad): + # a bit of a clumsy way to test requires_grad on the PEFT parameters + for name in module.adapter_layer_names: + module_dict = getattr(module, name) + if adapter_name not in module_dict: + continue + attr = module_dict[adapter_name] + if isinstance(attr, nn.Module): + for param in attr.parameters(): + assert param.requires_grad == requires_grad + else: # it's an nn.Parameter + assert attr.requires_grad == requires_grad + + @pytest.mark.parametrize("config_cls", ALL_PEFT_CONFIG_CLASSES) + def test_set_requires_grad(self, config_cls): + # checks that the model.set_requires_grad method works as expected + if config_cls == TrainableTokensConfig: + pytest.skip( + "TrainableTokensConfig has a separate test for set_requires_grad, as it needs a different model." + ) + + config_kwargs = {"target_modules": ["layers.0.lin0"]} + if config_cls == IA3Config: + config_kwargs["feedforward_modules"] = [] + config0 = config_cls(**config_kwargs) + model = DeepMLP(size=256) # a size that works with all adapters + model = get_peft_model(model, config0, adapter_name="adapter0").eval() + + if config0.is_prompt_learning: + # prompt learning does not support this method (yet), so just check for the error and return + msg = "TODO" + with pytest.raises(TypeError, match=msg): + model.set_requires_grad(adapter_names="adpater0") + return + + # check that it works with a single adapter + self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter0", requires_grad=True) + + # add another adapter with two target modules and with modules_to_save + config_kwargs["target_modules"] = ["layers.0.lin0", "layers.1.lin0"] + config_kwargs["modules_to_save"] = ["layers.2.lin0"] + config1 = config_cls(**config_kwargs) + model.add_adapter("adapter1", config1) + + # adapter0 still has requires_grad=True, adapter1 has requires_grad=False + self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter0", requires_grad=True) + self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter1", requires_grad=False) + self._check_requires_grad(model.base_model.model.layers[1].lin0, adapter_name="adapter1", requires_grad=False) + self._check_requires_grad(model.base_model.model.layers[2].lin0, adapter_name="adapter1", requires_grad=False) + + # enable grad for adapter1; adapter0 is unaffected + model.set_requires_grad(adapter_names="adapter1") + self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter0", requires_grad=True) + self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter1", requires_grad=True) + self._check_requires_grad(model.base_model.model.layers[1].lin0, adapter_name="adapter1", requires_grad=True) + self._check_requires_grad(model.base_model.model.layers[2].lin0, adapter_name="adapter1", requires_grad=True) + + # disable adapter for both + model.set_requires_grad(adapter_names=["adapter0", "adapter1"], requires_grad=False) + self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter0", requires_grad=False) + self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter1", requires_grad=False) + self._check_requires_grad(model.base_model.model.layers[1].lin0, adapter_name="adapter1", requires_grad=False) + + def test_set_requires_grad_trainable_tokens(self): + # same as test_set_requires_grad for trainable tokens + class EmbModel(nn.Module): + def __init__(self): + super().__init__() + self.emb0 = nn.Embedding(10, 10) + self.emb1 = nn.Embedding(10, 10) + + config_kwargs = {"target_modules": ["emb0"], "token_indices": [0, 2, 4]} + config0 = TrainableTokensConfig(**config_kwargs) + model = EmbModel() + model = get_peft_model(model, config0, adapter_name="adapter0").eval() + + # check that it works with a single adapter + self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter0", requires_grad=True) + + # add another adapter which targets 2 embedding layers + config_kwargs["target_modules"] = ["emb0", "emb1"] + config1 = TrainableTokensConfig(**config_kwargs) + model.add_adapter("adapter1", config1) + + # adapter0 still has requires_grad=True, adapter1 has requires_grad=False + self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter0", requires_grad=True) + self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter1", requires_grad=False) + self._check_requires_grad(model.base_model.model.emb1, adapter_name="adapter1", requires_grad=False) + + # enable grad for adapter1; adapter0 is unaffected + model.set_requires_grad(adapter_names="adapter1") + self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter0", requires_grad=True) + self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter1", requires_grad=True) + self._check_requires_grad(model.base_model.model.emb1, adapter_name="adapter1", requires_grad=True) + + # disable adapter for both + model.set_requires_grad(adapter_names=["adapter0", "adapter1"], requires_grad=False) + self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter0", requires_grad=False) + self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter1", requires_grad=False) + self._check_requires_grad(model.base_model.model.emb1, adapter_name="adapter1", requires_grad=False) + def test_weight_bias_attributes(self): model = MLP() config = LoraConfig(target_modules=["lin0"]) diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index de48ee020f..499a4a5502 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -771,6 +771,14 @@ def test_requires_grad_large(self, large_model): expected = [{"default": False, "other": True}, {"default": False}, {"other": True}, {"default": False}] assert result == expected + # change requires grad, is now inconsistent with active/inactive adapter + large_model.set_requires_grad("default", requires_grad=True) + large_model.set_requires_grad("other", requires_grad=False) + layer_status = large_model.get_layer_status() + result = [status.requires_grad for status in layer_status] + expected = [{"default": True, "other": False}, {"default": True}, {"other": False}, {"default": True}] + assert result == expected + def test_requires_grad_irregular(self, large_model): # inject an embedding layer with requires_grad=False # this is an invalid state, but we should still test it @@ -1114,6 +1122,12 @@ def test_model_requires_grad_model_large(self, large_model): model_status = large_model.get_model_status() assert model_status.requires_grad == {"default": False, "other": True} + # change requires grad, is now inconsistent with active/inactive adapter + large_model.set_requires_grad("default", requires_grad=True) + large_model.set_requires_grad("other", requires_grad=False) + model_status = large_model.get_model_status() + assert model_status.requires_grad == {"default": True, "other": False} + def test_model_requires_grad_model_irregular(self, large_model): # inject an embedding layer with requires_grad=False # this is an invalid state, but we should still test it From 8104d3bffe9355924f00ec1be33d620808744844 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 1 Oct 2025 12:07:25 +0200 Subject: [PATCH 2/2] Small fixes: - Fix docstring - Fix error message - Add test for prompt learning methods --- src/peft/peft_model.py | 5 ++--- tests/test_custom_models.py | 7 ------- tests/test_decoder_models.py | 20 ++++++++++++++++++++ 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 7815b910f3..cdb8976acb 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1469,8 +1469,6 @@ def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: b Note: Not supported for prompt learning methods like prompt tuning. Args: - model (`nn.Module`): - The model from which the adapter should be deleted. adapter_name (`str` or `Sequence[str]`): The name of the adapter(s) whose gradients should be enabled/disabled. requires_grad (`bool`, *optional*) @@ -1478,7 +1476,8 @@ def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: b """ if self.active_peft_config.is_prompt_learning: raise TypeError( - f"Setting `requires_grad` is not supported for prompt learning methods like {self.active_peft_config.peft_type}" + "Setting `requires_grad` is not supported for prompt learning methods like " + f"{self.active_peft_config.peft_type.value}." ) self.base_model.set_requires_grad(adapter_names=adapter_names, requires_grad=requires_grad) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 5c8c002b27..a9d886d04d 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -2669,13 +2669,6 @@ def test_set_requires_grad(self, config_cls): model = DeepMLP(size=256) # a size that works with all adapters model = get_peft_model(model, config0, adapter_name="adapter0").eval() - if config0.is_prompt_learning: - # prompt learning does not support this method (yet), so just check for the error and return - msg = "TODO" - with pytest.raises(TypeError, match=msg): - model.set_requires_grad(adapter_names="adpater0") - return - # check that it works with a single adapter self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter0", requires_grad=True) diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 2c0c402b1c..4c4fecc737 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -774,3 +774,23 @@ def test_save_pretrained_targeting_lora_to_embedding_layer(self, save_embedding_ ) else: assert not contains_embedding + + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_set_requires_grad_prompt_learning_raises(self, config_cls, config_kwargs): + # Test that for prompt learning, calling set_requires_grad raises an error with an appropriate error message. + # Note that for non-prompt learning methods, set_requires_grad is being tested for custom models, so there is no + # specific test here. + model_id = PEFT_DECODER_MODELS_TO_TEST[0] # it's enough to test this with one model + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + if not config.is_prompt_learning: + pytest.skip("This test is only for prompt learning methods.") + + with hub_online_once(model_id + config_kwargs.get("tokenizer_name_or_path", "")): + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + model = get_peft_model(model, config) + msg = "Setting `requires_grad` is not supported for prompt learning methods like" + with pytest.raises(TypeError, match=msg): + model.set_requires_grad(adapter_names="adpater0")