Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/package_reference/functional.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ The functions provided here can be considered "public API" of PEFT and hence are
[[autodoc]] functional.set_adapter
- all

## Set the `requires_grad` attribute of the specified adapters
[[autodoc]] functional.set_requires_grad
- all

## Load the weights of the PEFT state dict into the model
[[autodoc]] functional.set_peft_model_state_dict
- all
3 changes: 2 additions & 1 deletion src/peft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

from peft.mapping import inject_adapter_in_model
from peft.tuners.tuners_utils import cast_adapter_dtype, delete_adapter, set_adapter
from peft.tuners.tuners_utils import cast_adapter_dtype, delete_adapter, set_adapter, set_requires_grad
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict


Expand All @@ -30,4 +30,5 @@
"inject_adapter_in_model",
"set_adapter",
"set_peft_model_state_dict",
"set_requires_grad",
]
21 changes: 21 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import inspect
import os
import warnings
from collections.abc import Sequence
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from dataclasses import dataclass
Expand Down Expand Up @@ -1461,6 +1462,26 @@ def set_adapter(self, adapter_name: str) -> None:
# handle auxiliary modules
_set_adapter(self, adapter_name)

def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None:
"""
Enable or disable gradients on the given adapter(s).

Note: Not supported for prompt learning methods like prompt tuning.

Args:
adapter_name (`str` or `Sequence[str]`):
The name of the adapter(s) whose gradients should be enabled/disabled.
requires_grad (`bool`, *optional*)
Whether to enable (`True`, default) or disable (`False`).
"""
if self.active_peft_config.is_prompt_learning:
raise TypeError(
"Setting `requires_grad` is not supported for prompt learning methods like "
f"{self.active_peft_config.peft_type.value}."
)

self.base_model.set_requires_grad(adapter_names=adapter_names, requires_grad=requires_grad)

@property
def base_model_torch_dtype(self):
return getattr(self.base_model, "dtype", None)
Expand Down
51 changes: 51 additions & 0 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import textwrap
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import contextmanager, nullcontext
from typing import Any, Optional, Union, overload

Expand Down Expand Up @@ -483,6 +484,18 @@ def delete_adapter(self, adapter_name: str) -> None:
)
self.active_adapter = new_adapter or []

def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None:
"""
Enable or disable gradients on the given adapter(s).

Args:
adapter_name (`str` or `Sequence[str]`):
The name of the adapter(s) whose gradients should be enabled/disabled.
requires_grad (`bool`, *optional*)
Whether to enable (`True`, default) or disable (`False`).
"""
set_requires_grad(self.model, adapter_names=adapter_names, requires_grad=requires_grad)

def _check_new_adapter_config(self, config: PeftConfig) -> None:
"""
A helper method to check the config of a new adapter being added.
Expand Down Expand Up @@ -1353,6 +1366,27 @@ def delete_adapter(self, adapter_name: str) -> None:
)
self.set_adapter(remaining_adapters[0])

def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None:
"""
Enable or disable gradients on the given adapter(s).

