Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/peft/tuners/oft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

from __future__ import annotations

import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional, Union

import packaging.version

from peft.config import PeftConfig
from peft.utils import PeftType

Expand Down Expand Up @@ -193,4 +196,14 @@ def check_kwargs(cls, **kwargs):
"with the latest version of OFT. Please retrain your adapter weights with newer PEFT versions. "
"Alternatively, downgrade PEFT to version 0.13.0 to use the old adapter weights."
)
if kwargs.get("use_cayley_neumann", False):
peft_version = kwargs.get("peft_version", "0.0.0") # if not present, set a low dummy version
# remove commit hash, if present
peft_version = peft_version.partition("@")[0]
parsed_version = packaging.version.Version(peft_version)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parsed_version = packaging.version.Version(peft_version)
# remove commit hash, if present
peft_version = peft_version.partition("@")[0]
parsed_version = packaging.version.Version(peft_version)

min_version = packaging.version.Version("0.18.0")
# note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version
if parsed_version < min_version:
msg = "The cayley-neumann parameterization has been slightly changed to be more numerically stable in PEFT 0.18.0. Please retrain your adapter weights with newer PEFT versions. Alternatively, downgrade PEFT to version 0.17.0 to use the old parameterization."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
msg = "The cayley-neumann parameterization has been slightly changed to be more numerically stable in PEFT 0.18.0. Please retrain your adapter weights with newer PEFT versions. Alternatively, downgrade PEFT to version 0.17.0 to use the old parameterization."
msg = (
"The cayley-neumann parameterization has been slightly changed to be more numerically stable in "
"PEFT 0.18.0. Please retrain your adapter weights with newer PEFT versions. Alternatively, "
"downgrade PEFT to version 0.17.0 to use the old parameterization."
)

Let's add some line breaks to keep <= 120 chars.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, done

warnings.warn(msg)
return super().check_kwargs(**kwargs)
10 changes: 7 additions & 3 deletions src/peft/tuners/oft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def __init__(
self.use_cayley_neumann = use_cayley_neumann
self.num_cayley_neumann_terms = num_cayley_neumann_terms
# Create indices for upper triangle (excluding diagonal)
self.rows, self.cols = torch.triu_indices(block_size, block_size, 1)
rows, cols = torch.triu_indices(block_size, block_size, 1)
self.register_buffer('rows', rows, persistent=False)
self.register_buffer('cols', cols, persistent=False)

def _pytorch_skew_symmetric(self, vec, block_size):
batch_size = vec.shape[0]
Expand Down Expand Up @@ -139,9 +141,11 @@ def _cayley_batch(
R.add_(Q_squared, alpha=2.0)

Q_power = Q_squared
for i in range(3, num_neumann_terms):
for _ in range(3, num_neumann_terms - 1):
Q_power = torch.bmm(Q_power, Q_skew)
R.add_(Q_power, alpha=2.0)
Q_power = torch.bmm(Q_power, Q_skew)
R.add_(Q_power)
else:
id_mat = (
torch.eye(Q_skew.shape[-1], device=Q_skew.device)
Expand Down Expand Up @@ -935,4 +939,4 @@ def dispatch_default(
kwargs["fan_in_fan_out"] = oft_config.fan_in_fan_out = False
new_module = Linear(target, adapter_name, **kwargs)

return new_module
return new_module
45 changes: 45 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
import itertools
import json
import math
import platform
import re
Expand Down Expand Up @@ -43,6 +44,7 @@
LoftQConfig,
LoKrConfig,
LoraConfig,
OFTConfig,
PeftMixedModel,
PeftModel,
PeftModelForCausalLM,
Expand Down Expand Up @@ -1756,6 +1758,49 @@ def test_vblora_with_incompatible_vector_length_with_out_features(self):
get_peft_model(model, config)


class TestOft:
torch_device = infer_device()

def get_model(self, bias=True):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(32, 32)

return MyModule().eval().to(self.torch_device)

@pytest.mark.parametrize("peft_version", ["0.17.0", "0.18.0", None])
def test_load_outdated_oft_checkpoint_warns(self, peft_version, tmp_path, recwarn):
# In PEFT v0.18.0, there was a small change in the OFT implementation with Cayley-Neumann enabled. As the
# outputs change slightly, users need to be warned about it if the checkpoint stems from a PEFT version below
# 0.18.0. When the 'peft_version' key is not in the config, it means that the version is below 0.18.0.
config = OFTConfig(target_modules=["lin"], use_cayley_neumann=True) # only relevant when using Cayley-Neumann
model = get_peft_model(self.get_model(), config)
model.save_pretrained(tmp_path)
del model

# overwrite the peft_version
with open(tmp_path / "adapter_config.json") as f:
config_json = json.load(f)

if peft_version is None:
del config_json["peft_version"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
del config_json["peft_version"]
config_json.pop("peft_version", None)

In case the key does not exist.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

else:
config_json["peft_version"] = peft_version

with open(tmp_path / "adapter_config.json", "w") as f:
json.dump(config_json, f)

msg = "TODO" # <= replace with final warning message
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace it with the correct error message (start of the message is fine).

PeftModel.from_pretrained(self.get_model(), tmp_path)

warn_messages = [str(w.message) for w in recwarn.list]
if peft_version == "0.18.0":
assert not any(w.startswith(msg) for w in warn_messages)
else:
assert any(w.startswith(msg) for w in warn_messages)


class TestC3AInitialization:
torch_device = infer_device()

Expand Down