-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[WIP] Add LoRA multihead attention module #1324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
49fab86
d8e9589
0e188a3
b409d81
173062c
1e007f5
557c4a1
add1f51
e44e030
8d62579
c5d8a6b
9dc4a4d
c3fb2ce
17d407b
4cbf6e9
e0cae11
52c8d9b
977c84b
96d376d
0c17476
4b8db0c
7e91712
e12070b
7b6c7cb
e6ab8ed
8ec6c3c
f6ba465
fb18886
4ff2ec3
d1f6ab2
65363be
7ba2e68
6ef04b0
07c7240
cc3ac3d
03c466f
e558caa
38f4a98
7e5c61d
183bf52
b970607
732e8e7
79e2b38
61e6934
ced2f15
4c31bbc
1dbb9a5
e094234
e90af48
30a08e7
09f5ea6
6a83bd7
3b0471a
465a85e
266f9da
39e755e
4857858
74cbba6
14deb9f
ba2a8dd
ac10b18
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -59,6 +59,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: | |||
in_features, out_features = ( | ||||
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape | ||||
) | ||||
elif isinstance(base_layer, nn.MultiheadAttention): | ||||
assert base_layer._qkv_same_embed_dim, "Only same embed dim supported as of now" | ||||
in_features, out_features = base_layer.embed_dim, 3 * base_layer.embed_dim | ||||
elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"): | ||||
# QuantLinear | ||||
in_features, out_features = base_layer.infeatures, base_layer.outfeatures | ||||
|
@@ -684,3 +687,172 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: | |||
def __repr__(self) -> str: | ||||
rep = super().__repr__() | ||||
return "lora." + rep | ||||
|
||||
|
||||
class MultiheadAttention(nn.Module, LoraLayer): | ||||
def __init__( | ||||
self, | ||||
base_layer, | ||||
adapter_name: str, | ||||
r: int = 0, | ||||
lora_alpha: int = 1, | ||||
lora_dropout: float = 0.0, | ||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) | ||||
is_target_conv_1d_layer: bool = False, | ||||
init_lora_weights: Union[bool, str] = True, | ||||
use_rslora: bool = False, | ||||
**kwargs, | ||||
) -> None: | ||||
# TODO work with separate weights | ||||
assert base_layer._qkv_same_embed_dim, "Only same embed dim supported as of now" | ||||
|
||||
super().__init__() | ||||
LoraLayer.__init__(self, base_layer, **kwargs) | ||||
self.fan_in_fan_out = fan_in_fan_out | ||||
|
||||
self._active_adapter = adapter_name | ||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) | ||||
self.is_target_conv_1d_layer = is_target_conv_1d_layer | ||||
|
self.is_target_conv_1d_layer = is_target_conv_1d_layer |
We can also just hard-code it to False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this throw an exception? AFAICS we're assigning a tensor to a parameter value:
foo = torch.nn.Linear(10, 100)
foo.weight = foo.weight.detach() # raises
What am I missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's true that we change the type here, I guess you could consider this part of the hack to make this work. At the end, through _restore_weights
, the correct type is restored.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, yes. I missed the del
statement which unregisters the parameter and, thus, removes the setattr
constraint. WDYT about something along the lines of
# unregister parameter implicitly and overwrite using merged weights; gradients are computed
# after forward and, thus, after unmerging (see forward()), therefore this is safe to do.
del base_layer.in_proj_weight
base_layer.in_proj_weight = orig_weights_in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
githubnemo marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need also to over-write the modules()
method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed, as modules
calls named_modules
under the hood. I added a comment to that effect.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,7 +43,7 @@ | |
|
||
from .config import LoraConfig | ||
from .gptq import QuantLinear | ||
from .layer import Conv2d, Embedding, Linear, LoraLayer | ||
from .layer import Conv2d, Embedding, Linear, LoraLayer, MultiheadAttention | ||
|
||
|
||
class LoraModel(BaseTuner): | ||
|
@@ -193,11 +193,6 @@ def _replace_module(self, parent, child_name, new_module, child): | |
if hasattr(child, "base_layer"): | ||
child = child.base_layer | ||
|
||
if not hasattr(new_module, "base_layer"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this has been removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, forgot to put this into the description of the PR. These lines are obsolete for some time now. They only apply when we unload the model (otherwise, the I removed them now because they were annoying with |
||
new_module.weight = child.weight | ||
if hasattr(child, "bias"): | ||
new_module.bias = child.bias | ||
|
||
if getattr(child, "state", None) is not None: | ||
if hasattr(new_module, "base_layer"): | ||
new_module.base_layer.state = child.state | ||
|
@@ -208,7 +203,16 @@ def _replace_module(self, parent, child_name, new_module, child): | |
# dispatch to correct device | ||
for name, module in new_module.named_modules(): | ||
if (self.prefix in name) or ("ranknum" in name): | ||
weight = child.qweight if hasattr(child, "qweight") else child.weight | ||
if hasattr(child, "qweight"): | ||
weight = child.qweight | ||
elif hasattr(child, "weight"): | ||
weight = child.weight | ||
elif getattr(child, "in_proj_weight", None) is not None: # MHA | ||
weight = child.in_proj_weight | ||
elif getattr(child, "q_proj_weight", None) is not None: # MHA | ||
weight = child.q_proj_weight | ||
|
||
else: | ||
raise ValueError(f"Encountered unknown module type: {type(child)}") | ||
module.to(weight.device) | ||
|
||
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: | ||
|
@@ -290,6 +294,9 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): | |
elif isinstance(target_base_layer, torch.nn.Conv2d): | ||
kwargs.update(lora_config.loftq_config) | ||
new_module = Conv2d(target, adapter_name, **kwargs) | ||
elif isinstance(target_base_layer, torch.nn.MultiheadAttention): | ||
kwargs.update(lora_config.loftq_config) | ||
new_module = MultiheadAttention(target, adapter_name, **kwargs) | ||
elif isinstance(target_base_layer, torch.nn.Linear): | ||
if kwargs["fan_in_fan_out"]: | ||
warnings.warn( | ||
|
@@ -333,7 +340,8 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): | |
else: | ||
raise ValueError( | ||
f"Target module {target} is not supported. Currently, only the following modules are supported: " | ||
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`." | ||
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`, " | ||
"`torch.nn.MultiheadAttention.`" | ||
) | ||
|
||
return new_module | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is used?