-
Notifications
You must be signed in to change notification settings - Fork 2.1k
FEAT Allow LoRA to target nn.Parameter #2638
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT Allow LoRA to target nn.Parameter #2638
Conversation
Normally, nn.Parameter cannot be targeted with LoRA adapters. This can be problematic, e.g. when there are MoE layers that use nn.Parameter directly, or when there is nn.Linear but the weight is passed directly instead of calling forward (e.g. MHA). It would be possible to craft a solution involving a special LoRA layer for each of the modules that use nn.Parameter directly (e.g. lora.MHA) but that doesn't scale. This PR is an attempt at implementing a direct way to target nn.Parameter making use of torch.nn.parametrize. The current state of the PR is WIP, the next step is to add a dispatching mechanism. This is not trivial, as we don't want the new changes to accidentally affect the current matching logic. Probably the best way is to add a completely new config variable (e.g. target_parameters) that does not interfere with the current code.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super cool feature :) Implementation looks good as well, some questions/remarks in the comments
Co-authored-by: githubnemo <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
The test tests/test_target_parameters.py::TestDecoderModelsTargetParameters::test_merge_layers[LoraConfig-config_kwargs2-trl-internal-testing/tiny-Llama4ForCausalLM] is failing for me locally with somewhat large numerical differences in the expected outputs. I'm not quite sure why that is. This test involves a mixture of normal LoRA and LoRA targeting nn.Parameter and then merging. Possibly this is a bug or possibly this just requires higher tolerance, I'll investigate later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reviews @githubnemo and @qgallouedec. The comments should be addressed. I added another test config involving a mixture of normal LoRA and LoRA targeting nn.Parameter and then merging. This results in one test,
tests/test_target_parameters.py::TestDecoderModelsTargetParameters::test_merge_layers[LoraConfig-config_kwargs2-trl-internal-testing/tiny-Llama4ForCausalLM]
to fail for me locally with somewhat large numerical differences in the expected outputs. I'm not quite sure why that is. Possibly this is a bug or possibly this just requires higher tolerance, I'll investigate later.
if any(p.device == meta for p in adapter_layer.parameters()): | ||
continue | ||
|
||
# TODO: weight is not necessarily defined here, leading to a NameError, fix that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: This is an existing bug and has nothing to do with the PR, just flagging it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the swift implementation :)
This issue was found in PR huggingface#2638 and is defined thusly: > When calling `get_peft_model_state_dict(..., save_embedding_layers="auto")` we check if the > embedding layer is targetted to determine if the embedding layers need saving. This is not > done when `PeftConfig.target_modules` is a regex-string, potentially missing to save embeddings. This is fixed by adding a check similar to the existing query of whether `EMBEDDING_LAYER_NAMES` is a subset of the defined target modules, only that the regex matching from `BaseTuner.inject_adapter` is used. To avoid code duplication, the matching was moved to its own utility function `match_target_against_key`. The main complication was to define the test-cases as it was non-trivial to find what the meaning of `save_embedding_layers="auto"` entails. I've assembled a list of cases that I think are correct in the corresponding unit test.
When the target_parameters feature for LoRA was introduced in huggingface#2638, there was one gap, namely the possibility to target multiple nn.Parameters on the same module. (There was only a workaroud involving multiple adapters, but that is not user friendly.) With this PR, it is now possible to achieve this. The mechanism to enable this is a bit crude, namely allowing to nest multiple ParamWrappers. This should generally be fine as long as there are only a couple of nn.Parameters being targeted on the same module. When there are dozens or hundreds, this approach could load to slow downs or other issues. A side effect of this implementation is that the ParamWrapper, when it removes the parametrization, now only removes its own parametrization. When using nn.utils.parametrize.remove_parametrization, it removes all parametrizations, which is bad when we have nested parametrizations.
When the target_parameters feature for LoRA was introduced in #2638, there was one gap, namely the possibility to target multiple nn.Parameters on the same module (there was only a workaround involving multiple adapters, but that is not user friendly). With this PR, it is now possible to achieve this. The mechanism to enable this is a bit crude, namely allowing to nest multiple ParamWrappers. This should generally be fine as long as there are only a couple of nn.Parameters being targeted on the same module. When there are dozens or hundreds, this approach could load to slow downs or other issues. A side effect of this implementation is that the ParamWrapper, when it removes the parametrization, now only removes its own parametrization. When using nn.utils.parametrize.remove_parametrization, it removes all parametrizations, which is bad when we have nested parametrizations. Alternative approaches Some alternative approaches were discussed internally but the chosen one was considered most practical. Allow to have more than one adapted parameter per LoRA layer. This would require to have nested dicts for the LoRA parameters, something like self.lora_A[adapter_name][parameter_name]. We don't have this anywhere so far and it would probably break implicit assumptions about PEFT layers in many places (like, parsing of state_dict keys), requiring many adjustments. Have an auxiliary module that contains the individual LoRA layers that target the individual parameters. This could be the cleanest solution and would probably be more efficient if there are a huge number of targeted parameters per module. However, this also brings extra complexity, as it requires implementing the logic of how to route the information to the right parameter, and it may be a solution to a problem that is irrelevant in practice (large number of targets per module).
Normally, nn.Parameter cannot be targeted with LoRA adapters. This can be problematic, e.g. when there are MoE layers that use nn.Parameter directly, or when there is nn.Linear but the weight is passed directly instead of calling forward (e.g. MHA). It would be possible to craft a solution involving a special LoRA layer for each of the modules that use nn.Parameter directly (e.g. lora.MHA) but that doesn't scale. This PR is implements a direct way to target nn.Parameter making use of torch.nn.utils.parametrize. Using the feature requires passing target_parameters to the LoraConfig. During the forward pass, when the parameter is acceessed, the LoRA weights are added to the weights while still ensuring that gradients flow correctly to the LoRA weights. Right now, only LoRA supports this feature. Moreover, it is not possible to target multiple parameters of the same module with the same adapter. A workaround is to use multiple adapters (i.e. with different names). --------- Co-authored-by: githubnemo <[email protected]>
When the target_parameters feature for LoRA was introduced in huggingface#2638, there was one gap, namely the possibility to target multiple nn.Parameters on the same module (there was only a workaround involving multiple adapters, but that is not user friendly). With this PR, it is now possible to achieve this. The mechanism to enable this is a bit crude, namely allowing to nest multiple ParamWrappers. This should generally be fine as long as there are only a couple of nn.Parameters being targeted on the same module. When there are dozens or hundreds, this approach could load to slow downs or other issues. A side effect of this implementation is that the ParamWrapper, when it removes the parametrization, now only removes its own parametrization. When using nn.utils.parametrize.remove_parametrization, it removes all parametrizations, which is bad when we have nested parametrizations. Alternative approaches Some alternative approaches were discussed internally but the chosen one was considered most practical. Allow to have more than one adapted parameter per LoRA layer. This would require to have nested dicts for the LoRA parameters, something like self.lora_A[adapter_name][parameter_name]. We don't have this anywhere so far and it would probably break implicit assumptions about PEFT layers in many places (like, parsing of state_dict keys), requiring many adjustments. Have an auxiliary module that contains the individual LoRA layers that target the individual parameters. This could be the cleanest solution and would probably be more efficient if there are a huge number of targeted parameters per module. However, this also brings extra complexity, as it requires implementing the logic of how to route the information to the right parameter, and it may be a solution to a problem that is irrelevant in practice (large number of targets per module).
This issue was found in PR #2638 and is defined thusly: > When calling `get_peft_model_state_dict(..., save_embedding_layers="auto")` we check if the > embedding layer is targetted to determine if the embedding layers need saving. This is not > done when `PeftConfig.target_modules` is a regex-string, potentially missing to save embeddings. This is fixed by adding a check similar to the existing query of whether `EMBEDDING_LAYER_NAMES` is a subset of the defined target modules, only that the regex matching from `BaseTuner.inject_adapter` is used. To avoid code duplication, the matching was moved to its own utility function `match_target_against_key`. The main complication was to define the test-cases as it was non-trivial to find what the meaning of `save_embedding_layers="auto"` entails. I've assembled a list of cases that I think are correct in the corresponding unit test.
Normally,
nn.Parameter
cannot be targeted with LoRA adapters. This can be problematic, e.g. when there are MoE layers that usenn.Parameter
directly, or when there isnn.Linear
but the weight is passed directly instead of calling forward (e.g. MHA).It would be possible to craft a solution involving a special LoRA layer for each of the modules that use
nn.Parameter
directly (e.g.lora.MHA
) but that doesn't scale. This PR is an attempt at implementing a direct way to targetnn.Parameter
making use oftorch.nn.utils.parametrize
.