From 700f64eb31d776213e263bf1fc6ec8468915042d Mon Sep 17 00:00:00 2001 From: Florent Draye Date: Mon, 29 Sep 2025 18:04:08 +0200 Subject: [PATCH 01/10] minor changes to OFT to make it faster --- src/peft/tuners/oft/layer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/peft/tuners/oft/layer.py b/src/peft/tuners/oft/layer.py index 6b14d015ae..86b6eab613 100644 --- a/src/peft/tuners/oft/layer.py +++ b/src/peft/tuners/oft/layer.py @@ -114,6 +114,7 @@ def _pytorch_skew_symmetric_inv(self, matrix, block_size): vec = matrix[:, self.rows, self.cols] return vec + @torch.compile def _cayley_batch( self, Q: torch.Tensor, block_size: int, use_cayley_neumann: bool = True, num_neumann_terms: int = 5 ) -> torch.Tensor: @@ -139,9 +140,11 @@ def _cayley_batch( R.add_(Q_squared, alpha=2.0) Q_power = Q_squared - for i in range(3, num_neumann_terms): + for _ in range(3, num_neumann_terms - 1): Q_power = torch.bmm(Q_power, Q_skew) R.add_(Q_power, alpha=2.0) + Q_power = torch.bmm(Q_power, Q_skew) + R.add_(Q_power) else: id_mat = ( torch.eye(Q_skew.shape[-1], device=Q_skew.device) @@ -248,6 +251,10 @@ def forward(self, x): if required_dtype != self.weight.dtype: x = x.to(self.weight.dtype) + if self.rows.device != self.weight.device: + self.rows = self.rows.to(self.weight.device) + self.cols = self.cols.to(self.weight.device) + orig_shape = x.shape if self.coft: From 8317b9d45d0a9cd3b5cac5568632ad00141dfdb6 Mon Sep 17 00:00:00 2001 From: Florent Draye Date: Wed, 1 Oct 2025 13:31:52 +0200 Subject: [PATCH 02/10] update oft --- src/peft/tuners/oft/config.py | 10 ++++++++++ src/peft/tuners/oft/layer.py | 11 ++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/peft/tuners/oft/config.py b/src/peft/tuners/oft/config.py index 4b33f13cb3..ca6e1b447c 100644 --- a/src/peft/tuners/oft/config.py +++ b/src/peft/tuners/oft/config.py @@ -20,6 +20,8 @@ from peft.config import PeftConfig from peft.utils import PeftType +import packaging.version +import warnings @dataclass class OFTConfig(PeftConfig): @@ -193,4 +195,12 @@ def check_kwargs(cls, **kwargs): "with the latest version of OFT. Please retrain your adapter weights with newer PEFT versions. " "Alternatively, downgrade PEFT to version 0.13.0 to use the old adapter weights." ) + if kwargs["use_caylay_neumann"]: + peft_version = kwargs.get("peft_version", "unknown") + parsed_version = packaging.version.Version(peft_version) + min_version = packaging.version.Version("0.18.0") + # note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version + if (peft_version == "unknown") or (parsed_version < min_version): + msg = "warning message that explains what is happening" + warnings.warn(msg) return super().check_kwargs(**kwargs) diff --git a/src/peft/tuners/oft/layer.py b/src/peft/tuners/oft/layer.py index 86b6eab613..a05140d837 100644 --- a/src/peft/tuners/oft/layer.py +++ b/src/peft/tuners/oft/layer.py @@ -82,6 +82,7 @@ def __init__( kernel_size=(0, 0), use_cayley_neumann=True, num_cayley_neumann_terms=5, + device=None, ): super().__init__() self.r = r @@ -89,6 +90,7 @@ def __init__( self.block_size = block_size self.in_features = in_features self.weight = nn.Parameter(torch.empty(r, n_elements)) + self.weight.to(device) self.coft = coft self.eps = eps self.block_share = block_share @@ -98,6 +100,8 @@ def __init__( self.num_cayley_neumann_terms = num_cayley_neumann_terms # Create indices for upper triangle (excluding diagonal) self.rows, self.cols = torch.triu_indices(block_size, block_size, 1) + self.rows.to(device) + self.cols.to(device) def _pytorch_skew_symmetric(self, vec, block_size): batch_size = vec.shape[0] @@ -114,7 +118,6 @@ def _pytorch_skew_symmetric_inv(self, matrix, block_size): vec = matrix[:, self.rows, self.cols] return vec - @torch.compile def _cayley_batch( self, Q: torch.Tensor, block_size: int, use_cayley_neumann: bool = True, num_neumann_terms: int = 5 ) -> torch.Tensor: @@ -251,10 +254,6 @@ def forward(self, x): if required_dtype != self.weight.dtype: x = x.to(self.weight.dtype) - if self.rows.device != self.weight.device: - self.rows = self.rows.to(self.weight.device) - self.cols = self.cols.to(self.weight.device) - orig_shape = x.shape if self.coft: @@ -477,6 +476,7 @@ def update_layer( block_share=block_share, use_cayley_neumann=use_cayley_neumann, num_cayley_neumann_terms=num_cayley_neumann_terms, + device=self.weight.device, ) # Initialize weights @@ -777,6 +777,7 @@ def update_layer( kernel_size=base_layer.kernel_size, use_cayley_neumann=use_cayley_neumann, num_cayley_neumann_terms=num_cayley_neumann_terms, + device=self.weight.device, ) # Initialize weights From 96f769937fab687b443944a451bece9fa7716731 Mon Sep 17 00:00:00 2001 From: Florent Draye Date: Wed, 1 Oct 2025 13:33:41 +0200 Subject: [PATCH 03/10] update with make style --- src/peft/tuners/oft/config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/oft/config.py b/src/peft/tuners/oft/config.py index ca6e1b447c..c38f539fef 100644 --- a/src/peft/tuners/oft/config.py +++ b/src/peft/tuners/oft/config.py @@ -14,14 +14,15 @@ from __future__ import annotations +import warnings from dataclasses import dataclass, field from typing import Literal, Optional, Union +import packaging.version + from peft.config import PeftConfig from peft.utils import PeftType -import packaging.version -import warnings @dataclass class OFTConfig(PeftConfig): From 806a425e2ba89974748cd2be4a4e5b39d0530b89 Mon Sep 17 00:00:00 2001 From: Florent Draye Date: Thu, 2 Oct 2025 11:02:27 +0200 Subject: [PATCH 04/10] oft add test init --- src/peft/tuners/oft/config.py | 8 +++--- src/peft/tuners/oft/layer.py | 10 ++++---- tests/test_initialization.py | 46 +++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 9 deletions(-) diff --git a/src/peft/tuners/oft/config.py b/src/peft/tuners/oft/config.py index c38f539fef..b39b78511e 100644 --- a/src/peft/tuners/oft/config.py +++ b/src/peft/tuners/oft/config.py @@ -196,12 +196,12 @@ def check_kwargs(cls, **kwargs): "with the latest version of OFT. Please retrain your adapter weights with newer PEFT versions. " "Alternatively, downgrade PEFT to version 0.13.0 to use the old adapter weights." ) - if kwargs["use_caylay_neumann"]: - peft_version = kwargs.get("peft_version", "unknown") + if kwargs.get("use_cayley_neumann", False): + peft_version = kwargs.get("peft_version", "0.0.0") # if not present, set a low dummy version parsed_version = packaging.version.Version(peft_version) min_version = packaging.version.Version("0.18.0") # note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version - if (peft_version == "unknown") or (parsed_version < min_version): - msg = "warning message that explains what is happening" + if parsed_version < min_version: + msg = "The cayley-neumann parameterization has been slightly changed to be more numerically stable in PEFT 0.18.0. Please retrain your adapter weights with newer PEFT versions. Alternatively, downgrade PEFT to version 0.17.0 to use the old parameterization." warnings.warn(msg) return super().check_kwargs(**kwargs) diff --git a/src/peft/tuners/oft/layer.py b/src/peft/tuners/oft/layer.py index a05140d837..95e9e3e833 100644 --- a/src/peft/tuners/oft/layer.py +++ b/src/peft/tuners/oft/layer.py @@ -90,7 +90,7 @@ def __init__( self.block_size = block_size self.in_features = in_features self.weight = nn.Parameter(torch.empty(r, n_elements)) - self.weight.to(device) + self.weight = self.weight.to(device) self.coft = coft self.eps = eps self.block_share = block_share @@ -100,8 +100,8 @@ def __init__( self.num_cayley_neumann_terms = num_cayley_neumann_terms # Create indices for upper triangle (excluding diagonal) self.rows, self.cols = torch.triu_indices(block_size, block_size, 1) - self.rows.to(device) - self.cols.to(device) + self.rows = self.rows.to(device) + self.cols = self.cols.to(device) def _pytorch_skew_symmetric(self, vec, block_size): batch_size = vec.shape[0] @@ -476,7 +476,7 @@ def update_layer( block_share=block_share, use_cayley_neumann=use_cayley_neumann, num_cayley_neumann_terms=num_cayley_neumann_terms, - device=self.weight.device, + device=self.get_base_layer().weight.device, ) # Initialize weights @@ -777,7 +777,7 @@ def update_layer( kernel_size=base_layer.kernel_size, use_cayley_neumann=use_cayley_neumann, num_cayley_neumann_terms=num_cayley_neumann_terms, - device=self.weight.device, + device=self.get_base_layer().weight.device, ) # Initialize weights diff --git a/tests/test_initialization.py b/tests/test_initialization.py index f37e0c2cbe..a26ddd5cda 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -35,6 +35,8 @@ from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer +import json + from peft import ( AdaLoraConfig, C3AConfig, @@ -43,6 +45,7 @@ LoftQConfig, LoKrConfig, LoraConfig, + OFTConfig, PeftMixedModel, PeftModel, PeftModelForCausalLM, @@ -1756,6 +1759,49 @@ def test_vblora_with_incompatible_vector_length_with_out_features(self): get_peft_model(model, config) +class TestOft: + torch_device = infer_device() + + def get_model(self, bias=True): + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.lin = nn.Linear(32, 32) + + return MyModule().eval().to(self.torch_device) + + @pytest.mark.parametrize("peft_version", ["0.17.0", "0.18.0", None]) + def test_load_outdated_oft_checkpoint_warns(self, peft_version, tmp_path, recwarn): + # In PEFT v0.18.0, there was a small change in the OFT implementation with Cayley-Neumann enabled. As the + # outputs change slightly, users need to be warned about it if the checkpoint stems from a PEFT version below + # 0.18.0. When the 'peft_version' key is not in the config, it means that the version is below 0.18.0. + config = OFTConfig(target_modules=["lin"], use_cayley_neumann=True) # only relevant when using Cayley-Neumann + model = get_peft_model(self.get_model(), config) + model.save_pretrained(tmp_path) + del model + + # overwrite the peft_version + with open(tmp_path / "adapter_config.json") as f: + config_json = json.load(f) + + if peft_version is None: + del config_json["peft_version"] + else: + config_json["peft_version"] = peft_version + + with open(tmp_path / "adapter_config.json", "w") as f: + json.dump(config_json, f) + + msg = "TODO" # <= replace with final warning message + PeftModel.from_pretrained(self.get_model(), tmp_path) + + warn_messages = [str(w.message) for w in recwarn.list] + if peft_version == "0.18.0": + assert not any(w.startswith(msg) for w in warn_messages) + else: + assert any(w.startswith(msg) for w in warn_messages) + + class TestC3AInitialization: torch_device = infer_device() From da73c332e94a3b6cf8764f10889af79dbcb82f57 Mon Sep 17 00:00:00 2001 From: Florent Draye Date: Thu, 2 Oct 2025 11:03:03 +0200 Subject: [PATCH 05/10] make style --- tests/test_initialization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_initialization.py b/tests/test_initialization.py index a26ddd5cda..94205670b6 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -14,6 +14,7 @@ import copy import itertools +import json import math import platform import re @@ -35,8 +36,6 @@ from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer -import json - from peft import ( AdaLoraConfig, C3AConfig, From bc597cce18e233d236b7ed5ad6a1aa4595229796 Mon Sep 17 00:00:00 2001 From: Florent Draye Date: Thu, 2 Oct 2025 14:04:22 +0200 Subject: [PATCH 06/10] update oft config --- src/peft/tuners/oft/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/peft/tuners/oft/config.py b/src/peft/tuners/oft/config.py index b39b78511e..a73156fb99 100644 --- a/src/peft/tuners/oft/config.py +++ b/src/peft/tuners/oft/config.py @@ -198,6 +198,8 @@ def check_kwargs(cls, **kwargs): ) if kwargs.get("use_cayley_neumann", False): peft_version = kwargs.get("peft_version", "0.0.0") # if not present, set a low dummy version + # remove commit hash, if present + peft_version = peft_version.partition("@")[0] parsed_version = packaging.version.Version(peft_version) min_version = packaging.version.Version("0.18.0") # note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version From 85ed7ec784ac8595e1389483fdce6f58f9a4a535 Mon Sep 17 00:00:00 2001 From: Florent Draye Date: Fri, 3 Oct 2025 15:50:18 +0200 Subject: [PATCH 07/10] update code --- src/peft/tuners/oft/layer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/peft/tuners/oft/layer.py b/src/peft/tuners/oft/layer.py index 95e9e3e833..50c5370306 100644 --- a/src/peft/tuners/oft/layer.py +++ b/src/peft/tuners/oft/layer.py @@ -90,7 +90,6 @@ def __init__( self.block_size = block_size self.in_features = in_features self.weight = nn.Parameter(torch.empty(r, n_elements)) - self.weight = self.weight.to(device) self.coft = coft self.eps = eps self.block_share = block_share @@ -466,6 +465,7 @@ def update_layer( # Create weights with provided shape n_elements = oft_block_size * (oft_block_size - 1) // 2 + device = self.get_base_layer().weight.device self.oft_R[adapter_name] = OFTRotationModule( r if not block_share else 1, n_elements, @@ -476,7 +476,7 @@ def update_layer( block_share=block_share, use_cayley_neumann=use_cayley_neumann, num_cayley_neumann_terms=num_cayley_neumann_terms, - device=self.get_base_layer().weight.device, + device=device if device.type != "meta" else None, ) # Initialize weights @@ -766,6 +766,7 @@ def update_layer( # Create weights with provided shape n_elements = oft_block_size * (oft_block_size - 1) // 2 + device = self.get_base_layer().weight.device self.oft_R[adapter_name] = OFTRotationModule( r if not block_share else 1, n_elements, @@ -777,7 +778,7 @@ def update_layer( kernel_size=base_layer.kernel_size, use_cayley_neumann=use_cayley_neumann, num_cayley_neumann_terms=num_cayley_neumann_terms, - device=self.get_base_layer().weight.device, + device=device if device.type != "meta" else None, ) # Initialize weights From 29833f4a4a190f7827a4107cc00859f186605d34 Mon Sep 17 00:00:00 2001 From: Florent Draye Date: Fri, 3 Oct 2025 16:28:43 +0200 Subject: [PATCH 08/10] update oft with buffer cols/rows --- src/peft/tuners/oft/layer.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/peft/tuners/oft/layer.py b/src/peft/tuners/oft/layer.py index 50c5370306..e3210176bb 100644 --- a/src/peft/tuners/oft/layer.py +++ b/src/peft/tuners/oft/layer.py @@ -82,7 +82,6 @@ def __init__( kernel_size=(0, 0), use_cayley_neumann=True, num_cayley_neumann_terms=5, - device=None, ): super().__init__() self.r = r @@ -98,9 +97,9 @@ def __init__( self.use_cayley_neumann = use_cayley_neumann self.num_cayley_neumann_terms = num_cayley_neumann_terms # Create indices for upper triangle (excluding diagonal) - self.rows, self.cols = torch.triu_indices(block_size, block_size, 1) - self.rows = self.rows.to(device) - self.cols = self.cols.to(device) + rows, cols = torch.triu_indices(block_size, block_size, 1) + self.register_buffer('rows', rows, persistent=False) + self.register_buffer('cols', cols, persistent=False) def _pytorch_skew_symmetric(self, vec, block_size): batch_size = vec.shape[0] @@ -465,7 +464,6 @@ def update_layer( # Create weights with provided shape n_elements = oft_block_size * (oft_block_size - 1) // 2 - device = self.get_base_layer().weight.device self.oft_R[adapter_name] = OFTRotationModule( r if not block_share else 1, n_elements, @@ -476,7 +474,6 @@ def update_layer( block_share=block_share, use_cayley_neumann=use_cayley_neumann, num_cayley_neumann_terms=num_cayley_neumann_terms, - device=device if device.type != "meta" else None, ) # Initialize weights @@ -766,7 +763,6 @@ def update_layer( # Create weights with provided shape n_elements = oft_block_size * (oft_block_size - 1) // 2 - device = self.get_base_layer().weight.device self.oft_R[adapter_name] = OFTRotationModule( r if not block_share else 1, n_elements, @@ -778,7 +774,6 @@ def update_layer( kernel_size=base_layer.kernel_size, use_cayley_neumann=use_cayley_neumann, num_cayley_neumann_terms=num_cayley_neumann_terms, - device=device if device.type != "meta" else None, ) # Initialize weights @@ -944,4 +939,4 @@ def dispatch_default( kwargs["fan_in_fan_out"] = oft_config.fan_in_fan_out = False new_module = Linear(target, adapter_name, **kwargs) - return new_module + return new_module \ No newline at end of file From ed56c57db06b0a4134ea0128ba35dd4c14d24411 Mon Sep 17 00:00:00 2001 From: Florent Draye Date: Mon, 6 Oct 2025 14:12:41 +0200 Subject: [PATCH 09/10] update oft test --- src/peft/tuners/oft/config.py | 6 +++++- tests/test_initialization.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/peft/tuners/oft/config.py b/src/peft/tuners/oft/config.py index a73156fb99..9c62e1bece 100644 --- a/src/peft/tuners/oft/config.py +++ b/src/peft/tuners/oft/config.py @@ -204,6 +204,10 @@ def check_kwargs(cls, **kwargs): min_version = packaging.version.Version("0.18.0") # note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version if parsed_version < min_version: - msg = "The cayley-neumann parameterization has been slightly changed to be more numerically stable in PEFT 0.18.0. Please retrain your adapter weights with newer PEFT versions. Alternatively, downgrade PEFT to version 0.17.0 to use the old parameterization." + msg = ( + "The cayley-neumann parameterization has been slightly changed to be more numerically stable in " + "PEFT 0.18.0. Please retrain your adapter weights with newer PEFT versions. Alternatively, " + "downgrade PEFT to version 0.17.0 to use the old parameterization." + ) warnings.warn(msg) return super().check_kwargs(**kwargs) diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 94205670b6..f69a317718 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -1784,14 +1784,14 @@ def test_load_outdated_oft_checkpoint_warns(self, peft_version, tmp_path, recwar config_json = json.load(f) if peft_version is None: - del config_json["peft_version"] + config_json.pop("peft_version", None) else: config_json["peft_version"] = peft_version with open(tmp_path / "adapter_config.json", "w") as f: json.dump(config_json, f) - msg = "TODO" # <= replace with final warning message + msg = "The cayley-neumann parameterization has been slightly changed to be more numerically stable in PEFT 0.18.0." PeftModel.from_pretrained(self.get_model(), tmp_path) warn_messages = [str(w.message) for w in recwarn.list] From c3bc6e0ffeb0ba0c4398196a8ec380591264960c Mon Sep 17 00:00:00 2001 From: Florent Draye Date: Mon, 6 Oct 2025 14:13:16 +0200 Subject: [PATCH 10/10] run make quality --- src/peft/tuners/oft/layer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/peft/tuners/oft/layer.py b/src/peft/tuners/oft/layer.py index e3210176bb..9dd450d519 100644 --- a/src/peft/tuners/oft/layer.py +++ b/src/peft/tuners/oft/layer.py @@ -98,8 +98,8 @@ def __init__( self.num_cayley_neumann_terms = num_cayley_neumann_terms # Create indices for upper triangle (excluding diagonal) rows, cols = torch.triu_indices(block_size, block_size, 1) - self.register_buffer('rows', rows, persistent=False) - self.register_buffer('cols', cols, persistent=False) + self.register_buffer("rows", rows, persistent=False) + self.register_buffer("cols", cols, persistent=False) def _pytorch_skew_symmetric(self, vec, block_size): batch_size = vec.shape[0] @@ -939,4 +939,4 @@ def dispatch_default( kwargs["fan_in_fan_out"] = oft_config.fan_in_fan_out = False new_module = Linear(target, adapter_name, **kwargs) - return new_module \ No newline at end of file + return new_module