-
Notifications
You must be signed in to change notification settings - Fork 2.1k
minor changes to OFT to make it faster #2805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
700f64e
8317b9d
96f7699
806a425
da73c33
bc597cc
85ed7ec
29833f4
ed56c57
c3bc6e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -14,9 +14,12 @@ | |||||||||||||
|
||||||||||||||
from __future__ import annotations | ||||||||||||||
|
||||||||||||||
import warnings | ||||||||||||||
from dataclasses import dataclass, field | ||||||||||||||
from typing import Literal, Optional, Union | ||||||||||||||
|
||||||||||||||
import packaging.version | ||||||||||||||
|
||||||||||||||
from peft.config import PeftConfig | ||||||||||||||
from peft.utils import PeftType | ||||||||||||||
|
||||||||||||||
|
@@ -193,4 +196,14 @@ def check_kwargs(cls, **kwargs): | |||||||||||||
"with the latest version of OFT. Please retrain your adapter weights with newer PEFT versions. " | ||||||||||||||
"Alternatively, downgrade PEFT to version 0.13.0 to use the old adapter weights." | ||||||||||||||
) | ||||||||||||||
if kwargs.get("use_cayley_neumann", False): | ||||||||||||||
peft_version = kwargs.get("peft_version", "0.0.0") # if not present, set a low dummy version | ||||||||||||||
# remove commit hash, if present | ||||||||||||||
peft_version = peft_version.partition("@")[0] | ||||||||||||||
parsed_version = packaging.version.Version(peft_version) | ||||||||||||||
min_version = packaging.version.Version("0.18.0") | ||||||||||||||
# note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version | ||||||||||||||
if parsed_version < min_version: | ||||||||||||||
msg = "The cayley-neumann parameterization has been slightly changed to be more numerically stable in PEFT 0.18.0. Please retrain your adapter weights with newer PEFT versions. Alternatively, downgrade PEFT to version 0.17.0 to use the old parameterization." | ||||||||||||||
|
msg = "The cayley-neumann parameterization has been slightly changed to be more numerically stable in PEFT 0.18.0. Please retrain your adapter weights with newer PEFT versions. Alternatively, downgrade PEFT to version 0.17.0 to use the old parameterization." | |
msg = ( | |
"The cayley-neumann parameterization has been slightly changed to be more numerically stable in " | |
"PEFT 0.18.0. Please retrain your adapter weights with newer PEFT versions. Alternatively, " | |
"downgrade PEFT to version 0.17.0 to use the old parameterization." | |
) |
Let's add some line breaks to keep <= 120 chars.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, done
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,13 +82,15 @@ def __init__( | |
kernel_size=(0, 0), | ||
use_cayley_neumann=True, | ||
num_cayley_neumann_terms=5, | ||
device=None, | ||
): | ||
super().__init__() | ||
self.r = r | ||
self.n_elements = n_elements | ||
self.block_size = block_size | ||
self.in_features = in_features | ||
self.weight = nn.Parameter(torch.empty(r, n_elements)) | ||
self.weight = self.weight.to(device) | ||
self.coft = coft | ||
self.eps = eps | ||
self.block_share = block_share | ||
|
@@ -98,6 +100,8 @@ def __init__( | |
self.num_cayley_neumann_terms = num_cayley_neumann_terms | ||
# Create indices for upper triangle (excluding diagonal) | ||
self.rows, self.cols = torch.triu_indices(block_size, block_size, 1) | ||
self.rows = self.rows.to(device) | ||
self.cols = self.cols.to(device) | ||
|
||
def _pytorch_skew_symmetric(self, vec, block_size): | ||
batch_size = vec.shape[0] | ||
|
@@ -139,9 +143,11 @@ def _cayley_batch( | |
R.add_(Q_squared, alpha=2.0) | ||
|
||
Q_power = Q_squared | ||
for i in range(3, num_neumann_terms): | ||
for _ in range(3, num_neumann_terms - 1): | ||
Q_power = torch.bmm(Q_power, Q_skew) | ||
R.add_(Q_power, alpha=2.0) | ||
Q_power = torch.bmm(Q_power, Q_skew) | ||
R.add_(Q_power) | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
id_mat = ( | ||
torch.eye(Q_skew.shape[-1], device=Q_skew.device) | ||
|
@@ -470,6 +476,7 @@ def update_layer( | |
block_share=block_share, | ||
use_cayley_neumann=use_cayley_neumann, | ||
num_cayley_neumann_terms=num_cayley_neumann_terms, | ||
device=self.get_base_layer().weight.device, | ||
|
||
) | ||
|
||
# Initialize weights | ||
|
@@ -770,6 +777,7 @@ def update_layer( | |
kernel_size=base_layer.kernel_size, | ||
use_cayley_neumann=use_cayley_neumann, | ||
num_cayley_neumann_terms=num_cayley_neumann_terms, | ||
device=self.get_base_layer().weight.device, | ||
|
||
) | ||
|
||
# Initialize weights | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -14,6 +14,7 @@ | |||||
|
||||||
import copy | ||||||
import itertools | ||||||
import json | ||||||
import math | ||||||
import platform | ||||||
import re | ||||||
|
@@ -43,6 +44,7 @@ | |||||
LoftQConfig, | ||||||
LoKrConfig, | ||||||
LoraConfig, | ||||||
OFTConfig, | ||||||
PeftMixedModel, | ||||||
PeftModel, | ||||||
PeftModelForCausalLM, | ||||||
|
@@ -1756,6 +1758,49 @@ def test_vblora_with_incompatible_vector_length_with_out_features(self): | |||||
get_peft_model(model, config) | ||||||
|
||||||
|
||||||
class TestOft: | ||||||
torch_device = infer_device() | ||||||
|
||||||
def get_model(self, bias=True): | ||||||
class MyModule(nn.Module): | ||||||
def __init__(self): | ||||||
super().__init__() | ||||||
self.lin = nn.Linear(32, 32) | ||||||
|
||||||
return MyModule().eval().to(self.torch_device) | ||||||
|
||||||
@pytest.mark.parametrize("peft_version", ["0.17.0", "0.18.0", None]) | ||||||
def test_load_outdated_oft_checkpoint_warns(self, peft_version, tmp_path, recwarn): | ||||||
# In PEFT v0.18.0, there was a small change in the OFT implementation with Cayley-Neumann enabled. As the | ||||||
# outputs change slightly, users need to be warned about it if the checkpoint stems from a PEFT version below | ||||||
# 0.18.0. When the 'peft_version' key is not in the config, it means that the version is below 0.18.0. | ||||||
config = OFTConfig(target_modules=["lin"], use_cayley_neumann=True) # only relevant when using Cayley-Neumann | ||||||
model = get_peft_model(self.get_model(), config) | ||||||
model.save_pretrained(tmp_path) | ||||||
del model | ||||||
|
||||||
# overwrite the peft_version | ||||||
with open(tmp_path / "adapter_config.json") as f: | ||||||
config_json = json.load(f) | ||||||
|
||||||
if peft_version is None: | ||||||
del config_json["peft_version"] | ||||||
|
del config_json["peft_version"] | |
config_json.pop("peft_version", None) |
In case the key does not exist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please replace it with the correct error message (start of the message is fine).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.