Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 71 additions & 40 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,48 +777,79 @@ def _generalized_task_arithmetic_weighted_adapter(
density,
majority_sign_method,
):
# account weights for LoRA A and B layers.
valid_weights_A = []
valid_weights_B = []
lora_A_deltas = []
lora_B_deltas = []
# Collect valid adapters and their weights with scaling
valid_adapters = []
valid_weights = []
is_embedding = any(adapter in target.lora_embedding_A for adapter in adapters)
for adapter, weight in zip(adapters, weights):
if adapter in target.lora_A:
current_adapter_lora_A = target.lora_A[adapter].weight
current_adapter_lora_B = target.lora_B[adapter].weight
elif adapter in target.lora_embedding_A:
current_adapter_lora_A = target.lora_embedding_A[adapter]
current_adapter_lora_B = target.lora_embedding_B[adapter]
else:
continue
# Support negative weights: take absolute value for sqrt, then apply sign
weight_with_scaling = weight * target.scaling[adapter]
sign = 1 if weight_with_scaling >= 0 else -1
# apply sign only on one side of the weights, otherwise negative signs negate
valid_weights_A.append(math.sqrt(abs(weight_with_scaling)) * sign)
valid_weights_B.append(math.sqrt(abs(weight_with_scaling)))
lora_A_deltas.append(current_adapter_lora_A.data)
lora_B_deltas.append(current_adapter_lora_B.data)
valid_weights_A = torch.tensor(valid_weights_A).to(lora_A_deltas[0].device)
valid_weights_B = torch.tensor(valid_weights_B).to(lora_B_deltas[0].device)
valid_weights = [valid_weights_A, valid_weights_B]
lora_deltas = [lora_A_deltas, lora_B_deltas]
dtype = lora_A_deltas[0].dtype
for i, task_tensors in enumerate(lora_deltas):
if combination_type == "linear":
lora_deltas[i] = task_arithmetic(task_tensors, valid_weights[i])
elif combination_type == "ties":
lora_deltas[i] = ties(task_tensors, valid_weights[i], density, majority_sign_method)
elif combination_type == "dare_linear":
lora_deltas[i] = dare_linear(task_tensors, valid_weights[i], density)
elif combination_type == "dare_ties":
lora_deltas[i] = dare_ties(task_tensors, valid_weights[i], density, majority_sign_method)
elif combination_type == "magnitude_prune":
lora_deltas[i] = magnitude_prune(task_tensors, valid_weights[i], density)
if adapter in target.lora_A or adapter in target.lora_embedding_A:
valid_adapters.append(adapter)
valid_weights.append(weight * target.scaling[adapter])

if len(valid_adapters) == 0:
raise ValueError("No matching LoRAs found. Please raise an issue on GitHub.")

# Get the dtype and shape info from the first adapter
if valid_adapters[0] in target.lora_A:
dtype = target.lora_A[valid_adapters[0]].weight.dtype
lora_A_shape = target.lora_A[valid_adapters[0]].weight.shape
lora_B_shape = target.lora_B[valid_adapters[0]].weight.shape
else:
dtype = target.lora_embedding_A[valid_adapters[0]].dtype
lora_A_shape = target.lora_embedding_A[valid_adapters[0]].shape
lora_B_shape = target.lora_embedding_B[valid_adapters[0]].shape

# Compute full delta weights for each adapter to avoid cross-terms bug
# See https://github.com/huggingface/peft/issues/3004
delta_weights = [target.get_delta_weight(adapter) for adapter in valid_adapters]
valid_weights_tensor = torch.tensor(valid_weights).to(delta_weights[0].device)

# Apply the combination method to the full delta weights
if combination_type == "linear":
combined_delta = task_arithmetic(delta_weights, valid_weights_tensor)
elif combination_type == "ties":
combined_delta = ties(delta_weights, valid_weights_tensor, density, majority_sign_method)
elif combination_type == "dare_linear":
combined_delta = dare_linear(delta_weights, valid_weights_tensor, density)
elif combination_type == "dare_ties":
combined_delta = dare_ties(delta_weights, valid_weights_tensor, density, majority_sign_method)
elif combination_type == "magnitude_prune":
combined_delta = magnitude_prune(delta_weights, valid_weights_tensor, density)
else:
raise ValueError("Invalid combination type")

# Handle Conv2d layers - flatten for SVD, then reshape back
conv2d = isinstance(target, Conv2d)
if conv2d:
conv2d_1x1 = target.weight.size()[2:4] == (1, 1)
if not conv2d_1x1:
combined_delta = combined_delta.flatten(start_dim=1)
else:
raise ValueError("Invalid combination type")
lora_deltas = [delta.to(dtype) for delta in lora_deltas]
return lora_deltas
combined_delta = combined_delta.squeeze()

# Handle transpose for embeddings and fan_in_fan_out layers
if (hasattr(target, "fan_in_fan_out") and target.fan_in_fan_out) or is_embedding:
combined_delta = combined_delta.T

# Decompose combined delta back into A and B using truncated SVD
# This ensures the result is mathematically correct without cross-terms
rank = lora_A_shape[0] # r dimension
U, S, Vh = torch.linalg.svd(combined_delta, full_matrices=False)
U = U[:, :rank]
S = S[:rank]
Vh = Vh[:rank, :]

# Distribute singular values: A gets sqrt(S), B gets sqrt(S)
sqrt_S = torch.sqrt(S)
new_lora_B = (U * sqrt_S.unsqueeze(0)).to(dtype) # [out_features, rank]
new_lora_A = (Vh * sqrt_S.unsqueeze(1)).to(dtype) # [rank, in_features]

# Reshape for Conv2d layers
if conv2d:
new_lora_B = new_lora_B.reshape(lora_B_shape)
new_lora_A = new_lora_A.reshape(lora_A_shape)

return new_lora_A, new_lora_B

def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adapter_name: str, kwargs=None):
"""
Expand Down