Skip to content

Commit 06ee86f

Browse files
BenjaminBossansywangyikeakonmuellerzrsaeid93
authored
ENH Add LoRA multihead attention module (huggingface#1324)
For now, only works with _qkv_same_embed_dim=True. --------- Co-authored-by: Wang, Yi <[email protected]> Co-authored-by: keakon <[email protected]> Co-authored-by: Zach Mueller <[email protected]> Co-authored-by: Saeid Ghafouri <[email protected]> Co-authored-by: Fanli Lin <[email protected]> Co-authored-by: githubnemo <[email protected]>
1 parent 04299fc commit 06ee86f

File tree

10 files changed

+764
-38
lines changed

10 files changed

+764
-38
lines changed

src/peft/peft_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import packaging.version
2828
import torch
2929
import transformers
30-
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
30+
from accelerate import dispatch_model, infer_auto_device_map
3131
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
3232
from accelerate.utils import get_balanced_memory, named_module_tensors
3333
from huggingface_hub import HfFileSystem, ModelCard, ModelCardData, hf_hub_download
@@ -39,6 +39,7 @@
3939
from transformers.utils import PushToHubMixin
4040

4141
from peft.utils.constants import DUMMY_MODEL_CONFIG, PEFT_TYPE_TO_PREFIX_MAPPING
42+
from peft.utils.integrations import init_empty_weights
4243

4344
from . import __version__
4445
from .config import PeftConfig

src/peft/tuners/lora/eva.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
from peft.utils.other import _get_submodules, get_pattern_key
3333

3434
from .config import LoraConfig
35-
from .layer import Embedding, LoraLayer, _ConvNd
35+
from .layer import Embedding, LoraLayer, MultiheadAttention, _ConvNd
3636

3737

38-
UNSUPPORTED_LORA_MODULES = (Embedding, _ConvNd)
38+
UNSUPPORTED_LORA_MODULES = (Embedding, MultiheadAttention, _ConvNd)
3939

4040

4141
class _Hook:

src/peft/tuners/lora/layer.py

Lines changed: 359 additions & 1 deletion
Large diffs are not rendered by default.

src/peft/tuners/lora/model.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -247,14 +247,6 @@ def _replace_module(self, parent, child_name, new_module, child):
247247
if hasattr(child, "base_layer"):
248248
child = child.base_layer
249249

250-
if not hasattr(new_module, "base_layer"):
251-
if hasattr(new_module, "W_q"): # HQQ
252-
new_module.W_q = child.W_q
253-
else:
254-
new_module.weight = child.weight
255-
if hasattr(child, "bias"):
256-
new_module.bias = child.bias
257-
258250
if getattr(child, "state", None) is not None:
259251
if hasattr(new_module, "base_layer"):
260252
new_module.base_layer.state = child.state
@@ -266,15 +258,16 @@ def _replace_module(self, parent, child_name, new_module, child):
266258
# dispatch to correct device
267259
for name, module in new_module.named_modules():
268260
if (self.prefix in name) or ("ranknum" in name):
269-
weight = (
270-
child.qweight
271-
if hasattr(child, "qweight")
272-
else child.W_q
273-
if hasattr(child, "W_q")
274-
else child.weight
275-
if hasattr(child, "weight")
276-
else next(child.parameters())
277-
)
261+
if hasattr(child, "qweight"):
262+
weight = child.qweight
263+
elif hasattr(child, "W_q"):
264+
weight = child.W_q
265+
elif hasattr(child, "weight"):
266+
weight = child.weight
267+
elif getattr(child, "in_proj_weight", None) is not None: # MHA
268+
weight = child.in_proj_weight
269+
else:
270+
weight = next(child.parameters())
278271
if not any(p.device == meta for p in module.parameters()):
279272
module.to(weight.device)
280273

@@ -360,7 +353,7 @@ def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs):
360353
raise ValueError(
361354
f"Target module {target} is not supported. Currently, only the following modules are supported: "
362355
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `torch.nn.Conv3d`, "
363-
"`transformers.pytorch_utils.Conv1D`."
356+
"`transformers.pytorch_utils.Conv1D`, `torch.nn.MultiheadAttention.`."
364357
)
365358

