-
Notifications
You must be signed in to change notification settings - Fork 2.2k
ENH: Tie weights for target_modules in Lora (#2864) #2879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4c6d15f
4b91220
46b803e
37b1e06
8388aa8
cd6c6d0
0cb44e8
628ce10
602ce10
e2d0345
7880032
f73af50
46cca1e
2267a48
5d5b8e4
c7cfe40
8294ec7
7370a21
1da895f
d86ff7d
dc03dd4
c79a64c
0715451
dbb0096
06d4b7f
67a71d6
8889558
9f7702f
4d5d681
ba4d81f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -150,6 +150,19 @@ def update_layer( | |
| adapter_name: str, | ||
| r: int, | ||
| lora_alpha: int, | ||
| lora_dropout, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this was added by merge commit, let me take a look and remove this if not required
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this looks like a wrong merge. #2960 refactored this code to only include |
||
| init_lora_weights, | ||
| use_rslora, | ||
| use_dora: bool = False, | ||
| use_alora: bool = False, | ||
| use_qalora: bool = False, | ||
| lora_bias: bool = False, | ||
| arrow_config: ArrowConfig = None, | ||
| qalora_group_size: int = 32, | ||
| inference_mode: bool = False, | ||
| tied_adapter: Optional[dict[str, nn.Parameter]] = None, | ||
| lora_ga_config=None, | ||
| use_bdlora=None, | ||
| config: LoraConfig, | ||
| **kwargs, | ||
| ) -> None: | ||
|
|
@@ -190,6 +203,17 @@ def update_layer( | |
| # Actual trainable parameters | ||
| self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False) | ||
| self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=lora_bias) | ||
|
|
||
| # Tying adapters is only implemented for Linear layers | ||
| # where the source is the embedding layer. | ||
| # Currently, this is the most prevelant way of tying layers (weight tying) | ||
| if tied_adapter: | ||
| lora_A_params = tied_adapter["lora_A"] | ||
| lora_B_params = tied_adapter["lora_B"] | ||
|
|
||
| self.lora_A[adapter_name].weight = torch.nn.Parameter(lora_A_params) | ||
| self.lora_B[adapter_name].weight = torch.nn.Parameter(lora_B_params) | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| self.lora_bias[adapter_name] = lora_bias | ||
|
|
||
| if use_rslora: | ||
|
|
@@ -743,6 +767,16 @@ def __init__( | |
| adapter_name, | ||
| r, | ||
| lora_alpha=lora_alpha, | ||
| lora_dropout=lora_dropout, | ||
| init_lora_weights=init_lora_weights, | ||
| use_rslora=use_rslora, | ||
| use_dora=use_dora, | ||
| use_alora=use_alora, | ||
| lora_bias=lora_bias, | ||
| arrow_config=arrow_config, | ||
| tied_adapter=kwargs.pop("tied_adapter", None), | ||
| lora_ga_config=lora_ga_config, | ||
| use_bdlora=use_bdlora, | ||
|
Comment on lines
+770
to
+779
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove, merge artifact |
||
| config=config, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |||||||||||
|
|
||||||||||||
| import math | ||||||||||||
| import operator | ||||||||||||
| import re | ||||||||||||
| import warnings | ||||||||||||
| from contextlib import contextmanager | ||||||||||||
| from dataclasses import replace | ||||||||||||
|
|
@@ -27,11 +28,7 @@ | |||||||||||
| from torch import nn | ||||||||||||
|
|
||||||||||||
| from peft.import_utils import is_bnb_4bit_available, is_bnb_available | ||||||||||||
| from peft.tuners.tuners_utils import ( | ||||||||||||
| BaseTuner, | ||||||||||||
| BaseTunerLayer, | ||||||||||||
| replicate_layers, | ||||||||||||
| ) | ||||||||||||
| from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, find_parameter_name_by_tensor, replicate_layers | ||||||||||||
| from peft.utils import ( | ||||||||||||
| TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, | ||||||||||||
| AuxiliaryTrainingWrapper, | ||||||||||||
|
|
@@ -202,13 +199,25 @@ def _create_and_replace( | |||||||||||
| r = lora_config.rank_pattern.get(r_key, lora_config.r) | ||||||||||||
| alpha = lora_config.alpha_pattern.get(alpha_key, lora_config.lora_alpha) | ||||||||||||
|
|
||||||||||||
| # Checks if the target is marked as a tied layer | ||||||||||||
| # If true, we add the reference to lora adapters of embedding layer in `tied_adapter` | ||||||||||||
| is_tied = target_name in (getattr(lora_config, "target_modules_to_tie", []) or []) | ||||||||||||
romitjain marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
| tied_adapter = {} | ||||||||||||
| if is_tied: | ||||||||||||
| tied_module = self.model.get_input_embeddings() | ||||||||||||
| emb_A = tied_module.lora_embedding_A[adapter_name] | ||||||||||||
| emb_B = tied_module.lora_embedding_B[adapter_name] | ||||||||||||
|
|
||||||||||||
| tied_adapter = {"lora_A": emb_B.t(), "lora_B": emb_A.t()} | ||||||||||||
|
|
||||||||||||
| kwargs = { | ||||||||||||
| "r": r, | ||||||||||||
| "lora_alpha": alpha, | ||||||||||||
| "target_name": current_key, | ||||||||||||
| "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), | ||||||||||||
| "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), | ||||||||||||
| "parameter_name": parameter_name, | ||||||||||||
| "tied_adapter": tied_adapter, | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| # for torchao merging, we need the get_apply_tensor_subclass from the quantization config | ||||||||||||
|
|
@@ -249,6 +258,7 @@ def _create_and_replace( | |||||||||||
| if adapter_name not in self.active_adapters: | ||||||||||||
| # adding an additional adapter: it is not automatically trainable | ||||||||||||
| new_module.requires_grad_(False) | ||||||||||||
|
|
||||||||||||
| self._replace_module(parent, target_name, new_module, target) | ||||||||||||
|
|
||||||||||||
| def _replace_module(self, parent, child_name, new_module, child): | ||||||||||||
|
|
@@ -857,8 +867,86 @@ def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adap | |||||||||||
|
|
||||||||||||
| return tensors_lora | ||||||||||||
|
|
||||||||||||
| def _add_modules_to_tie(self, peft_config, tied_weight_keys): | ||||||||||||
| modules_to_save = set(getattr(peft_config, "modules_to_save", []) or []) | ||||||||||||
| missing_keys = set(tied_weight_keys) - modules_to_save | ||||||||||||
| def _add_modules_to_save_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]): | ||||||||||||
| """ | ||||||||||||
| Add embedding layer to `modules_to_save` and remove rest of the tied layers from `module_to_save`. Maintain a | ||||||||||||
| separate set for layers to be tied in `peft_config.tied_weights_keys`. | ||||||||||||
|
|
||||||||||||
| Args: | ||||||||||||
| peft_config (LoraConfig) -- The configuration of the Lora model. | ||||||||||||
| tied_weight_keys (list[str]) -- Contains the layers tied to the embedding layer. | ||||||||||||
| """ | ||||||||||||
| tied_weight_keys = set(tied_weight_keys) | ||||||||||||
| peft_config.modules_to_tie = tied_weight_keys | ||||||||||||
|
|
||||||||||||
| modules_to_save = getattr(peft_config, "modules_to_save", []) or [] | ||||||||||||
|
|
||||||||||||
| embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings()) | ||||||||||||
| # find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name | ||||||||||||
| if embed_layer_name.endswith(".weight"): | ||||||||||||
| embed_layer_name = embed_layer_name.removesuffix(".weight") | ||||||||||||
|
Comment on lines
+886
to
+887
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if it is ever the case that |
||||||||||||
| prefix, sep, suffix = embed_layer_name.partition(".") | ||||||||||||
| if sep and "model" in prefix: | ||||||||||||
| embed_layer_name = suffix | ||||||||||||
|
|
||||||||||||
| if embed_layer_name not in modules_to_save: | ||||||||||||
| modules_to_save.append(embed_layer_name) | ||||||||||||
|
Comment on lines
+892
to
+893
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that it's not a problem but a reader might wonder
Suggested change
|
||||||||||||
|
|
||||||||||||
| # Iterate over `tied_weight_keys` which are | ||||||||||||
| # fully qualified keys and remove matching keys from | ||||||||||||
| # `modules_to_save`. It will only remove first encounter | ||||||||||||
| # in `module_to_save`, which should be safe, because `tied_weight_keys` | ||||||||||||
| # is a unique set of keys | ||||||||||||
|
Comment on lines
+895
to
+899
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this comment but it would be vastly improved if would explain to what end we're doing this. Could you add the reason to this comment to inform the reader about why we're doing this? |
||||||||||||
| for key in tied_weight_keys: | ||||||||||||
| for m in modules_to_save: | ||||||||||||
| if re.match(rf"(^|.*\.){m}($|\..*)", key): | ||||||||||||
| modules_to_save.remove(m) | ||||||||||||
| break | ||||||||||||
|
|
||||||||||||
| peft_config.modules_to_save = modules_to_save | ||||||||||||
|
|
||||||||||||
| def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]): | ||||||||||||
| """ | ||||||||||||
| Add embedding layer to `target_modules` and remove rest of the tied layers from `target_modules`. Maintain a | ||||||||||||
| separate set for layers to be tied in `peft_config.target_modules_to_tie` | ||||||||||||
|
|
||||||||||||
| Args: | ||||||||||||
| peft_config (LoraConfig) -- The configuration of the Lora model. | ||||||||||||
| tied_weight_keys (list[str]) -- Contains the layers tied to the embedding layer. | ||||||||||||
| """ | ||||||||||||
| tied_weight_keys = set(tied_weight_keys) | ||||||||||||
| peft_config.target_modules_to_tie = tied_weight_keys | ||||||||||||
|
|
||||||||||||
| raw_target_modules = getattr(peft_config, "target_modules", None) | ||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BenjaminBossan Please review this logic. I know this is a bit hacky! I am open to suggestions
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm yeah, this is rough. We can't really operate on the string like this, as there are too many possible ways that the regex could be formed. I wonder if we should just leave it be and deal with the tied module edge case in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be possible, it would just make the flow very convoluted.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I redid this a bit. We just need to make sure that |
||||||||||||
|
|
||||||||||||
| embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings()) | ||||||||||||
| # find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name | ||||||||||||
| if embed_layer_name.endswith(".weight"): | ||||||||||||
| embed_layer_name = embed_layer_name.removesuffix(".weight") | ||||||||||||
|
Comment on lines
+924
to
+925
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above |
||||||||||||
| prefix, sep, suffix = embed_layer_name.partition(".") | ||||||||||||
| if sep and "model" in prefix: | ||||||||||||
| embed_layer_name = suffix | ||||||||||||
|
Comment on lines
+926
to
+928
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we split the fully-qualified name and make it a less precise name only to add it to the |
||||||||||||
|
|
||||||||||||
| if isinstance(raw_target_modules, str): | ||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that |
||||||||||||
| # The way weight tying is handled for adapters, we always want to add | ||||||||||||
| # lora adapters to the input embedding layer (embed_tokens) | ||||||||||||
| # instead of output embedding lauyer. | ||||||||||||
| raw_target_modules = rf"(?:{raw_target_modules}|.*{embed_layer_name}$)" | ||||||||||||
| peft_config.target_modules = raw_target_modules | ||||||||||||
| return | ||||||||||||
|
|
||||||||||||
| peft_config.modules_to_tie = missing_keys | ||||||||||||
| target_modules = set(raw_target_modules or []) | ||||||||||||
| target_modules.add(embed_layer_name) | ||||||||||||
|
|
||||||||||||
| # Iterate over `tied_weight_keys` which are | ||||||||||||
| # fully qualified keys and remove matching keys from | ||||||||||||
| # `target_modules`. It will only remove first encounter | ||||||||||||
| # in `target_modules`, which should be safe, because `tied_weight_keys` | ||||||||||||
|
Comment on lines
+941
to
+944
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above, misses the 'why' |
||||||||||||
| # is a unique set of keys | ||||||||||||
| for key in tied_weight_keys: | ||||||||||||
| for m in target_modules: | ||||||||||||
| if re.match(rf"(^|.*\.){m}($|\..*)", key): | ||||||||||||
| target_modules.remove(m) | ||||||||||||
| break | ||||||||||||
|
|
||||||||||||
| peft_config.target_modules = target_modules | ||||||||||||
Uh oh!
There was an error while loading. Please reload this page.