-
Notifications
You must be signed in to change notification settings - Fork 2.1k
minor changes to OFT to make it faster #2805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
700f64e
8317b9d
96f7699
806a425
da73c33
bc597cc
85ed7ec
29833f4
ed56c57
c3bc6e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -14,9 +14,12 @@ | |||||||||||||||||||||
|
||||||||||||||||||||||
from __future__ import annotations | ||||||||||||||||||||||
|
||||||||||||||||||||||
import warnings | ||||||||||||||||||||||
from dataclasses import dataclass, field | ||||||||||||||||||||||
from typing import Literal, Optional, Union | ||||||||||||||||||||||
|
||||||||||||||||||||||
import packaging.version | ||||||||||||||||||||||
|
||||||||||||||||||||||
from peft.config import PeftConfig | ||||||||||||||||||||||
from peft.utils import PeftType | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
@@ -193,4 +196,12 @@ def check_kwargs(cls, **kwargs): | |||||||||||||||||||||
"with the latest version of OFT. Please retrain your adapter weights with newer PEFT versions. " | ||||||||||||||||||||||
"Alternatively, downgrade PEFT to version 0.13.0 to use the old adapter weights." | ||||||||||||||||||||||
) | ||||||||||||||||||||||
if kwargs["use_caylay_neumann"]: | ||||||||||||||||||||||
peft_version = kwargs.get("peft_version", "unknown") | ||||||||||||||||||||||
parsed_version = packaging.version.Version(peft_version) | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||
min_version = packaging.version.Version("0.18.0") | ||||||||||||||||||||||
# note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version | ||||||||||||||||||||||
if (peft_version == "unknown") or (parsed_version < min_version): | ||||||||||||||||||||||
|
peft_version = kwargs.get("peft_version", "unknown") | |
parsed_version = packaging.version.Version(peft_version) | |
min_version = packaging.version.Version("0.18.0") | |
# note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version | |
if (peft_version == "unknown") or (parsed_version < min_version): | |
peft_version = kwargs.get("peft_version", "0.0.0") # if not present, set a low dummy version | |
parsed_version = packaging.version.Version(peft_version) | |
min_version = packaging.version.Version("0.18.0") | |
# note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version | |
if parsed_version < min_version: |
I just tested and parsing with "unknown"
throws an error, so let's put a dummy value here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, this message was just a placeholder :-D Could you please add a short explanation here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure:)
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -82,13 +82,15 @@ def __init__( | |||||
kernel_size=(0, 0), | ||||||
use_cayley_neumann=True, | ||||||
num_cayley_neumann_terms=5, | ||||||
device=None, | ||||||
): | ||||||
super().__init__() | ||||||
self.r = r | ||||||
self.n_elements = n_elements | ||||||
self.block_size = block_size | ||||||
self.in_features = in_features | ||||||
self.weight = nn.Parameter(torch.empty(r, n_elements)) | ||||||
self.weight.to(device) | ||||||
self.coft = coft | ||||||
self.eps = eps | ||||||
self.block_share = block_share | ||||||
|
@@ -98,6 +100,8 @@ def __init__( | |||||
self.num_cayley_neumann_terms = num_cayley_neumann_terms | ||||||
# Create indices for upper triangle (excluding diagonal) | ||||||
self.rows, self.cols = torch.triu_indices(block_size, block_size, 1) | ||||||
self.rows.to(device) | ||||||
self.cols.to(device) | ||||||
|
||||||
|
||||||
def _pytorch_skew_symmetric(self, vec, block_size): | ||||||
batch_size = vec.shape[0] | ||||||
|
@@ -139,9 +143,11 @@ def _cayley_batch( | |||||
R.add_(Q_squared, alpha=2.0) | ||||||
|
||||||
Q_power = Q_squared | ||||||
for i in range(3, num_neumann_terms): | ||||||
for _ in range(3, num_neumann_terms - 1): | ||||||
Q_power = torch.bmm(Q_power, Q_skew) | ||||||
R.add_(Q_power, alpha=2.0) | ||||||
Q_power = torch.bmm(Q_power, Q_skew) | ||||||
R.add_(Q_power) | ||||||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
else: | ||||||
id_mat = ( | ||||||
torch.eye(Q_skew.shape[-1], device=Q_skew.device) | ||||||
|
@@ -470,6 +476,7 @@ def update_layer( | |||||
block_share=block_share, | ||||||
use_cayley_neumann=use_cayley_neumann, | ||||||
num_cayley_neumann_terms=num_cayley_neumann_terms, | ||||||
device=self.weight.device, | ||||||
|
device=self.weight.device, | |
device=self.get_base_layer().weight.device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly more robust:
device=self.weight.device, | |
device=self.get_base_layer().weight.device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this more backwards compatible, also, there was a typo:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.