Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
305 changes: 292 additions & 13 deletions src/peft/tuners/oft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
"""
self.base_layer = base_layer
self.oft_R = nn.ModuleDict({})
# For Embedding layer
self.oft_embedding_R = nn.ModuleDict({})
self.oft_block_size = {}
self.r = {}
self.oft_block_size = {}
Expand All @@ -345,6 +347,8 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
in_features, out_features = base_layer.in_features, base_layer.out_features
elif isinstance(base_layer, nn.Conv2d):
in_features, out_features = base_layer.in_channels, base_layer.out_channels
elif isinstance(base_layer, nn.Embedding):
in_features, out_features = base_layer.embedding_dim, base_layer.num_embeddings
elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"):
# QuantLinear
in_features, out_features = base_layer.infeatures, base_layer.outfeatures
Expand Down Expand Up @@ -376,10 +380,6 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self.in_features = in_features
self.out_features = out_features

@property
def _available_adapters(self) -> set[str]:
return {*self.oft_R}

def set_scale(self, adapter, scale):
if adapter not in self.scaling:
# Ignore the case where the adapter is not in the layer
Expand Down Expand Up @@ -492,15 +492,25 @@ def reset_oft_parameters(self, adapter_name, init_weights):
Reset the OFT parameters.
"""
if init_weights is False:
nn.init.normal_(self.oft_R[adapter_name].weight, mean=0.0, std=0.1)
return
if adapter_name in self.oft_R.keys():
nn.init.normal_(self.oft_R[adapter_name].weight, mean=0.0, std=0.1)
return
if adapter_name in self.oft_embedding_R.keys():
nn.init.normal_(self.oft_embedding_R[adapter_name].weight, mean=0.0, std=0.1)
return

if adapter_name in self.oft_R.keys():
if init_weights is True:
# initialize oft_R to zero
nn.init.zeros_(self.oft_R[adapter_name].weight)
else:
raise ValueError(f"Unknown initialization {init_weights=}")
if adapter_name in self.oft_embedding_R.keys():
if init_weights is True:
# initialize oft_embedding_R to zero
nn.init.zeros_(self.oft_embedding_R[adapter_name].weight)
else:
raise ValueError(f"Unknown initialization {init_weights=}")

