Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
49fab86
[WIP] Add LoRA multihead attention module
BenjaminBossan Jan 5, 2024
d8e9589
Make style
BenjaminBossan Jan 5, 2024
0e188a3
Remove commented code
BenjaminBossan Jan 5, 2024
b409d81
Remove assignment of weight to new module
BenjaminBossan Jan 5, 2024
173062c
Make state_dict and named_parameters work
BenjaminBossan Jan 5, 2024
1e007f5
Extend test coverage a bit
BenjaminBossan Jan 8, 2024
557c4a1
Clean ups after reviewer feedback:
BenjaminBossan Jan 9, 2024
add1f51
Reviewer feedback: removed another unnecessary arg
BenjaminBossan Jan 9, 2024
e44e030
Make style
BenjaminBossan Jan 9, 2024
8d62579
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Jan 9, 2024
c5d8a6b
Apply LoRA also to the out_proj of MHA
BenjaminBossan Jan 12, 2024
9dc4a4d
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Feb 7, 2024
c3fb2ce
Fix bug with incorrectly set gradient
BenjaminBossan Feb 7, 2024
17d407b
Fix failing tests
BenjaminBossan Feb 7, 2024
4cbf6e9
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Feb 26, 2024
e0cae11
Move to pytest style asserts
BenjaminBossan Feb 26, 2024
52c8d9b
Fix safe merging code
BenjaminBossan Feb 26, 2024
977c84b
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Mar 11, 2024
96d376d
No need to set bias for MHA anymore, see #1530
BenjaminBossan Mar 11, 2024
0c17476
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Mar 26, 2024
4b8db0c
Fix style
BenjaminBossan Mar 26, 2024
7e91712
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan May 21, 2024
e12070b
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Jul 25, 2024
7b6c7cb
Remove duplicate merge
BenjaminBossan Jul 25, 2024
e6ab8ed
Raise error for multi adapter batch inference
BenjaminBossan Jul 25, 2024
8ec6c3c
Raise error for DoRA + MHA
BenjaminBossan Jul 25, 2024
f6ba465
Fix error when adding multiple adapters to MHA
BenjaminBossan Jul 25, 2024
fb18886
Better way of param initialization
BenjaminBossan Jul 26, 2024
4ff2ec3
Add tests for broken loading and workaround
BenjaminBossan Jul 26, 2024
d1f6ab2
make style
BenjaminBossan Jul 26, 2024
65363be
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Sep 3, 2024
7ba2e68
Fix wrong merge conflict resolution in test
BenjaminBossan Sep 4, 2024
6ef04b0
Ensure that base weights have requires_grad False
BenjaminBossan Sep 4, 2024
07c7240
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Sep 4, 2024
cc3ac3d
Remove xpass-ing test
BenjaminBossan Sep 4, 2024
03c466f
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Sep 12, 2024
e558caa
MAINT: Give stale bot permissions for PRs too (#2064)
BenjaminBossan Sep 12, 2024
38f4a98
ENH BOFT don't save boft_P buffer (#2050)
sywangyi Sep 13, 2024
7e5c61d
FIX Command line args in PiSSA preprocess (#2053)
keakon Sep 13, 2024
183bf52
MNT Update deprecated evaluation_strategy (#1664)
muellerzr Sep 13, 2024
b970607
ENH Multi adapters in same batch: modules_to_save (#1990)
saeid93 Sep 17, 2024
732e8e7
FIX Bug that prevents BOFT from loading 2 adapters (#2068)
BenjaminBossan Sep 18, 2024
79e2b38
TST Skip some quantization tests on XPU (#2074)
faaany Sep 18, 2024
61e6934
Improve test coverage for initialization of MHA
BenjaminBossan Sep 18, 2024
ced2f15
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Oct 14, 2024
4c31bbc
Fix bug with unloading multihead attention layer
BenjaminBossan Oct 21, 2024
1dbb9a5
Fix bug in unloading
BenjaminBossan Oct 22, 2024
e094234
Fix for low_cpu_mem_usage
BenjaminBossan Nov 1, 2024
e90af48
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Nov 1, 2024
30a08e7
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Nov 1, 2024
09f5ea6
Add tests for init_empty_weights
BenjaminBossan Nov 26, 2024
6a83bd7
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Nov 26, 2024
3b0471a
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Dec 9, 2024
465a85e
Add MHA to modules unsupported by EVA
BenjaminBossan Dec 9, 2024
266f9da
Add comment on why/how empty init works
BenjaminBossan Jan 6, 2025
39e755e
Expose attributes of underlying MHA module
BenjaminBossan Jan 6, 2025
4857858
Apply suggestions from code review
BenjaminBossan Jan 6, 2025
74cbba6
Remove trailing whitespace
BenjaminBossan Jan 6, 2025
14deb9f
Linting..
BenjaminBossan Jan 6, 2025
ba2a8dd
Reviewer comment: Add comments for clarification
BenjaminBossan Jan 8, 2025
ac10b18
Reviewer feedback: Remove q_proj_weight
BenjaminBossan Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import packaging.version
import torch
import transformers
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
from accelerate.utils import get_balanced_memory, named_module_tensors
from huggingface_hub import HfFileSystem, ModelCard, ModelCardData, hf_hub_download
Expand All @@ -39,6 +39,7 @@
from transformers.utils import PushToHubMixin

from peft.utils.constants import DUMMY_MODEL_CONFIG, PEFT_TYPE_TO_PREFIX_MAPPING
from peft.utils.integrations import init_empty_weights

from . import __version__
from .config import PeftConfig
Expand Down
4 changes: 2 additions & 2 deletions src/peft/tuners/lora/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
from peft.utils.other import _get_submodules, get_pattern_key

from .config import LoraConfig
from .layer import Embedding, LoraLayer, _ConvNd
from .layer import Embedding, LoraLayer, MultiheadAttention, _ConvNd


UNSUPPORTED_LORA_MODULES = (Embedding, _ConvNd)
UNSUPPORTED_LORA_MODULES = (Embedding, MultiheadAttention, _ConvNd)


class _Hook:
Expand Down
360 changes: 359 additions & 1 deletion src/peft/tuners/lora/layer.py

Large diffs are not rendered by default.

37 changes: 18 additions & 19 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,6 @@ def _replace_module(self, parent, child_name, new_module, child):
if hasattr(child, "base_layer"):
child = child.base_layer

if not hasattr(new_module, "base_layer"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this has been removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forgot to put this into the description of the PR.

These lines are obsolete for some time now. They only apply when we unload the model (otherwise, the if does not match). Remember when we made the base_layer switch, we ensured that when unloading, we simply return the base_layer, no more need to create a new layer (say, a new nn.Linear when using lora.Linear) and replace the new layer's weight by the parent layer's weight. The base_layer already has the original weight. Therefore, these lines are unnecessary.

I removed them now because they were annoying with MultiheadAttention, because that layer has no weight attribute, so this line would fail.

if hasattr(new_module, "W_q"): # HQQ
new_module.W_q = child.W_q
else:
new_module.weight = child.weight
if hasattr(child, "bias"):
new_module.bias = child.bias

if getattr(child, "state", None) is not None:
if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
Expand All @@ -266,15 +258,16 @@ def _replace_module(self, parent, child_name, new_module, child):
# dispatch to correct device
for name, module in new_module.named_modules():
if (self.prefix in name) or ("ranknum" in name):
weight = (
child.qweight
if hasattr(child, "qweight")
else child.W_q
if hasattr(child, "W_q")
else child.weight
if hasattr(child, "weight")
else next(child.parameters())
)
if hasattr(child, "qweight"):
weight = child.qweight
elif hasattr(child, "W_q"):
weight = child.W_q
elif hasattr(child, "weight"):
weight = child.weight
elif getattr(child, "in_proj_weight", None) is not None: # MHA
weight = child.in_proj_weight
else:
weight = next(child.parameters())
if not any(p.device == meta for p in module.parameters()):
module.to(weight.device)

Expand Down Expand Up @@ -360,7 +353,7 @@ def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs):
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `torch.nn.Conv3d`, "
"`transformers.pytorch_utils.Conv1D`."
"`transformers.pytorch_utils.Conv1D`, `torch.nn.MultiheadAttention.`."
)

return new_module
Expand Down Expand Up @@ -509,7 +502,13 @@ def _unload_and_optionally_merge(
except AttributeError:
continue
with onload_layer(target):
if hasattr(target, "base_layer"):
if hasattr(target, "unload_and_optionally_merge_module"):
# if layers have special unloading method, like MultiheadAttention, use that
unloaded_module = target.unload_and_optionally_merge_module(
merge=merge, safe_merge=safe_merge, adapter_names=adapter_names
)
self._replace_module(parent, target_name, unloaded_module, target)
elif hasattr(target, "base_layer"):
if merge:
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
Expand Down
7 changes: 5 additions & 2 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from typing import Any, Optional, Union

import torch
from accelerate import init_empty_weights
from accelerate.hooks import AlignDevicesHook
from accelerate.utils import named_module_tensors, offload_state_dict
from torch import nn
Expand All @@ -39,6 +38,7 @@
MIN_TARGET_MODULES_FOR_OPTIMIZATION,
SEQ_CLS_HEAD_NAMES,
)
from peft.utils.integrations import init_empty_weights
from peft.utils.peft_types import PeftType, TaskType

from ..config import PeftConfig
Expand Down Expand Up @@ -828,9 +828,12 @@ def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optio
Move the adapter of the given name to the device of the base layer.
"""
if device is None:
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.MultiheadAttention):
base_layer = base_layer.out_proj
# check weight and qweight (for GPTQ)
for weight_name in ("weight", "qweight"):
weight = getattr(self.get_base_layer(), weight_name, None)
weight = getattr(base_layer, weight_name, None)
if weight is not None:
device = weight.device
dtype = weight.dtype
Expand Down
108 changes: 108 additions & 0 deletions src/peft/utils/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

from __future__ import annotations

import functools
from contextlib import contextmanager
from typing import Literal

import packaging.version
import torch
import transformers
from torch import nn


@contextmanager
Expand Down Expand Up @@ -170,3 +172,109 @@ def map_cache_to_layer_device_map(model, cache) -> None:
layer_device = layer_device_map[idx]
cache.key_cache[idx] = cache.key_cache[idx].to(layer_device)
cache.value_cache[idx] = cache.value_cache[idx].to(layer_device)


##################################
# START: ADAPTED FROM ACCELERATE #
##################################
#
# Modified to support explicitly skipping layer initialization for faster switching between layer states
# (necessary for supporting `nn.MultiHeadAttention` adapters)


@contextmanager
def init_empty_weights(include_buffers: bool = None):
# adapted from accelerate.big_modeling.py
with _init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
yield f


@contextmanager
def _init_on_device(device: torch.device, include_buffers: bool = None):
# adapted from accelerate.big_modeling.py
old_register_parameter = nn.Module.register_parameter
if include_buffers:
old_register_buffer = nn.Module.register_buffer

def register_empty_parameter(module, name, param):
# This works because torch first initializes the parameters with torch.empty, thus not assigning any new memory.
# Then the parameter is moved to meta device before reset_parameters() is called, which then operates on the
# meta device, making any subsequent calls to initialization methods no-ops.
old_register_parameter(module, name, param)
if (param is not None) and (getattr(_init_on_device, "_skip", False) is not True):
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)

def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)

# Patch tensor creation
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}

def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)

return wrapper

try:
nn.Module.register_parameter = register_empty_parameter
if include_buffers:
nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
nn.Module.register_parameter = old_register_parameter
if include_buffers:
nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)


@contextmanager
def _skip_init_on_device():
# context manager to skip the _init_on_device context manager
old_val = getattr(_init_on_device, "_skip", False)
try:
_init_on_device._skip = True
yield
finally:
_init_on_device._skip = old_val


def skip_init_on_device(func):
"""
Ignore the init_on_device context manager when calling the decorated function.

This is a narrow use decorator that allows us to avoid initializing on meta device even when we're inside the
init_empty_weights context.

"""

# The need for this functionality arose when working on MultiheadAttention, where we have to call _restore_weights
# repeatedly as parametes are overwritten and need to be re-registered. When using low_cpu_mem_usage=True, as
# register_parameter is patched inside of the init_empty_weights context, this would result in those parameters
# suddenly being moved to meta device. Using this decorator allows us to avoid this.
@functools.wraps(func)
def wrapper(*args, **kwargs):
with _skip_init_on_device():
return func(*args, **kwargs)

return wrapper


#######
# END #
#######
93 changes: 85 additions & 8 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@
),
("Conv2d 1 LoRA with lora_b bias", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "lora_bias": True}),
("Conv3d 1 LoRA with lora_b bias", "Conv3d", LoraConfig, {"target_modules": ["conv3d"], "lora_bias": True}),
("MHA 1 LoRA", "MHA", LoraConfig, {"target_modules": ["mha"]}),
("MHA 2 LoRA", "MHA", LoraConfig, {"target_modules": ["mha", "lin0"]}),
#######
# IA³ #
#######
Expand Down Expand Up @@ -872,6 +874,21 @@ def forward(self, X):
return X


class ModelMha(nn.Module):
def __init__(self):
super().__init__()
self.mha = nn.MultiheadAttention(10, 2)
self.lin0 = nn.Linear(10, 2)
self.sm = nn.LogSoftmax(dim=-1)

def forward(self, X):
X = X.float()
X, _ = self.mha(X, X, X)
X = self.lin0(X)
X = self.sm(X)
return X


class MockTransformerWrapper:
"""Mock class to behave like a transformers model.

Expand Down Expand Up @@ -908,6 +925,9 @@ def from_pretrained(cls, model_id, torch_dtype=None):
if model_id == "Conv2d2":
return ModelConv2D2().to(torch_dtype)

if model_id == "MHA":
return ModelMha().to(torch_dtype)

raise ValueError(f"model_id {model_id} not implemented")


Expand Down Expand Up @@ -1074,12 +1094,13 @@ def test_only_params_are_updated(self, test_name, model_id, config_cls, config_k
model_before = copy.deepcopy(model)

model.train()
# this high learning rate was found through testing to be necessary to avoid flakiness
lr = (
100.0
if (config_kwargs.get("use_dora") and model_id == "EmbConv1D") or issubclass(config_cls, VBLoRAConfig)
else 0.5
)
lr = 0.5
if (config_kwargs.get("use_dora") and model_id == "EmbConv1D") or issubclass(config_cls, VBLoRAConfig):
# this high learning rate was found through testing to be necessary to avoid flakiness
lr = 100
elif "mha" in model_id.lower():
# we get exploding gradients with MHA when learning rate is too high
lr = 1e-3
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
Expand Down Expand Up @@ -1117,8 +1138,13 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c
)
model = get_peft_model(model, config)
model.train()
lr = 0.5 if not config_kwargs.get("use_dora") else 0.1 # otherwise we get nan
if issubclass(config_cls, VBLoRAConfig):

lr = 0.5
if config_kwargs.get("use_dora"):
lr = 0.1 # otherwise we get nan
elif "mha" in model_id.lower():
lr = 1e-3 # we get exploding gradients with MHA when learning rate is too high
elif issubclass(config_cls, VBLoRAConfig):
lr = 0.01 # otherwise we get nan
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

Expand Down Expand Up @@ -1775,6 +1801,14 @@ def test_gpt2_dora_merge_and_unload_safe_merge(self):
# should not raise an error
model.merge_and_unload(safe_merge=True)

def test_unload_adapter_multihead_attention(self):
# MultiheadAttention has special logic for unloading, that logic is covered by this test
self._test_unload_adapter(
model_id="MHA",
config_cls=LoraConfig,
config_kwargs={"target_modules": ["mha"], "init_lora_weights": False},
)

def test_dora_save_and_load_remapping(self):
# Here we test the refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a ModuleDict
# with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since we want the
Expand Down Expand Up @@ -1810,6 +1844,37 @@ def test_dora_save_and_load_remapping(self):
for k in state_dict:
assert torch.allclose(state_dict[k], state_dict_loaded[k])

@parameterized.expand([False, True])
def test_mha_gradients_set_correctly(self, with_forward_call):
# check for this bug: https://github.com/huggingface/peft/issues/761#issuecomment-1893804738
base_model = ModelMha()
config = LoraConfig(target_modules=["mha"])
model = get_peft_model(base_model, config)
model = model.to(self.torch_device)

if with_forward_call:
# after the merge-unmerge roundtrip happening in forward of lora MHA, the base weights should be set to
# requires_grad=False
inputs = self.prepare_inputs_for_testing()
model(**inputs)

assert model.base_model.model.mha.base_layer.out_proj.base_layer.weight.requires_grad is False
assert model.base_model.model.mha.base_layer.in_proj_weight.requires_grad is False

# _restore_weights used to ignore the gradient, this checks that it is indeed considered
model.base_model.model.mha._restore_weights()
assert model.base_model.model.mha.base_layer.out_proj.base_layer.weight.requires_grad is False
assert model.base_model.model.mha.base_layer.in_proj_weight.requires_grad is False

model.base_model.model.mha.base_layer.out_proj.base_layer.weight.requires_grad = True
model.base_model.model.mha.base_layer.in_proj_weight.requires_grad = True
assert model.base_model.model.mha.base_layer.out_proj.base_layer.weight.requires_grad is True
assert model.base_model.model.mha.base_layer.in_proj_weight.requires_grad is True

model.base_model.model.mha._restore_weights()
assert model.base_model.model.mha.base_layer.out_proj.base_layer.weight.requires_grad is True
assert model.base_model.model.mha.base_layer.in_proj_weight.requires_grad is True


class TestMultiRankAdapter(unittest.TestCase):
"""Tests related to multirank LoRA adapters"""
Expand Down Expand Up @@ -3630,6 +3695,18 @@ def test_mixed_adapter_batches_lora_conv2d(self):
inputs = {"X": torch.arange(270).view(6, 5, 3, 3).to(self.torch_device)}
self.run_checks(peft_model, inputs)

def test_mixed_adapter_batches_mha_raises(self):
base_model = ModelMha().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["mha"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["mha"], r=16, init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)

inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
msg = "lora.MultiheadAttention does not support mixed adapter batches"
with pytest.raises(TypeError, match=msg):
self.run_checks(peft_model, inputs)

def test_mixed_adapter_batches_lora_length_mismatch_raises(self, mlp_lora):
inputs = {
"X": torch.arange(90).view(-1, 10).to(self.torch_device),
Expand Down
Loading
Loading