-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[WIP] Add LoRA multihead attention module #1324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 59 commits
49fab86
d8e9589
0e188a3
b409d81
173062c
1e007f5
557c4a1
add1f51
e44e030
8d62579
c5d8a6b
9dc4a4d
c3fb2ce
17d407b
4cbf6e9
e0cae11
52c8d9b
977c84b
96d376d
0c17476
4b8db0c
7e91712
e12070b
7b6c7cb
e6ab8ed
8ec6c3c
f6ba465
fb18886
4ff2ec3
d1f6ab2
65363be
7ba2e68
6ef04b0
07c7240
cc3ac3d
03c466f
e558caa
38f4a98
7e5c61d
183bf52
b970607
732e8e7
79e2b38
61e6934
ced2f15
4c31bbc
1dbb9a5
e094234
e90af48
30a08e7
09f5ea6
6a83bd7
3b0471a
465a85e
266f9da
39e755e
4857858
74cbba6
14deb9f
ba2a8dd
ac10b18
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -247,14 +247,6 @@ def _replace_module(self, parent, child_name, new_module, child): | |
if hasattr(child, "base_layer"): | ||
child = child.base_layer | ||
|
||
if not hasattr(new_module, "base_layer"): | ||
if hasattr(new_module, "W_q"): # HQQ | ||
new_module.W_q = child.W_q | ||
else: | ||
new_module.weight = child.weight | ||
if hasattr(child, "bias"): | ||
new_module.bias = child.bias | ||
|
||
if getattr(child, "state", None) is not None: | ||
if hasattr(new_module, "base_layer"): | ||
new_module.base_layer.state = child.state | ||
|
@@ -266,15 +258,18 @@ def _replace_module(self, parent, child_name, new_module, child): | |
# dispatch to correct device | ||
for name, module in new_module.named_modules(): | ||
if (self.prefix in name) or ("ranknum" in name): | ||
weight = ( | ||
child.qweight | ||
if hasattr(child, "qweight") | ||
else child.W_q | ||
if hasattr(child, "W_q") | ||
else child.weight | ||
if hasattr(child, "weight") | ||
else next(child.parameters()) | ||
) | ||
if hasattr(child, "qweight"): | ||
weight = child.qweight | ||
elif hasattr(child, "W_q"): | ||
weight = child.W_q | ||
elif hasattr(child, "weight"): | ||
weight = child.weight | ||
elif getattr(child, "in_proj_weight", None) is not None: # MHA | ||
weight = child.in_proj_weight | ||
elif getattr(child, "q_proj_weight", None) is not None: # MHA | ||
weight = child.q_proj_weight | ||
|
||
else: | ||
weight = next(child.parameters()) | ||
if not any(p.device == meta for p in module.parameters()): | ||
module.to(weight.device) | ||
|
||
|
@@ -360,7 +355,7 @@ def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs): | |
raise ValueError( | ||
f"Target module {target} is not supported. Currently, only the following modules are supported: " | ||
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `torch.nn.Conv3d`, " | ||
"`transformers.pytorch_utils.Conv1D`." | ||
"`transformers.pytorch_utils.Conv1D`, `torch.nn.MultiheadAttention.`." | ||
) | ||
|
||
return new_module | ||
|
@@ -509,7 +504,13 @@ def _unload_and_optionally_merge( | |
except AttributeError: | ||
continue | ||
with onload_layer(target): | ||
if hasattr(target, "base_layer"): | ||
if hasattr(target, "unload_and_optionally_merge_module"): | ||
# if layers have special unloading method, like MultiheadAttention, use that | ||
unloaded_module = target.unload_and_optionally_merge_module( | ||
merge=merge, safe_merge=safe_merge, adapter_names=adapter_names | ||
) | ||
self._replace_module(parent, target_name, unloaded_module, target) | ||
elif hasattr(target, "base_layer"): | ||
if merge: | ||
target.merge(safe_merge=safe_merge, adapter_names=adapter_names) | ||
self._replace_module(parent, target_name, target.get_base_layer(), target) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this has been removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, forgot to put this into the description of the PR.
These lines are obsolete for some time now. They only apply when we unload the model (otherwise, the
if
does not match). Remember when we made thebase_layer
switch, we ensured that when unloading, we simply return thebase_layer
, no more need to create a new layer (say, a newnn.Linear
when usinglora.Linear
) and replace the new layer'sweight
by the parent layer'sweight
. Thebase_layer
already has the originalweight
. Therefore, these lines are unnecessary.I removed them now because they were annoying with
MultiheadAttention
, because that layer has noweight
attribute, so this line would fail.