366359
return new_module
@@ -509,7 +502,13 @@ def _unload_and_optionally_merge(
509502
except AttributeError:
510503
continue
511504
with onload_layer(target):
512-
if hasattr(target, "base_layer"):
505+
if hasattr(target, "unload_and_optionally_merge_module"):
506+
# if layers have special unloading method, like MultiheadAttention, use that
507+
unloaded_module = target.unload_and_optionally_merge_module(
508+
merge=merge, safe_merge=safe_merge, adapter_names=adapter_names
509+
)
510+
self._replace_module(parent, target_name, unloaded_module, target)
511+
elif hasattr(target, "base_layer"):
513512
if merge:
514513
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
515514
self._replace_module(parent, target_name, target.get_base_layer(), target)

src/peft/tuners/tuners_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from typing import Any, Optional, Union
2525

2626
import torch
27-
from accelerate import init_empty_weights
2827
from accelerate.hooks import AlignDevicesHook
2928
from accelerate.utils import named_module_tensors, offload_state_dict
3029
from torch import nn
@@ -39,6 +38,7 @@
3938
MIN_TARGET_MODULES_FOR_OPTIMIZATION,
4039
SEQ_CLS_HEAD_NAMES,
4140
)
41+
from peft.utils.integrations import init_empty_weights
4242
from peft.utils.peft_types import PeftType, TaskType
4343

4444
from ..config import PeftConfig
@@ -828,9 +828,12 @@ def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optio
828828
Move the adapter of the given name to the device of the base layer.
829829
"""
830830
if device is None:
831+
base_layer = self.get_base_layer()
832+
if isinstance(base_layer, nn.MultiheadAttention):
833+
base_layer = base_layer.out_proj
831834
# check weight and qweight (for GPTQ)
832835
for weight_name in ("weight", "qweight"):
833-
weight = getattr(self.get_base_layer(), weight_name, None)
836+
weight = getattr(base_layer, weight_name, None)
834837
if weight is not None:
835838
device = weight.device
836839
dtype = weight.dtype

src/peft/utils/integrations.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414

1515
from __future__ import annotations
1616

17+
import functools
1718
from contextlib import contextmanager
1819
from typing import Literal
1920

2021
import packaging.version
2122
import torch
2223
import transformers
24+
from torch import nn
2325

2426

2527
@contextmanager
@@ -170,3 +172,109 @@ def map_cache_to_layer_device_map(model, cache) -> None:
170172
layer_device = layer_device_map[idx]
171173
cache.key_cache[idx] = cache.key_cache[idx].to(layer_device)
172174
cache.value_cache[idx] = cache.value_cache[idx].to(layer_device)
175+
176+
177+
##################################
178+
# START: ADAPTED FROM ACCELERATE #
179+
##################################
180+
#
181+
# Modified to support explicitly skipping layer initialization for faster switching between layer states
182+
# (necessary for supporting `nn.MultiHeadAttention` adapters)
183+
184+
185+
@contextmanager
186+
def init_empty_weights(include_buffers: bool = None):
187+
# adapted from accelerate.big_modeling.py
188+
with _init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
189+
yield f
190+
191+
192+
@contextmanager
193+
def _init_on_device(device: torch.device, include_buffers: bool = None):
194+
# adapted from accelerate.big_modeling.py
195+
old_register_parameter = nn.Module.register_parameter
196+
if include_buffers:
197+
old_register_buffer = nn.Module.register_buffer
198+
199+
def register_empty_parameter(module, name, param):
200+
# This works because torch first initializes the parameters with torch.empty, thus not assigning any new memory.
201+
# Then the parameter is moved to meta device before reset_parameters() is called, which then operates on the
202+
# meta device, making any subsequent calls to initialization methods no-ops.
203+
old_register_parameter(module, name, param)
204+
if (param is not None) and (getattr(_init_on_device, "_skip", False) is not True):
205+
param_cls = type(module._parameters[name])
206+
kwargs = module._parameters[name].__dict__
207+
kwargs["requires_grad"] = param.requires_grad
208+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
209+
210+
def register_empty_buffer(module, name, buffer, persistent=True):
211+
old_register_buffer(module, name, buffer, persistent=persistent)
212+
if buffer is not None:
213+
module._buffers[name] = module._buffers[name].to(device)
214+
215+
# Patch tensor creation
216+
if include_buffers:
217+
tensor_constructors_to_patch = {
218+
torch_function_name: getattr(torch, torch_function_name)
219+
for torch_function_name in ["empty", "zeros", "ones", "full"]
220+
}
221+
else:
222+
tensor_constructors_to_patch = {}
223+
224+
def patch_tensor_constructor(fn):
225+
def wrapper(*args, **kwargs):
226+
kwargs["device"] = device
227+
return fn(*args, **kwargs)
228+
229+
return wrapper
230+
231+
try:
232+
nn.Module.register_parameter = register_empty_parameter
233+
if include_buffers:
234+
nn.Module.register_buffer = register_empty_buffer
235+
for torch_function_name in tensor_constructors_to_patch.keys():
236+
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
237+
yield
238+
finally:
239+
nn.Module.register_parameter = old_register_parameter
240+
if include_buffers:
241+
nn.Module.register_buffer = old_register_buffer
242+
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
243+
setattr(torch, torch_function_name, old_torch_function)
244+
245+
246+
@contextmanager
247+
def _skip_init_on_device():
248+
# context manager to skip the _init_on_device context manager
249+
old_val = getattr(_init_on_device, "_skip", False)
250+
try:
251+
_init_on_device._skip = True
252+
yield
253+
finally:
254+
_init_on_device._skip = old_val
255+
256+
257+
def skip_init_on_device(func):
258+
"""
259+
Ignore the init_on_device context manager when calling the decorated function.
260+
261+
This is a narrow use decorator that allows us to avoid initializing on meta device even when we're inside the
262+
init_empty_weights context.
263+
264+
"""
265+
266+
# The need for this functionality arose when working on MultiheadAttention, where we have to call _restore_weights
267+
# repeatedly as parametes are overwritten and need to be re-registered. When using low_cpu_mem_usage=True, as
268+
# register_parameter is patched inside of the init_empty_weights context, this would result in those parameters
269+
# suddenly being moved to meta device. Using this decorator allows us to avoid this.
270+
@functools.wraps(func)
271+
def wrapper(*args, **kwargs):
272+
with _skip_init_on_device():
273+
return func(*args, **kwargs)
274+
275+
return wrapper
276+
277+
278+
#######
279+
# END #
280+
#######

tests/test_custom_models.py

Lines changed: 85 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@
120120
),
121121
("Conv2d 1 LoRA with lora_b bias", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "lora_bias": True}),
122122
("Conv3d 1 LoRA with lora_b bias", "Conv3d", LoraConfig, {"target_modules": ["conv3d"], "lora_bias": True}),
123+
("MHA 1 LoRA", "MHA", LoraConfig, {"target_modules": ["mha"]}),
124+
("MHA 2 LoRA", "MHA", LoraConfig, {"target_modules": ["mha", "lin0"]}),
123125
#######
124126
# IA³ #
125127
#######
@@ -872,6 +874,21 @@ def forward(self, X):
872874
return X
873875

874876

877+
class ModelMha(nn.Module):
878+
def __init__(self):
879+
super().__init__()
880+
self.mha = nn.MultiheadAttention(10, 2)
881+
self.lin0 = nn.Linear(10, 2)
882+
self.sm = nn.LogSoftmax(dim=-1)
883+
884+
def forward(self, X):
885+
X = X.float()
886+
X, _ = self.mha(X, X, X)
887+
X = self.lin0(X)
888+
X = self.sm(X)
889+
return X
890+
891+
875892
class MockTransformerWrapper:
876893
"""Mock class to behave like a transformers model.
877894
@@ -908,6 +925,9 @@ def from_pretrained(cls, model_id, torch_dtype=None):
908925
if model_id == "Conv2d2":
909926
return ModelConv2D2().to(torch_dtype)
910927

