Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,9 @@ def __init__(
super().__init__()
LoraLayer.__init__(self, base_layer)

if base_layer.groups > 1:
warnings.warn("LoRA adapter added to ConvNd layer with groups > 1. Merging is not supported.")

self._active_adapter = adapter_name
self._kernel_dim = base_layer.weight.dim()

Expand Down Expand Up @@ -1061,7 +1064,9 @@ def update_layer(
conv_layer = type(base_layer)
out_kernel = out_stride = (1,) * (self._kernel_dim - 2)
self.lora_A[adapter_name] = conv_layer(self.in_features, r, kernel_size, stride, padding, bias=False)
self.lora_B[adapter_name] = conv_layer(r, self.out_features, out_kernel, out_stride, bias=lora_bias)
self.lora_B[adapter_name] = conv_layer(
r, self.out_features // base_layer.groups, out_kernel, out_stride, bias=lora_bias
)
self.lora_bias[adapter_name] = lora_bias

if use_rslora:
Expand Down Expand Up @@ -1126,6 +1131,11 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
for active_adapter in adapter_names:
if active_adapter in self.lora_A.keys():
base_layer = self.get_base_layer()

if base_layer.groups > 1:
# https://github.com/huggingface/peft/pull/2403
raise NotImplementedError("Merging is not supported for _ConvNd layers with groups > 1!")

if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
Expand Down Expand Up @@ -1243,13 +1253,12 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
3
) * self.scaling[adapter]
else:
output_tensor = (
self.conv_fn(
weight_A.transpose(0, 1),
weight_B,
).transpose(0, 1)
* self.scaling[adapter]
)
output_tensor = self.conv_fn(weight_A.transpose(0, 1), weight_B)

if self.get_base_layer().groups > 1:
output_tensor = output_tensor * self.scaling[adapter]
else:
output_tensor = output_tensor.transpose(0, 1) * self.scaling[adapter]

if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)
Expand Down
54 changes: 54 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@
("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 1 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}),
("Conv2d 2 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"], "use_dora": True}),
("Conv2d Groups LoRA", "Conv2dGroups", LoraConfig, {"target_modules": ["conv2d"]}),
("Conv2d Groups LoRA with DoRA", "Conv2dGroups", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}),
("Conv3d 1 LoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d"]}),
("Conv3d 2 LoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d", "lin0"]}),
("Conv3d 1 LoRA with DoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d"], "use_dora": True}),
Expand Down Expand Up @@ -903,6 +905,25 @@ def forward(self, X):
return X


class ModelConv2DGroups(nn.Module):
def __init__(self):
super().__init__()
self.conv2d = nn.Conv2d(5, 5, 3, groups=5)
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(5, 2)
self.sm = nn.LogSoftmax(dim=-1)

def forward(self, X):
X = X.float().reshape(-1, 5, 3, 3)
X = self.conv2d(X)
X = self.relu(X)
X = self.flat(X)
X = self.lin0(X)
X = self.sm(X)
return X


class ModelConv3D(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -967,6 +988,9 @@ def from_pretrained(cls, model_id, torch_dtype=None):
if model_id == "Conv2d":
return ModelConv2D().to(torch_dtype)

if model_id == "Conv2dGroups":
return ModelConv2DGroups().to(torch_dtype)

if model_id == "Conv3d":
return ModelConv3D().to(torch_dtype)

Expand Down Expand Up @@ -1038,6 +1062,12 @@ def test_load_multiple_adapters(self, test_name, model_id, config_cls, config_kw

@parameterized.expand(TEST_CASES)
def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
# https://github.com/huggingface/peft/pull/2403
if model_id in ["Conv2dGroups"]:
pytest.skip(
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
)

config_kwargs = config_kwargs.copy()
if issubclass(config_cls, LoraConfig):
config_kwargs["init_lora_weights"] = False
Expand All @@ -1055,6 +1085,12 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):

@parameterized.expand(TEST_CASES)
def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs):
# https://github.com/huggingface/peft/pull/2403
if model_id in ["Conv2dGroups"]:
pytest.skip(
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
)

config_kwargs = config_kwargs.copy()
if issubclass(config_cls, LoraConfig):
config_kwargs["init_lora_weights"] = False
Expand All @@ -1064,6 +1100,12 @@ def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs)

@parameterized.expand(TEST_CASES)
def test_merge_layers_is_idempotent(self, test_name, model_id, config_cls, config_kwargs):
# https://github.com/huggingface/peft/pull/2403
if model_id in ["Conv2dGroups"]:
pytest.skip(
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
)

# calling merge twice with the same arguments should not change the output
config_kwargs = config_kwargs.copy()
if issubclass(config_cls, LoraConfig):
Expand All @@ -1074,6 +1116,12 @@ def test_merge_layers_is_idempotent(self, test_name, model_id, config_cls, confi

@parameterized.expand(TEST_CASES)
def test_safe_merge(self, test_name, model_id, config_cls, config_kwargs):
# https://github.com/huggingface/peft/pull/2403
if model_id in ["Conv2dGroups"]:
pytest.skip(
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
)

# calling merge twice with the same arguments should not change the output
config_kwargs = config_kwargs.copy()
if issubclass(config_cls, LoraConfig):
Expand Down Expand Up @@ -1290,6 +1338,12 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):

@parameterized.expand(TEST_CASES)
def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, config_kwargs):
# https://github.com/huggingface/peft/pull/2403
if model_id in ["Conv2dGroups"]:
pytest.skip(
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
)

# same as test_disable_adapters, but with merging
X = self.prepare_inputs_for_testing()
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
Expand Down