-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[WIP] Update LoraConfig
for KaSA implementation
#2698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nsbg
wants to merge
30
commits into
huggingface:main
Choose a base branch
from
nsbg:peft-kasa
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+145
−11
Open
Changes from 2 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
387ad2a
Add KaSA implementation to layer.py
nsbg ae00a34
Add `use_kasa` argument to LoraConfig
nsbg 6588f1a
Add use_kasa parameter to Linear class
nsbg a824ac9
Add KasaLinearVariant class (just copy of DoraLinearVariant class) in…
nsbg 05e4e07
Add kasa description
nsbg d1e7e43
Remove unnecessary self.kasa
nsbg 9e53b9f
[WIP] update KasaLinearVariant class with SVD implementation
nsbg aa37111
Modify merge/unmerge method in KasaLinearVariant class
nsbg 9cfe65c
update KasaLinearVariant class with SVD implementation
nsbg f9d7cc7
fix type in init method
nsbg 84813a3
delete unnecessary part in layer.py
nsbg 39abcad
add original reference in layer.py
nsbg 06f76d8
merge main to peft-kasa
nsbg 0043ae3
re-add KaSA implementation to variants.py
nsbg ea59432
add use_kasa param to resolve_lora_variant in other layers
nsbg 574c1b8
delete unnecessary part in layer.py
nsbg 73fa58f
delete unnecessary part in variants.py
nsbg b9e3190
add _get_delta_weight static method in KasaLinearVariants class
nsbg 2649fdf
update module.get_delta_weight to KasaLinearVariant._get_delta_weight…
nsbg d06cafd
add kasa test
nsbg 4ee5da1
add dropout
nsbg 0377170
Update tests/test_custom_models.py
nsbg a536bbf
Update src/peft/tuners/lora/variants.py
nsbg f8d8057
Update src/peft/tuners/lora/variants.py
nsbg cd57c7b
Update src/peft/tuners/lora/variants.py
nsbg 2431a2c
add use_kasa param in LoraModel class
nsbg 7276b3b
restore output_tensor variable in Linear class get_delta_weight method
nsbg 5a67b1f
add use_kasa handling condition in resolve_lora_variant method
nsbg 3ec6b18
fix KaSA self, mat1, mat2 dtype error
nsbg 461a89c
fix make style error
nsbg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,7 @@ | |
from peft.utils.other import transpose | ||
|
||
from .config import LoraConfig | ||
|
||
from .dora import DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer, _DoraConvNdLayer | ||
|
||
class LoraVariant: | ||
""" | ||
|
@@ -107,6 +107,9 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, * | |
self.cast_input_dtype_enabled: bool = True | ||
self.lora_variant: dict[str, LoraVariant] = {} | ||
self.kwargs = kwargs | ||
|
||
# Diag value | ||
self.lora_diag = nn.ParameterDict({}) | ||
|
||
base_layer = self.get_base_layer() | ||
if isinstance(base_layer, nn.Linear): | ||
|
@@ -217,7 +220,20 @@ def update_layer( | |
self.scaling[adapter_name] = lora_alpha / r | ||
|
||
self.use_dora[adapter_name] = use_dora | ||
|
||
############ kasa ############# | ||
self.lora_diag[adapter_name] = nn.Parameter(torch.randn(r), requires_grad=True) | ||
|
||
weight = self.get_base_layer().weight | ||
dtype = weight.dtype | ||
svd_rank = self.in_features - r | ||
weight = weight.to(torch.float32) | ||
U, S, Vh = torch.linalg.svd(weight.data, full_matrices=False) | ||
U_principle, S_principle, Vh_principle = U[:, :svd_rank], S[:svd_rank], Vh[:svd_rank, :] | ||
self.get_base_layer().weight.data = (U_principle @ torch.diag(S_principle) @ Vh_principle).to(dtype) | ||
|
||
######################### | ||
|
||
|
||
# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed | ||
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"): | ||
with gather_params_ctx(self.get_base_layer().weight): | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.