928+
if model_id == "MHA":
929+
return ModelMha().to(torch_dtype)
930+
911931
raise ValueError(f"model_id {model_id} not implemented")
912932

913933

@@ -1074,12 +1094,13 @@ def test_only_params_are_updated(self, test_name, model_id, config_cls, config_k
10741094
model_before = copy.deepcopy(model)
10751095

10761096
model.train()
1077-
# this high learning rate was found through testing to be necessary to avoid flakiness
1078-
lr = (
1079-
100.0
1080-
if (config_kwargs.get("use_dora") and model_id == "EmbConv1D") or issubclass(config_cls, VBLoRAConfig)
1081-
else 0.5
1082-
)
1097+
lr = 0.5
1098+
if (config_kwargs.get("use_dora") and model_id == "EmbConv1D") or issubclass(config_cls, VBLoRAConfig):
1099+
# this high learning rate was found through testing to be necessary to avoid flakiness
1100+
lr = 100
1101+
elif "mha" in model_id.lower():
1102+
# we get exploding gradients with MHA when learning rate is too high
1103+
lr = 1e-3
10831104
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
10841105

10851106
# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
@@ -1117,8 +1138,13 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c
11171138
)
11181139
model = get_peft_model(model, config)
11191140
model.train()
1120-
lr = 0.5 if not config_kwargs.get("use_dora") else 0.1 # otherwise we get nan
1121-
if issubclass(config_cls, VBLoRAConfig):
1141+
1142+
lr = 0.5
1143+
if config_kwargs.get("use_dora"):
1144+
lr = 0.1 # otherwise we get nan
1145+
elif "mha" in model_id.lower():
1146+
lr = 1e-3 # we get exploding gradients with MHA when learning rate is too high
1147+
elif issubclass(config_cls, VBLoRAConfig):
11221148
lr = 0.01 # otherwise we get nan
11231149
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
11241150

