-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add LoRA multihead attention module #1324
base: main
Are you sure you want to change the base?
Changes from 6 commits
49fab86
d8e9589
0e188a3
b409d81
173062c
1e007f5
557c4a1
add1f51
e44e030
8d62579
c5d8a6b
9dc4a4d
c3fb2ce
17d407b
4cbf6e9
e0cae11
52c8d9b
977c84b
96d376d
0c17476
4b8db0c
7e91712
e12070b
7b6c7cb
e6ab8ed
8ec6c3c
f6ba465
fb18886
4ff2ec3
d1f6ab2
65363be
7ba2e68
6ef04b0
07c7240
cc3ac3d
03c466f
e558caa
38f4a98
7e5c61d
183bf52
b970607
732e8e7
79e2b38
61e6934
ced2f15
4c31bbc
1dbb9a5
e094234
e90af48
30a08e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -59,6 +59,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: | |||
in_features, out_features = ( | ||||
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape | ||||
) | ||||
elif isinstance(base_layer, nn.MultiheadAttention): | ||||
assert base_layer._qkv_same_embed_dim, "Only same embed dim supported as of now" | ||||
in_features, out_features = base_layer.embed_dim, 3 * base_layer.embed_dim | ||||
elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"): | ||||
# QuantLinear | ||||
in_features, out_features = base_layer.infeatures, base_layer.outfeatures | ||||
|
@@ -684,3 +687,172 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: | |||
def __repr__(self) -> str: | ||||
rep = super().__repr__() | ||||
return "lora." + rep | ||||
|
||||
|
||||
class MultiheadAttention(nn.Module, LoraLayer): | ||||
def __init__( | ||||
self, | ||||
base_layer, | ||||
adapter_name: str, | ||||
r: int = 0, | ||||
lora_alpha: int = 1, | ||||
lora_dropout: float = 0.0, | ||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) | ||||
is_target_conv_1d_layer: bool = False, | ||||
init_lora_weights: Union[bool, str] = True, | ||||
use_rslora: bool = False, | ||||
**kwargs, | ||||
) -> None: | ||||
# TODO work with separate weights | ||||
assert base_layer._qkv_same_embed_dim, "Only same embed dim supported as of now" | ||||
|
||||
super().__init__() | ||||
LoraLayer.__init__(self, base_layer, **kwargs) | ||||
self.fan_in_fan_out = fan_in_fan_out | ||||
|
||||
self._active_adapter = adapter_name | ||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) | ||||
self.is_target_conv_1d_layer = is_target_conv_1d_layer | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
We can also just hard-code it to False |
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: | ||||
""" | ||||
Merge the active adapter weights into the base weights | ||||
|
||||
Args: | ||||
safe_merge (`bool`, *optional*): | ||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs | ||||
before merging the weights. This is useful if you want to check if the merge operation will produce | ||||
NaNs. Defaults to `False`. | ||||
adapter_names (`List[str]`, *optional*): | ||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults | ||||
to `None`. | ||||
""" | ||||
if self.merged: | ||||
warnings.warn( | ||||
f"Already following adapters were merged {','.join(self.merged_adapters)}. " | ||||
f"You are now additionally merging {','.join(self.active_adapters)}." | ||||
) | ||||
|
||||
if adapter_names is None: | ||||
adapter_names = self.active_adapters | ||||
|
||||
# Implementation follows this: | ||||
# https://github.com/Baijiong-Lin/LoRA-Torch/blob/4bfed6820b64fcf47064c30f30606a190a4f0d2e/loratorch/layers.py#L73-L79 | ||||
# Notably, instead of mutating the weight, we delete the original weight and replace it by the merged weight | ||||
# TODO: work with separate weights | ||||
for active_adapter in adapter_names: | ||||
if active_adapter in self.lora_A.keys(): | ||||
base_layer = self.get_base_layer() | ||||
if safe_merge: | ||||
orig_weights = base_layer.in_proj_weight.data.detach().clone() | ||||
orig_weights += self.get_delta_weight(active_adapter) | ||||
|
||||
if not torch.isfinite(orig_weights).all(): | ||||
raise ValueError( | ||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" | ||||
) | ||||
|
||||
del base_layer.in_proj_weight | ||||
base_layer.in_proj_weight = orig_weights | ||||
else: | ||||
# TODO: work with separate weights | ||||
weight_merged = base_layer.in_proj_weight.data.detach() + self.get_delta_weight(active_adapter) | ||||
del base_layer.in_proj_weight | ||||
base_layer.in_proj_weight = weight_merged | ||||
self.merged_adapters.append(active_adapter) | ||||
|
||||
def unmerge(self) -> None: | ||||
""" | ||||
This method unmerges all merged adapter layers from the base weights. | ||||
""" | ||||
if not self.merged: | ||||
warnings.warn("Already unmerged. Nothing to do.") | ||||
return | ||||
|
||||
# TODO work with separate weights | ||||
while len(self.merged_adapters) > 0: | ||||
active_adapter = self.merged_adapters.pop() | ||||
if active_adapter in self.lora_A.keys(): | ||||
self.get_base_layer().in_proj_weight.data -= self.get_delta_weight(active_adapter) | ||||
|
||||
def get_delta_weight(self, adapter) -> torch.Tensor: | ||||
""" | ||||
Compute the delta weight for the given adapter. | ||||
|
||||
Args: | ||||
adapter (str): | ||||
The name of the adapter for which the delta weight should be computed. | ||||
""" | ||||
device = self.lora_B[adapter].weight.device | ||||
dtype = self.lora_B[adapter].weight.dtype | ||||
|
||||
# In case users wants to merge the adapter weights that are in | ||||
# float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to | ||||
# float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16. | ||||
cast_to_fp32 = device.type == "cpu" and dtype == torch.float16 | ||||
|
||||
weight_A = self.lora_A[adapter].weight | ||||
weight_B = self.lora_B[adapter].weight | ||||
|
||||
if cast_to_fp32: | ||||
weight_A = weight_A.float() | ||||
weight_B = weight_B.float() | ||||
|
||||
output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter] | ||||
|
||||
if cast_to_fp32: | ||||
output_tensor = output_tensor.to(dtype=dtype) | ||||
|
||||
# cast back the weights | ||||
self.lora_A[adapter].weight.data = weight_A.to(dtype) | ||||
self.lora_B[adapter].weight.data = weight_B.to(dtype) | ||||
|
||||
return output_tensor | ||||
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: | ||||
previous_dtype = x.dtype | ||||
|
||||
if self.disable_adapters: | ||||
if self.merged: | ||||
self.unmerge() | ||||
result = self.base_layer(x, *args, **kwargs) | ||||
elif self.merged: | ||||
result = self.base_layer(x, *args, **kwargs) | ||||
else: | ||||
# merge all adapters that are active for this module | ||||
active_adapters = [a for a in self.active_adapters if a in self.lora_A] | ||||
try: | ||||
self.merge(adapter_names=active_adapters) | ||||
result = self.base_layer(x, *args, **kwargs) | ||||
finally: | ||||
# it's safe to call unmerge(), which unmerges all adapters, because we checked that not self.merged, | ||||
# i.e. there is was no merged layer before | ||||
self.unmerge() | ||||
|
||||
result = (result[0].to(previous_dtype), result[1].to(previous_dtype) if result[1] is not None else result[1]) | ||||
return result | ||||
|
||||
def _restore_weights(self): | ||||
# Restore the weights as registered parameters on the base layer. | ||||
# This is necessary because the way that weights are merged/unmerged (which is necessary for forward to work | ||||
# correctly), the Module "forgets" these attributes. Therefore, we need to call register_parameter explicitly. | ||||
# We cannot call register_parameter for merging/unmerging because that cuts them off from the autograd graph. | ||||
# Note that this is hacky, since we need to ensure that _restore_weights is called by each method that needs it. | ||||
|
||||
# TODO work with separate weights | ||||
base_layer = self.get_base_layer() | ||||
weight = base_layer.in_proj_weight.data | ||||
del base_layer.in_proj_weight | ||||
base_layer.register_parameter("in_proj_weight", nn.Parameter(weight)) | ||||
|
||||
def state_dict(self, *args, **kwargs): | ||||
self._restore_weights() | ||||
return super().state_dict(*args, **kwargs) | ||||
|
||||
def named_modules(self, *args, **kwargs): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need also to over-write the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not needed, as |
||||
self._restore_weights() | ||||
return super().named_modules(*args, **kwargs) | ||||
|
||||
def __repr__(self) -> str: | ||||
rep = super().__repr__() | ||||
return "lora." + rep |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,7 +43,7 @@ | |
|
||
from .config import LoraConfig | ||
from .gptq import QuantLinear | ||
from .layer import Conv2d, Embedding, Linear, LoraLayer | ||
from .layer import Conv2d, Embedding, Linear, LoraLayer, MultiheadAttention | ||
|
||
|
||
class LoraModel(BaseTuner): | ||
|
@@ -193,11 +193,6 @@ def _replace_module(self, parent, child_name, new_module, child): | |
if hasattr(child, "base_layer"): | ||
child = child.base_layer | ||
|
||
if not hasattr(new_module, "base_layer"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this has been removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, forgot to put this into the description of the PR. These lines are obsolete for some time now. They only apply when we unload the model (otherwise, the I removed them now because they were annoying with |
||
new_module.weight = child.weight | ||
if hasattr(child, "bias"): | ||
new_module.bias = child.bias | ||
|
||
if getattr(child, "state", None) is not None: | ||
if hasattr(new_module, "base_layer"): | ||
new_module.base_layer.state = child.state | ||
|
@@ -208,7 +203,16 @@ def _replace_module(self, parent, child_name, new_module, child): | |
# dispatch to correct device | ||
for name, module in new_module.named_modules(): | ||
if (self.prefix in name) or ("ranknum" in name): | ||
weight = child.qweight if hasattr(child, "qweight") else child.weight | ||
if hasattr(child, "qweight"): | ||
weight = child.qweight | ||
elif hasattr(child, "weight"): | ||
weight = child.weight | ||
elif getattr(child, "in_proj_weight", None) is not None: # MHA | ||
weight = child.in_proj_weight | ||
elif getattr(child, "q_proj_weight", None) is not None: # MHA | ||
weight = child.q_proj_weight | ||
else: | ||
raise ValueError(f"Encountered unknown module type: {type(child)}") | ||
module.to(weight.device) | ||
|
||
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: | ||
|
@@ -290,6 +294,9 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): | |
elif isinstance(target_base_layer, torch.nn.Conv2d): | ||
kwargs.update(lora_config.loftq_config) | ||
new_module = Conv2d(target, adapter_name, **kwargs) | ||
elif isinstance(target_base_layer, torch.nn.MultiheadAttention): | ||
kwargs.update(lora_config.loftq_config) | ||
new_module = MultiheadAttention(target, adapter_name, **kwargs) | ||
elif isinstance(target_base_layer, torch.nn.Linear): | ||
if kwargs["fan_in_fan_out"]: | ||
warnings.warn( | ||
|
@@ -333,7 +340,8 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): | |
else: | ||
raise ValueError( | ||
f"Target module {target} is not supported. Currently, only the following modules are supported: " | ||
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`." | ||
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`, " | ||
"`torch.nn.MultiheadAttention.`" | ||
) | ||
|
||
return new_module | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is used?