def adjust_oft_parameters(self, in_features, params):
"""
Expand Down Expand Up @@ -582,7 +592,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
return

for active_adapter in adapter_names:
if active_adapter in self._available_adapters:
if active_adapter in self.oft_R.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if safe_merge:
Expand Down Expand Up @@ -636,16 +646,33 @@ def unmerge(self) -> None:

base_layer.weight.data = orig_weights.to(orig_dtype)

def get_delta_weight(self, adapter_name) -> tuple[torch.Tensor, torch.Tensor]:
def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.

Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""

return self.oft_R[adapter_name].get_weight()
device = self.oft_R[adapter].weight.device
dtype = self.oft_R[adapter].weight.dtype

# In case users wants to merge the adapter weights that are in
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)

oft_R_module = self.oft_R[adapter]

if cast_to_fp32:
# Temporarily work in fp32 for faster CPU matmul
original_weight = oft_R_module.weight.data
oft_R_module.weight.data = oft_R_module.weight.data.float()
oft_mat = oft_R_module.get_weight()
oft_R_module.weight.data = original_weight # restore
return oft_mat.to(dtype)
else:
return oft_R_module.get_weight()

def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
previous_dtype = x.dtype
Expand Down Expand Up @@ -884,16 +911,33 @@ def unmerge(self) -> None:

base_layer.weight.data = orig_weights.to(orig_dtype)

def get_delta_weight(self, adapter_name) -> tuple[torch.Tensor, torch.Tensor]:
def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.

Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""

return self.oft_R[adapter_name].get_weight()
device = self.oft_embedding_R[adapter].weight.device
dtype = self.oft_embedding_R[adapter].weight.dtype

# In case users wants to merge the adapter weights that are in
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)

oft_R_module = self.oft_embedding_R[adapter]

if cast_to_fp32:
# Temporarily work in fp32 for faster CPU matmul
original_weight = oft_R_module.weight.data
oft_R_module.weight.data = oft_R_module.weight.data.float()
oft_mat = oft_R_module.get_weight()
oft_R_module.weight.data = original_weight # restore
return oft_mat.to(dtype)
else:
return oft_R_module.get_weight()

def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype
Expand Down Expand Up @@ -923,6 +967,237 @@ def __repr__(self) -> str:
return "oft." + rep


class Embedding(nn.Module, OFTLayer):
# OFT implemented in a Embedding layer
def __init__(
self,
base_layer: nn.Module,
adapter_name: str,
r: int = 8,
oft_block_size: int = 0,
module_dropout: float = 0.0,
coft: bool = False,
eps: float = 6e-5,
block_share: bool = False,
use_cayley_neumann: bool = False,
num_cayley_neumann_terms: int = 5,
fan_in_fan_out: bool = False, # unused for embedding, kept for API parity
init_weights: Union[bool, str] = True,
**kwargs,
) -> None:
super().__init__()
OFTLayer.__init__(self, base_layer)
self.fan_in_fan_out = fan_in_fan_out

self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
oft_block_size=oft_block_size,
module_dropout=module_dropout,
coft=coft,
eps=eps,
block_share=block_share,
init_weights=init_weights,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)

def update_layer(
self,
adapter_name,
r,
oft_block_size,
module_dropout,
coft,
eps,
block_share,
init_weights,
use_cayley_neumann,
num_cayley_neumann_terms,
inference_mode: bool = False,
**kwargs,
):
# collect the kwargs
kwargs = locals().copy()
del kwargs["self"]

if r == 0 and oft_block_size != 0:
if self.in_features % oft_block_size != 0 or oft_block_size > self.in_features:
old_oft_block_size = oft_block_size
oft_block_size = self.adjust_oft_parameters(self.in_features, oft_block_size)
warnings.warn(
f"Invalid `oft_block_size` ({old_oft_block_size})! Adjusted `oft_block_size` to ({oft_block_size})."
)
r = int(self.in_features // oft_block_size)
elif r != 0 and oft_block_size == 0:
if self.in_features % r != 0 or r > self.in_features:
old_r = r
r = self.adjust_oft_parameters(self.in_features, r)
warnings.warn(f"Invalid `r` ({old_r})! Adjusted `r` to ({r}).")
oft_block_size = int(self.in_features // r)
else:
raise ValueError(
"Something went wrong, please report this error: https://github.com/huggingface/peft/issues"
)

# Create weights with provided shape
n_elements = oft_block_size * (oft_block_size - 1) // 2
self.oft_embedding_R[adapter_name] = OFTRotationModule(
r if not block_share else 1,
n_elements,
oft_block_size,
self.in_features,
coft=coft,
eps=eps,
block_share=block_share,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)

# Initialize weights
self.reset_oft_parameters(adapter_name, init_weights)

# set oft r and block size
self.r[adapter_name] = r
self.oft_block_size[adapter_name] = oft_block_size

# Move new weights to device
self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters, inference_mode=inference_mode)

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights

Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return

for active_adapter in adapter_names:
if active_adapter in self.oft_embedding_R.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if safe_merge:
# Note that safe_merge will be slower than the normal merge
orig_weights = base_layer.weight.data
oft_mat = self.get_delta_weight(active_adapter)
orig_weights = torch.mm(orig_weights.to(oft_mat.dtype), oft_mat)

if not torch.isfinite(orig_weights).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)

base_layer.weight.data = orig_weights.contiguous().to(orig_dtype)
else:
orig_weights = base_layer.weight.data
oft_mat = self.get_delta_weight(active_adapter)
orig_weights = torch.mm(orig_weights.to(oft_mat.dtype), oft_mat)

base_layer.weight.data = orig_weights.contiguous().to(orig_dtype)
self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return

base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.oft_R.keys():
oft_mat = self.get_delta_weight(active_adapter)

previous_dtype = oft_mat.dtype
if previous_dtype != torch.float32:
oft_mat = oft_mat.to(torch.float32)

orig_weights = self.get_base_layer().weight.data
orig_weights = torch.mm(orig_weights.to(oft_mat.dtype), torch.linalg.inv(oft_mat))

base_layer.weight.data = orig_weights.to(orig_dtype)

def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.

Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.oft_embedding_R[adapter].weight.device
dtype = self.oft_embedding_R[adapter].weight.dtype

# In case users wants to merge the adapter weights that are in
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)

oft_R_module = self.oft_embedding_R[adapter]

if cast_to_fp32:
# Temporarily work in fp32 for faster CPU matmul
original_weight = oft_R_module.weight.data
oft_R_module.weight.data = oft_R_module.weight.data.float()
oft_mat = oft_R_module.get_weight()
oft_R_module.weight.data = original_weight # restore
return oft_mat.to(dtype)
else:
return oft_R_module.get_weight()

def _embed(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
base_layer = self.get_base_layer()
return F.embedding(
input,
weight,
padding_idx=base_layer.padding_idx,
max_norm=base_layer.max_norm,
norm_type=base_layer.norm_type,
scale_grad_by_freq=base_layer.scale_grad_by_freq,
sparse=base_layer.sparse,
)

def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
# x is token ids (usually LongTensor); rotation is applied to embedding outputs
if self.disable_adapters:
if self.merged:
self.unmerge()
return self.base_layer(x, *args, **kwargs)
if self.merged:
return self.base_layer(x, *args, **kwargs)

result = self.base_layer(x, *args, **kwargs)
out_dtype = result.dtype

for active_adapter in self.active_adapters:
if active_adapter not in self.oft_embedding_R:
continue
oft_embedding_R = self.oft_embedding_R[active_adapter]
result = self._cast_input_dtype(result, oft_embedding_R.weight.dtype)
result = oft_embedding_R(result)

return result.to(out_dtype)

def __repr__(self) -> str:
rep = super().__repr__()
return "oft." + rep


def dispatch_default(
target: torch.nn.Module,
adapter_name: str,
Expand All @@ -946,5 +1221,9 @@ def dispatch_default(
)
kwargs["fan_in_fan_out"] = oft_config.fan_in_fan_out = False
new_module = Linear(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, torch.nn.Embedding):
embedding_kwargs = kwargs.copy()
embedding_kwargs.pop("fan_in_fan_out", None)
new_module = Embedding(target, adapter_name, **embedding_kwargs)

return new_module