-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[WIP] Update LoraConfig
for KaSA implementation
#2698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 21 commits
387ad2a
ae00a34
6588f1a
a824ac9
05e4e07
d1e7e43
9e53b9f
aa37111
9cfe65c
f9d7cc7
84813a3
39abcad
06f76d8
0043ae3
ea59432
574c1b8
73fa58f
b9e3190
2649fdf
d06cafd
4ee5da1
0377170
a536bbf
f8d8057
cd57c7b
2431a2c
7276b3b
5a67b1f
3ec6b18
461a89c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -439,6 +439,82 @@ def init(module: Conv3d, adapter_name: str, **kwargs: Any) -> None: | |
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer) | ||
|
||
|
||
class KasaLinearVariant(LoraVariant): | ||
@staticmethod | ||
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None: | ||
if not module.lora_diag: | ||
nsbg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_diag",) | ||
|
||
# Initialize lora_diag | ||
module.lora_diag[adapter_name] = nn.Parameter(torch.randn(module.r[adapter_name]), requires_grad=True) | ||
|
||
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L132 | ||
|
||
# SVD | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add a reference here, so that we know the origin: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I put it in here, how is it? |
||
weight = module.get_base_layer().weight | ||
dtype = weight.dtype | ||
svd_rank = module.in_features - module.r[adapter_name] | ||
weight = weight.to(torch.float32) | ||
U, S, Vh = torch.linalg.svd(weight.data, full_matrices=False) | ||
U_principle, S_principle, Vh_principle = U[:, :svd_rank], S[:svd_rank], Vh[:svd_rank, :] | ||
module.get_base_layer().weight.data = (U_principle @ torch.diag(S_principle) @ Vh_principle).to(dtype) | ||
|
||
@staticmethod | ||
def _get_delta_weight(weight_A, weight_B, lora_diag, scaling, fan_in_fan_out): | ||
diag = torch.diag(lora_diag) | ||
delta = weight_B @ diag @ weight_A | ||
if fan_in_fan_out: | ||
delta = delta.transpose(0, 1) | ||
delta = delta * scaling | ||
return delta | ||
|
||
@staticmethod | ||
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: | ||
delta_weight = KasaLinearVariant._get_delta_weight( | ||
module.lora_A[active_adapter].weight, | ||
module.lora_B[active_adapter].weight, | ||
module.lora_diag[active_adapter], | ||
module.scaling[active_adapter], | ||
module.fan_in_fan_out | ||
) | ||
return orig_weight + delta_weight | ||
|
||
@staticmethod | ||
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None: | ||
delta_weight = KasaLinearVariant._get_delta_weight( | ||
module.lora_A[active_adapter].weight, | ||
module.lora_B[active_adapter].weight, | ||
module.lora_diag[active_adapter], | ||
module.scaling[active_adapter], | ||
module.fan_in_fan_out, | ||
) | ||
orig_weight.data += delta_weight | ||
|
||
@staticmethod | ||
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: | ||
delta_weight = KasaLinearVariant._get_delta_weight( | ||
module.lora_A[active_adapter].weight, | ||
module.lora_B[active_adapter].weight, | ||
module.lora_diag[active_adapter], | ||
module.scaling[active_adapter], | ||
module.fan_in_fan_out, | ||
) | ||
return orig_weight - delta_weight | ||
|
||
@staticmethod | ||
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor: | ||
nsbg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
lora_A = module.lora_A[active_adapter] | ||
lora_B = module.lora_B[active_adapter] | ||
dropout = module.lora_dropout[active_adapter] | ||
scaling = module.scaling[active_adapter] | ||
diag = torch.diag(module.lora_diag[active_adapter]) | ||
|
||
# KaSA calculation | ||
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L602C21-L602C110 | ||
lora_output = lora_B(torch.einsum("ijk,kl->ijl", lora_A(dropout(x)), diag)) * scaling | ||
nsbg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
return result + lora_output | ||
|
||
|
||
class QALoraLinearVariant(LoraVariant): | ||
@staticmethod | ||
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None: | ||
|
Uh oh!
There was an error while loading. Please reload this page.