Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/peft/tuners/oft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

from __future__ import annotations

import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional, Union

import packaging.version

from peft.config import PeftConfig
from peft.utils import PeftType

Expand Down Expand Up @@ -193,4 +196,12 @@ def check_kwargs(cls, **kwargs):
"with the latest version of OFT. Please retrain your adapter weights with newer PEFT versions. "
"Alternatively, downgrade PEFT to version 0.13.0 to use the old adapter weights."
)
if kwargs["use_caylay_neumann"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this more backwards compatible, also, there was a typo:

Suggested change
if kwargs["use_caylay_neumann"]:
if kwargs.get("use_cayley_neumann", False):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

peft_version = kwargs.get("peft_version", "unknown")
parsed_version = packaging.version.Version(peft_version)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parsed_version = packaging.version.Version(peft_version)
# remove commit hash, if present
peft_version = peft_version.partition("@")[0]
parsed_version = packaging.version.Version(peft_version)

min_version = packaging.version.Version("0.18.0")
# note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version
if (peft_version == "unknown") or (parsed_version < min_version):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
peft_version = kwargs.get("peft_version", "unknown")
parsed_version = packaging.version.Version(peft_version)
min_version = packaging.version.Version("0.18.0")
# note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version
if (peft_version == "unknown") or (parsed_version < min_version):
peft_version = kwargs.get("peft_version", "0.0.0") # if not present, set a low dummy version
parsed_version = packaging.version.Version(peft_version)
min_version = packaging.version.Version("0.18.0")
# note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version
if parsed_version < min_version:

I just tested and parsing with "unknown" throws an error, so let's put a dummy value here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

msg = "warning message that explains what is happening"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this message was just a placeholder :-D Could you please add a short explanation here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure:)

warnings.warn(msg)
return super().check_kwargs(**kwargs)
10 changes: 9 additions & 1 deletion src/peft/tuners/oft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,15 @@ def __init__(
kernel_size=(0, 0),
use_cayley_neumann=True,
num_cayley_neumann_terms=5,
device=None,
):
super().__init__()
self.r = r
self.n_elements = n_elements
self.block_size = block_size
self.in_features = in_features
self.weight = nn.Parameter(torch.empty(r, n_elements))
self.weight.to(device)
self.coft = coft
self.eps = eps
self.block_share = block_share
Expand All @@ -98,6 +100,8 @@ def __init__(
self.num_cayley_neumann_terms = num_cayley_neumann_terms
# Create indices for upper triangle (excluding diagonal)
self.rows, self.cols = torch.triu_indices(block_size, block_size, 1)
self.rows.to(device)
self.cols.to(device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As rows and cols are tensors, calling .to() is not in-place operations, you have to re-assign the values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


def _pytorch_skew_symmetric(self, vec, block_size):
batch_size = vec.shape[0]
Expand Down Expand Up @@ -139,9 +143,11 @@ def _cayley_batch(
R.add_(Q_squared, alpha=2.0)

Q_power = Q_squared
for i in range(3, num_neumann_terms):
for _ in range(3, num_neumann_terms - 1):
Q_power = torch.bmm(Q_power, Q_skew)
R.add_(Q_power, alpha=2.0)
Q_power = torch.bmm(Q_power, Q_skew)
R.add_(Q_power)
else:
id_mat = (
torch.eye(Q_skew.shape[-1], device=Q_skew.device)
Expand Down Expand Up @@ -470,6 +476,7 @@ def update_layer(
block_share=block_share,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
device=self.weight.device,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly more robust:

Suggested change
device=self.weight.device,
device=self.get_base_layer().weight.device,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

)

# Initialize weights
Expand Down Expand Up @@ -770,6 +777,7 @@ def update_layer(
kernel_size=base_layer.kernel_size,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
device=self.weight.device,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly more robust:

Suggested change
device=self.weight.device,
device=self.get_base_layer().weight.device,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

)

# Initialize weights
Expand Down