-
Notifications
You must be signed in to change notification settings - Fork 2.2k
[WIP] Update LoraConfig for KaSA implementation
#2698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for resuming your work on KaSA.
Implementation-wise, we need to take a different approach. Right now, KaSA is just added to the normal LoRA code, but we only want to activate it if the user opts in. Therefore, it should be implemented in a separate class, something like KasaVariant, in peft/tuners/lora/variants.py. Please check how DoRA is implemented and use a similar approach, as I have detailed in my previous comment. If anything is unclear, feel free to ask.
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
|
gentle ping @NSBG |
|
Thank you for your alert! I spent some time looking over the KaSA paper and code to get ready for more serious work, but it does seem pretty difficult 🥲 My goal is to upload code that's ready for review before the end of September, so I'm going to try even harder. Right now, I'm stuck at the 'Extend LoRA variant resolution' stage you mentioned. Honestly, this seems like the most important part, but it's hard for me to figure out where to start—specifically, which file and class I should work on first. Could you help me with this? |
|
That's great to see, thanks for picking this back up.
You're already on the right track, you added Next about resolving the variants. As a first step, let's revert the changes you made to Then let's look at these lines in peft/src/peft/tuners/lora/layer.py Lines 636 to 642 in a3197b1
Here we need to extend the functionality to add KaSA. The updated method could be something like: def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
if use_dora and use_kasa:
raise ValueError("Cannot use DoRA and KaSA at the same time, please choose only one.")
variant = None
if use_dora:
from .variants import DoraLinearVariant
variant = DoraLinearVariant()
elif use_kasa:
...
return variantDoes that make sense? Similarly, we'd have to update the I would suggest that you work on this as a next step, then we'll see what else needs to be done. |
|
wow I really appreciate your sincere feedback. I'll read your advice carefully and then move forward 🤗 |
|
@BenjaminBossan I modified the code in the files below based on what you explained. Please give me feedback if there are parts that still need fixing, and then we can discuss the next steps. 1. variants.py
2. layer.py
|
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for integrating my feedback. I gave this another review and noted the next few changes that are necessary. Please check my comments.
Apart from this, the branch is now encountering merge conflicts. Could you please bring your fork up-to-date with the remote and then merge with, or rebase on, the latest main branch from PEFT? If you have questions on how to resolve the merge conflicts, don't hesitate to ask.
Furthermore, please always run make style on your changes before pushing to make our linter happy.
More of a note for myself: Since KaSA updates the base weights of the model, we will have to take extra care to ensure that it works correctly when saving and loading the adapter.
src/peft/tuners/lora/layer.py
Outdated
|
|
||
| """ | ||
| return None | ||
| if use_dora and use_kasa: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's undo the changes in this method body and return None. Instead, since this KaSA layer is implemented for Linear only, add the logic to lora.Linear.resolve_lora_variant instead.
Also, we should update the resolve_lora_variant methods of the other layer types like lora.Embedding.resolve_lora_variant to accept the use_kasa argument but raise an error if it's True. Otherwise, users may add it to non-supported layers and not notice that it doesn't actually do anything there.
src/peft/tuners/lora/layer.py
Outdated
| ############ kasa ############# | ||
| self.lora_diag[adapter_name] = nn.Parameter(torch.randn(r), requires_grad=True) | ||
|
|
||
| weight = self.get_base_layer().weight | ||
| dtype = weight.dtype | ||
| svd_rank = self.in_features - r | ||
| weight = weight.to(torch.float32) | ||
| U, S, Vh = torch.linalg.svd(weight.data, full_matrices=False) | ||
| U_principle, S_principle, Vh_principle = U[:, :svd_rank], S[:svd_rank], Vh[:svd_rank, :] | ||
| self.get_base_layer().weight.data = (U_principle @ torch.diag(S_principle) @ Vh_principle).to(dtype) | ||
|
|
||
| ######################### |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of this can be removed, since it's part of KasaLinearVariant.init, right?
src/peft/tuners/lora/variants.py
Outdated
| # initialize lora_diag | ||
| module.lora_diag[adapter_name] = nn.Parameter(torch.randn(module.r[adapter_name]), requires_grad=True) | ||
|
|
||
| # SVD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a reference here, so that we know the origin:
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L132
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# initialize lora_diag
module.lora_diag[adapter_name] = nn.Parameter(torch.randn(module.r[adapter_name]), requires_grad=True)
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L132
# SVD
I put it in here, how is it?
| @staticmethod | ||
| def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: | ||
| delta_weight = module.get_delta_weight(active_adapter) | ||
| return orig_weight + delta_weight | ||
|
|
||
| @staticmethod | ||
| def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None: | ||
| delta_weight = module.get_delta_weight(active_adapter) | ||
| orig_weight.data += delta_weight | ||
|
|
||
| @staticmethod | ||
| def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: | ||
| delta_weight = module.get_delta_weight(active_adapter) | ||
| return orig_weight - delta_weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
KaSA should have an influence on the merged weights, should it not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although this PR is closed, it seems I've incorporated everything else except for this comment (of course, you'd have to look at the code). Could you explain this question in more detail?
src/peft/tuners/lora/variants.py
Outdated
| x = dropout(x) | ||
|
|
||
| # KaSA calculation | ||
| lora_output = lora_B(torch.einsum('ijk,kl->ijl', lora_A(x), diag)) * scaling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, let's add a reference:
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L602C21-L602C110
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# KaSA calculation
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L602C21-L602C110
lora_output = lora_B(torch.einsum('ijk,kl->ijl', lora_A(x), diag)) * scaling
return result + lora_output
I inserted this near where the actual calculation logic begins, rather than just in an empty space. I think this is a bit better.
39abcad to
20a9829
Compare
|
@BenjaminBossan oh I didn't mean to close the branch, but it seems to have closed while I was merging with the main branch. I guess I'll have to open a new PR, right? 😰 +) when I tried to sync with the main branch, I ended up discarding all my commits, so did that cause it to close? |
I don't know what happened, but I could re-open the PR and there are some changes visible. Can you double check that everything looks as expected? If for some reason it's not what it's expected, you can create a new PR and push your local branch. |
|
I usually handle merges in the terminal, and I suspect the pull request was closed because I accidentally wiped the commit history while using the 'Sync fork' feature on GitHub. I'll be more careful in the future. Thanks for reopening it. I'll review the changes and open a new PR if needed. Sorry to keep bothering you with this. |
No worries. If the diff on this PR looks good, let me know and I'll do a review. Only open a new PR if for some reason, the code here does not correspond to what it should be. |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
|
Check |
…apter types, enhancing compatibility checks in the initialization process.
…ve readability in LoraModel class.
…re SVD is applied only once, while also cleaning up whitespace in multiple locations.
|
I've addressed the points you mentioned, applied Regarding the SVD value caching, I gave it some thought and realized I was stuck on the idea that 'caching is always efficient.' Since the base weights are already updated in the first adapter even when using multiple KaSA adapters, I realized we can simply reuse those values subsequently. So, I modified the code to skip the calculation as you suggested. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the new updates. We just merged another LoRA variant, which created merge conflicts with your PR, but it should be easy to resolve. Could you please take care? Thanks.
tests/test_initialization.py
Outdated
| config1 = LoraConfig( | ||
| r=8, | ||
| target_modules=["linear"], | ||
| init_lora_weights=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove this line, as it's irrelevant.
tests/test_initialization.py
Outdated
| config2 = LoraConfig( | ||
| r=16, | ||
| target_modules=["linear"], | ||
| init_lora_weights=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove this line, as it's irrelevant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# src/peft/tuners/lora/model.py
if len(self.peft_config) > 1:
kasa_count = sum(1 for cfg in self.peft_config.values() if cfg.use_kasa)
non_kasa_count = len(self.peft_config) - kasa_count
if kasa_count > 0 and non_kasa_count > 0:
raise ValueError("KaSA adapters cannot be mixed with other adapter types.")I understood this to mean that since it's handled in this section, it's irrelevant elsewhere. Is my understanding correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, this was a misunderstanding. I meant that the single line I commented on (init_lora_weights=True,) can be removed, the test as a whole is good to keep :) Please restore these tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah okay haha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the tests back :) !
…tLoraInitialization, simplifying the test suite and focusing on essential compatibility checks.
|
I applied what you mentioned and resolvd conflicts. Please take a look! |
…dapter types in TestLoraInitialization, ensuring compatibility checks are enforced in both configurations.
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR is close to the finish line. I found a small issue, please check. Also, once ready to commit, please call make style.
src/peft/tuners/lora/model.py
Outdated
| if (len(self.peft_config) > 1) and (config.bias != "none"): | ||
| raise ValueError( | ||
| f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " | ||
| "set bias to 'none' for all adapters." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this and call super()._check_new_adapter_config(config) instead.
…class method, improving code clarity and ensuring consistent behavior across adapter types.
|
Is this the final step? Please let me know if there's anything else needed. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@iambogeumkim Could you please run |
…sting formatting and line breaks in LoraLayer class.
|
I did run Also, thank you for your patience with all my questions, even the trivial ones. I know I might have been a bit of a bother 😅 Since this was my first code contribution, I learned so much thanks to your guidance. Wishing you a warm and happy holiday season! |
Something doesn't seem to work right, as the formatter is still complaining. These changes should resolve it: modified src/peft/tuners/lora/config.py
@@ -764,8 +764,9 @@ class LoraConfig(PeftConfig):
"singular value decomposition (SVD) with knowledge-aware singular values to dynamically "
"activate parametric knowledge according to its relevance to downstream tasks."
)
- }
+ },
)
+
def to_dict(self):
"""
Returns the configuration for your adapter model as a dictionary. Removes runtime configurations.
modified tests/test_custom_models.py
@@ -1265,10 +1265,12 @@ def _skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config
if (config_cls == LoraConfig) and config_kwargs.get("target_parameters"):
pytest.skip("LoRA with multiple adapters with target_parameters is not supported")
+
def _skip_test_disable_adapters(config_cls, config_kwargs):
if (config_cls == LoraConfig) and config_kwargs.get("use_kasa"):
pytest.skip("KaSA modifies base weights, so adapter disable test is skipped")
+
class MLP(nn.Module):
def __init__(self, bias=True):
super().__init__()
Don't worry, it's always the first time for someone. Happy to hear that you learned a lot. |
…le configurations with multiple adapters, enhancing clarity and maintainability.
|
I double-checked if there were any unpushed files related to KaSA. Aside from those two files, everything seems to be pushed, so it should be ready to be merged now. |
|
@iambogeumkim There are a bunch of failing tests because |
…d adapter configuration support.
|
I’ve updated |
|
Thanks for the latest changes. There are still some errors, this time caused by X-LoRA. I checked and the issue there is that X-LoRA models can have PEFT configs that contain both normal LoRA and X-LoRA configs. Since X-LoRA configs don't have
It's a bit of an edge case, but let's add |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
|
not stale |
cc @BenjaminBossan
I was delayed in updating the code because I was focusing on company work, but now I'm planning to resume the project in earnest. If I have any questions about implementing the code, may I continue to ask you?
I apologize for opening a new pull request, as the previous one was closed 🥲 Thank you for your understanding.