@@ -1775,6 +1801,14 @@ def test_gpt2_dora_merge_and_unload_safe_merge(self):
17751801
# should not raise an error
17761802
model.merge_and_unload(safe_merge=True)
17771803

1804+
def test_unload_adapter_multihead_attention(self):
1805+
# MultiheadAttention has special logic for unloading, that logic is covered by this test
1806+
self._test_unload_adapter(
1807+
model_id="MHA",
1808+
config_cls=LoraConfig,
1809+
config_kwargs={"target_modules": ["mha"], "init_lora_weights": False},
1810+
)
1811+
17781812
def test_dora_save_and_load_remapping(self):
17791813
# Here we test the refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a ModuleDict
17801814
# with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since we want the
@@ -1810,6 +1844,37 @@ def test_dora_save_and_load_remapping(self):
18101844
for k in state_dict:
18111845
assert torch.allclose(state_dict[k], state_dict_loaded[k])
18121846

1847+
@parameterized.expand([False, True])
1848+
def test_mha_gradients_set_correctly(self, with_forward_call):
1849+
# check for this bug: https://github.com/huggingface/peft/issues/761#issuecomment-1893804738
1850+
base_model = ModelMha()
1851+
config = LoraConfig(target_modules=["mha"])
1852+
model = get_peft_model(base_model, config)
1853+
model = model.to(self.torch_device)
1854+
1855+
if with_forward_call:
1856+
# after the merge-unmerge roundtrip happening in forward of lora MHA, the base weights should be set to
1857+
# requires_grad=False
1858+
inputs = self.prepare_inputs_for_testing()
1859+
model(**inputs)
1860+
1861+
assert model.base_model.model.mha.base_layer.out_proj.base_layer.weight.requires_grad is False
1862+
assert model.base_model.model.mha.base_layer.in_proj_weight.requires_grad is False
1863+
1864+
# _restore_weights used to ignore the gradient, this checks that it is indeed considered
1865+
model.base_model.model.mha._restore_weights()
1866+
assert model.base_model.model.mha.base_layer.out_proj.base_layer.weight.requires_grad is False
1867+
assert model.base_model.model.mha.base_layer.in_proj_weight.requires_grad is False
1868+
1869+
model.base_model.model.mha.base_layer.out_proj.base_layer.weight.requires_grad = True
1870+
model.base_model.model.mha.base_layer.in_proj_weight.requires_grad = True
1871+
assert model.base_model.model.mha.base_layer.out_proj.base_layer.weight.requires_grad is True
1872+
assert model.base_model.model.mha.base_layer.in_proj_weight.requires_grad is True
1873+
1874+
model.base_model.model.mha._restore_weights()
1875+
assert model.base_model.model.mha.base_layer.out_proj.base_layer.weight.requires_grad is True
1876+
assert model.base_model.model.mha.base_layer.in_proj_weight.requires_grad is True
1877+
18131878

18141879
class TestMultiRankAdapter(unittest.TestCase):
18151880
"""Tests related to multirank LoRA adapters"""
@@ -3630,6 +3695,18 @@ def test_mixed_adapter_batches_lora_conv2d(self):
36303695
inputs = {"X": torch.arange(270).view(6, 5, 3, 3).to(self.torch_device)}
36313696
self.run_checks(peft_model, inputs)
36323697

3698+
def test_mixed_adapter_batches_mha_raises(self):
3699+
base_model = ModelMha().to(self.torch_device).eval()
3700+
config0 = LoraConfig(target_modules=["mha"], init_lora_weights=False)
3701+
config1 = LoraConfig(target_modules=["mha"], r=16, init_lora_weights=False)
3702+
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
3703+
peft_model.add_adapter("adapter1", config1)
3704+
3705+
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
3706+
msg = "lora.MultiheadAttention does not support mixed adapter batches"
3707+
with pytest.raises(TypeError, match=msg):
3708+
self.run_checks(peft_model, inputs)
3709+
36333710
def test_mixed_adapter_batches_lora_length_mismatch_raises(self, mlp_lora):
36343711
inputs = {
36353712
"X": torch.arange(90).view(-1, 10).to(self.torch_device),

0 commit comments

Comments
 (0)