Args:
adapter_name (`str` or `Sequence[str]`):
The name of the adapter(s) whose gradients should be enabled/disabled.
requires_grad (`bool`, *optional*)
Whether to enable (`True`, default) or disable (`False`).
"""
if isinstance(adapter_names, str):
adapter_names_set = {adapter_names}
else:
adapter_names_set = set(adapter_names)

for layer_name in self.adapter_layer_names:
module_dict = getattr(self, layer_name)
for key, layer in module_dict.items():
if key in adapter_names_set:
layer.requires_grad_(requires_grad)

def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optional[torch.device] = None) -> None:
"""
Move the adapter of the given name to the device of the base layer.
Expand Down Expand Up @@ -1877,3 +1911,20 @@ def cast_adapter_dtype(model: nn.Module, adapter_name: str, autocast_adapter_dty
for param in submodule[adapter_name].parameters():
if param.dtype in dtypes_to_convert_to_fp32:
param.data = param.data.to(torch.float32)


def set_requires_grad(model, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None:
"""
Enable or disable gradients on the given adapter(s).

Args:
model (`nn.Module`):
The model from which the adapter should be deleted.
adapter_name (`str` or `Sequence[str]`):
The name of the adapter(s) whose gradients should be enabled/disabled.
requires_grad (`bool`, *optional*)
Whether to enable (`True`, default) or disable (`False`).
"""
for module in model.modules():
if isinstance(module, (BaseTunerLayer, AuxiliaryTrainingWrapper)):
module.set_requires_grad(adapter_names=adapter_names, requires_grad=requires_grad)
23 changes: 23 additions & 0 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import warnings
from collections.abc import Sequence
from contextlib import nullcontext
from operator import attrgetter
from typing import Any, Optional, Union

import accelerate
Expand Down Expand Up @@ -478,6 +479,28 @@ def delete_adapter(self, adapter_name: str, new_active_adapters: Optional[list[s
"""Delete an adapter from the layer, set a new active adapter if necessary"""
raise NotImplementedError

def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None:
"""
Enable or disable gradients on the given adapter(s).

Args:
adapter_name (`str` or `Sequence[str]`):
The name of the adapter(s) whose gradients should be enabled/disabled.
requires_grad (`bool`, *optional*)
Whether to enable (`True`, default) or disable (`False`).
"""
if isinstance(adapter_names, str):
adapter_names_set = {adapter_names}
else:
adapter_names_set = set(adapter_names)

for layer_name in self.adapter_layer_names:
# use attrgetter, as it resolves `.` in the attribute name
module_dict = attrgetter(layer_name)(self)
for key, layer in module_dict.items():
if key in adapter_names_set:
layer.requires_grad_(requires_grad)

def adapter_state_dict(self, adapter_name):
"""Return the state dict of this module for a given adapter."""
raise NotImplementedError
Expand Down
95 changes: 95 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2640,6 +2640,101 @@ def test_delete_adapter_multiple_adapters_with_trainable_token_indices(self):
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)

@staticmethod
def _check_requires_grad(module, adapter_name, requires_grad):
# a bit of a clumsy way to test requires_grad on the PEFT parameters
for name in module.adapter_layer_names:
module_dict = getattr(module, name)
if adapter_name not in module_dict:
continue
attr = module_dict[adapter_name]
if isinstance(attr, nn.Module):
for param in attr.parameters():
assert param.requires_grad == requires_grad
else: # it's an nn.Parameter
assert attr.requires_grad == requires_grad

@pytest.mark.parametrize("config_cls", ALL_PEFT_CONFIG_CLASSES)
def test_set_requires_grad(self, config_cls):
# checks that the model.set_requires_grad method works as expected
if config_cls == TrainableTokensConfig:
pytest.skip(
"TrainableTokensConfig has a separate test for set_requires_grad, as it needs a different model."
)

config_kwargs = {"target_modules": ["layers.0.lin0"]}
if config_cls == IA3Config:
config_kwargs["feedforward_modules"] = []
config0 = config_cls(**config_kwargs)
model = DeepMLP(size=256) # a size that works with all adapters
model = get_peft_model(model, config0, adapter_name="adapter0").eval()

# check that it works with a single adapter
self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter0", requires_grad=True)

# add another adapter with two target modules and with modules_to_save
config_kwargs["target_modules"] = ["layers.0.lin0", "layers.1.lin0"]
config_kwargs["modules_to_save"] = ["layers.2.lin0"]
config1 = config_cls(**config_kwargs)
model.add_adapter("adapter1", config1)

# adapter0 still has requires_grad=True, adapter1 has requires_grad=False
self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter0", requires_grad=True)
self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter1", requires_grad=False)
self._check_requires_grad(model.base_model.model.layers[1].lin0, adapter_name="adapter1", requires_grad=False)
self._check_requires_grad(model.base_model.model.layers[2].lin0, adapter_name="adapter1", requires_grad=False)

# enable grad for adapter1; adapter0 is unaffected
model.set_requires_grad(adapter_names="adapter1")
self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter0", requires_grad=True)
self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter1", requires_grad=True)
self._check_requires_grad(model.base_model.model.layers[1].lin0, adapter_name="adapter1", requires_grad=True)
self._check_requires_grad(model.base_model.model.layers[2].lin0, adapter_name="adapter1", requires_grad=True)

# disable adapter for both
model.set_requires_grad(adapter_names=["adapter0", "adapter1"], requires_grad=False)
self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter0", requires_grad=False)
self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter1", requires_grad=False)
self._check_requires_grad(model.base_model.model.layers[1].lin0, adapter_name="adapter1", requires_grad=False)

def test_set_requires_grad_trainable_tokens(self):
# same as test_set_requires_grad for trainable tokens
class EmbModel(nn.Module):
def __init__(self):
super().__init__()
self.emb0 = nn.Embedding(10, 10)
self.emb1 = nn.Embedding(10, 10)

config_kwargs = {"target_modules": ["emb0"], "token_indices": [0, 2, 4]}
config0 = TrainableTokensConfig(**config_kwargs)
model = EmbModel()
model = get_peft_model(model, config0, adapter_name="adapter0").eval()

# check that it works with a single adapter
self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter0", requires_grad=True)

# add another adapter which targets 2 embedding layers
config_kwargs["target_modules"] = ["emb0", "emb1"]
config1 = TrainableTokensConfig(**config_kwargs)
model.add_adapter("adapter1", config1)

# adapter0 still has requires_grad=True, adapter1 has requires_grad=False
self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter0", requires_grad=True)
self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter1", requires_grad=False)
self._check_requires_grad(model.base_model.model.emb1, adapter_name="adapter1", requires_grad=False)

# enable grad for adapter1; adapter0 is unaffected
model.set_requires_grad(adapter_names="adapter1")
self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter0", requires_grad=True)
self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter1", requires_grad=True)
self._check_requires_grad(model.base_model.model.emb1, adapter_name="adapter1", requires_grad=True)

# disable adapter for both
model.set_requires_grad(adapter_names=["adapter0", "adapter1"], requires_grad=False)
self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter0", requires_grad=False)
self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter1", requires_grad=False)
self._check_requires_grad(model.base_model.model.emb1, adapter_name="adapter1", requires_grad=False)

def test_weight_bias_attributes(self):
model = MLP()
config = LoraConfig(target_modules=["lin0"])
Expand Down
20 changes: 20 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,3 +774,23 @@ def test_save_pretrained_targeting_lora_to_embedding_layer(self, save_embedding_
)
else:
assert not contains_embedding

@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_set_requires_grad_prompt_learning_raises(self, config_cls, config_kwargs):
# Test that for prompt learning, calling set_requires_grad raises an error with an appropriate error message.
# Note that for non-prompt learning methods, set_requires_grad is being tested for custom models, so there is no
# specific test here.
model_id = PEFT_DECODER_MODELS_TO_TEST[0] # it's enough to test this with one model
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
if not config.is_prompt_learning:
pytest.skip("This test is only for prompt learning methods.")

with hub_online_once(model_id + config_kwargs.get("tokenizer_name_or_path", "")):
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
model = get_peft_model(model, config)
msg = "Setting `requires_grad` is not supported for prompt learning methods like"
with pytest.raises(TypeError, match=msg):
model.set_requires_grad(adapter_names="adpater0")
14 changes: 14 additions & 0 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,14 @@ def test_requires_grad_large(self, large_model):
expected = [{"default": False, "other": True}, {"default": False}, {"other": True}, {"default": False}]
assert result == expected

# change requires grad, is now inconsistent with active/inactive adapter
large_model.set_requires_grad("default", requires_grad=True)
large_model.set_requires_grad("other", requires_grad=False)
layer_status = large_model.get_layer_status()
result = [status.requires_grad for status in layer_status]
expected = [{"default": True, "other": False}, {"default": True}, {"other": False}, {"default": True}]
assert result == expected

def test_requires_grad_irregular(self, large_model):
# inject an embedding layer with requires_grad=False
# this is an invalid state, but we should still test it
Expand Down Expand Up @@ -1114,6 +1122,12 @@ def test_model_requires_grad_model_large(self, large_model):
model_status = large_model.get_model_status()
assert model_status.requires_grad == {"default": False, "other": True}

# change requires grad, is now inconsistent with active/inactive adapter
large_model.set_requires_grad("default", requires_grad=True)
large_model.set_requires_grad("other", requires_grad=False)
model_status = large_model.get_model_status()
assert model_status.requires_grad == {"default": True, "other": False}

def test_model_requires_grad_model_irregular(self, large_model):
# inject an embedding layer with requires_grad=False
# this is an invalid state, but we should still test it
Expand Down
Loading