Skip to content

Commit 8476c23

Browse files
SP1029efraimdahl
authored andcommitted
FIX Improved handling of conv groups (huggingface#2567)
More generalized handling of groups argument in LoRA/DoRA conv layers (previous solution: huggingface#2403).
1 parent f339af2 commit 8476c23

File tree

4 files changed

+129
-15
lines changed

4 files changed

+129
-15
lines changed

src/peft/tuners/lora/dora.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=Fals
4949

5050
weight = dequantize_module_weight(base_layer)
5151
if weight.data.ndim >= 3: # For handling LoRAs applied to Conv layers.
52-
lora_weight = torch.mm(lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1))
52+
r = lora_A.shape[0]
53+
lora_weight = torch.mm(lora_B.view([-1, r]), lora_A.view([r, -1]))
5354
lora_weight = lora_weight.reshape(weight.shape)
5455
else:
5556
lora_weight = lora_B @ lora_A
@@ -145,7 +146,8 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None):
145146
output.
146147
"""
147148
weight = base_layer.weight
148-
lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1))
149+
r = lora_A.weight.shape[0]
150+
lora_weight = torch.mm(lora_B.weight.view([-1, r]), lora_A.weight.view([r, -1]))
149151
lora_weight = lora_weight.reshape(weight.shape)
150152
magnitude = self.weight
151153
weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling)

src/peft/tuners/lora/layer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,13 @@ def __init__(
10781078
if base_layer.groups > 1:
10791079
warnings.warn("LoRA adapter added to ConvNd layer with groups > 1. Merging is not supported.")
10801080

1081+
if r % base_layer.groups != 0:
1082+
raise ValueError(
1083+
f"Targeting a {base_layer.__class__.__name__} with groups={base_layer.groups} and rank {r}. "
1084+
"Currently, support is limited to conv layers where the rank is divisible by groups. "
1085+
"Either choose a different rank or do not target this specific layer."
1086+
)
1087+
10811088
self._active_adapter = adapter_name
10821089
self._kernel_dim = base_layer.weight.dim()
10831090

@@ -1123,7 +1130,7 @@ def update_layer(
11231130
out_kernel = out_stride = (1,) * (self._kernel_dim - 2)
11241131
self.lora_A[adapter_name] = conv_layer(self.in_features, r, kernel_size, stride, padding, bias=False)
11251132
self.lora_B[adapter_name] = conv_layer(
1126-
r, self.out_features // base_layer.groups, out_kernel, out_stride, bias=lora_bias
1133+
r, self.out_features, out_kernel, out_stride, groups=base_layer.groups, bias=lora_bias
11271134
)
11281135
self.lora_bias[adapter_name] = lora_bias
11291136

tests/test_custom_models.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@
118118
("Conv2d 1 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}),
119119
("Conv2d 2 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"], "use_dora": True}),
120120
("Conv2d Groups LoRA", "Conv2dGroups", LoraConfig, {"target_modules": ["conv2d"]}),
121+
("Conv2d Groups2 LoRA", "Conv2dGroups2", LoraConfig, {"target_modules": ["conv2d"]}),
121122
("Conv2d Groups LoRA with DoRA", "Conv2dGroups", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}),
123+
("Conv2d Groups2 LoRA with DoRA", "Conv2dGroups2", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}),
122124
("Conv3d 1 LoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d"]}),
123125
("Conv3d 2 LoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d", "lin0"]}),
124126
("Conv3d 1 LoRA with DoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d"], "use_dora": True}),
@@ -1082,16 +1084,43 @@ def forward(self, X):
10821084
class ModelConv2DGroups(nn.Module):
10831085
def __init__(self):
10841086
super().__init__()
1085-
self.conv2d = nn.Conv2d(5, 5, 3, groups=5)
1087+
self.lin0 = nn.Linear(90, 288)
1088+
# groups is set as 8 since default r=8
1089+
# hence to make r divisible by groups
1090+
self.conv2d = nn.Conv2d(16, 16, 3, groups=8)
10861091
self.relu = nn.ReLU()
10871092
self.flat = nn.Flatten()
1088-
self.lin0 = nn.Linear(5, 2)
1093+
self.lin1 = nn.Linear(16, 2)
10891094
self.sm = nn.LogSoftmax(dim=-1)
10901095
self.dtype = torch.float
10911096

10921097
def forward(self, X):
10931098
X = X.to(self.dtype)
1094-
X = X.reshape(-1, 5, 3, 3)
1099+
X = X.flatten()
1100+
X = self.lin0(X)
1101+
X = X.reshape(2, 16, 3, 3)
1102+
X = self.conv2d(X)
1103+
X = self.relu(X)
1104+
X = self.flat(X)
1105+
X = self.lin1(X)
1106+
X = self.sm(X)
1107+
return X
1108+
1109+
1110+
class ModelConv2DGroups2(nn.Module):
1111+
def __init__(self):
1112+
super().__init__()
1113+
self.conv2d = nn.Conv2d(16, 32, 3, padding=1, groups=2)
1114+
self.relu = nn.ReLU()
1115+
self.flat = nn.Flatten()
1116+
self.lin0 = nn.Linear(12800, 2)
1117+
self.sm = nn.LogSoftmax(dim=-1)
1118+
self.dtype = torch.float
1119+
1120+
def forward(self, X):
1121+
# Note: needs a different input shape, thus ignore original input
1122+
X = torch.arange(9 * 16 * 20 * 20).view([9, 16, 20, 20]).to(self.conv2d.weight.device)
1123+
X = X.to(self.dtype)
10951124
X = self.conv2d(X)
10961125
X = self.relu(X)
10971126
X = self.flat(X)
@@ -1170,6 +1199,9 @@ def from_pretrained(cls, model_id, torch_dtype=None):
11701199
if model_id == "Conv2dGroups":
11711200
return ModelConv2DGroups().to(torch_dtype)
11721201

1202+
if model_id == "Conv2dGroups2":
1203+
return ModelConv2DGroups2().to(torch_dtype)
1204+
11731205
if model_id == "Conv3d":
11741206
return ModelConv3D().to(torch_dtype)
11751207

@@ -1242,7 +1274,7 @@ def test_load_multiple_adapters(self, test_name, model_id, config_cls, config_kw
12421274
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
12431275
def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
12441276
# https://github.com/huggingface/peft/pull/2403
1245-
if model_id in ["Conv2dGroups"]:
1277+
if model_id in ["Conv2dGroups", "Conv2dGroups2"]:
12461278
pytest.skip(
12471279
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
12481280
)
@@ -1265,7 +1297,7 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
12651297
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
12661298
def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs):
12671299
# https://github.com/huggingface/peft/pull/2403
1268-
if model_id in ["Conv2dGroups"]:
1300+
if model_id in ["Conv2dGroups", "Conv2dGroups2"]:
12691301
pytest.skip(
12701302
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
12711303
)
@@ -1280,7 +1312,7 @@ def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs)
12801312
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
12811313
def test_merge_layers_is_idempotent(self, test_name, model_id, config_cls, config_kwargs):
12821314
# https://github.com/huggingface/peft/pull/2403
1283-
if model_id in ["Conv2dGroups"]:
1315+
if model_id in ["Conv2dGroups", "Conv2dGroups2"]:
12841316
pytest.skip(
12851317
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
12861318
)
@@ -1296,7 +1328,7 @@ def test_merge_layers_is_idempotent(self, test_name, model_id, config_cls, confi
12961328
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
12971329
def test_safe_merge(self, test_name, model_id, config_cls, config_kwargs):
12981330
# https://github.com/huggingface/peft/pull/2403
1299-
if model_id in ["Conv2dGroups"]:
1331+
if model_id in ["Conv2dGroups", "Conv2dGroups2"]:
13001332
pytest.skip(
13011333
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
13021334
)
@@ -1390,7 +1422,7 @@ def test_forward_float16(self, test_name, model_id, config_cls, config_kwargs):
13901422
# check that none of this raises an error
13911423
model(**X)
13921424

1393-
if model_id in ["Conv2dGroups"]:
1425+
if model_id in ["Conv2dGroups", "Conv2dGroups2"]:
13941426
# this model does not support merging
13951427
return
13961428

@@ -1432,7 +1464,7 @@ def test_forward_bfloat16(self, test_name, model_id, config_cls, config_kwargs):
14321464
# check that none of this raises an error
14331465
model(**X)
14341466

1435-
if model_id in ["Conv2dGroups"]:
1467+
if model_id in ["Conv2dGroups", "Conv2dGroups2"]:
14361468
# this model does not support merging
14371469
return
14381470

@@ -1473,7 +1505,7 @@ def test_forward_float16_no_autocast(self, test_name, model_id, config_cls, conf
14731505
# check that none of this raises an error
14741506
model(**X)
14751507

1476-
if model_id in ["Conv2dGroups"]:
1508+
if model_id in ["Conv2dGroups", "Conv2dGroups2"]:
14771509
# this model does not support merging
14781510
return
14791511

@@ -1514,7 +1546,7 @@ def test_forward_bfloat16_no_autocast(self, test_name, model_id, config_cls, con
15141546
# check that none of this raises an error
15151547
model(**X)
15161548

1517-
if model_id in ["Conv2dGroups"]:
1549+
if model_id in ["Conv2dGroups", "Conv2dGroups2"]:
15181550
# this model does not support merging
15191551
return
15201552

@@ -1685,7 +1717,7 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
16851717
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
16861718
def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, config_kwargs):
16871719
# https://github.com/huggingface/peft/pull/2403
1688-
if model_id in ["Conv2dGroups"]:
1720+
if model_id in ["Conv2dGroups", "Conv2dGroups2"]:
16891721
pytest.skip(
16901722
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
16911723
)

tests/test_initialization.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,79 @@ def test_lora_incompatible_mamba_modules(self):
13241324
with pytest.raises(ValueError, match=msg):
13251325
get_peft_model(model, config)
13261326

1327+
def get_model_conv2d_groups(self):
1328+
class ModelConv2DGroups(nn.Module):
1329+
"""For testing when groups argument is used in conv layer"""
1330+
1331+
def __init__(self):
1332+
super().__init__()
1333+
self.conv2d = nn.Conv2d(16, 32, 3, padding=1, groups=2)
1334+
self.relu = nn.ReLU()
1335+
self.flat = nn.Flatten()
1336+
self.lin0 = nn.Linear(12800, 2)
1337+
self.sm = nn.LogSoftmax(dim=-1)
1338+
self.dtype = torch.float
1339+
1340+
def forward(self, X):
1341+
# This is ignoring input since main usage is for checking raising of error when peft is applied
1342+
X = torch.arange(9 * 16 * 20 * 20).view([9, 16, 20, 20]).to(self.conv2d.weight.device)
1343+
X = X.to(self.dtype)
1344+
X = self.conv2d(X)
1345+
X = self.relu(X)
1346+
X = self.flat(X)
1347+
X = self.lin0(X)
1348+
X = self.sm(X)
1349+
return X
1350+
1351+
return ModelConv2DGroups().eval().to(self.torch_device)
1352+
1353+
@pytest.mark.parametrize(
1354+
"config_cls, config_kwargs",
1355+
[
1356+
pytest.param(LoraConfig, {"r": 8, "target_modules": ["conv2d"]}, id="lora with rank divisible by groups"),
1357+
pytest.param(LoraConfig, {"r": 2, "target_modules": ["conv2d"]}, id="lora with rank equal to groups"),
1358+
pytest.param(
1359+
LoraConfig, {"r": 1, "target_modules": ["conv2d"]}, id="lora with rank not divisible by groups"
1360+
),
1361+
pytest.param(
1362+
LoraConfig,
1363+
{"r": 8, "target_modules": ["conv2d"], "use_dora": True},
1364+
id="dora with rank divisible by groups",
1365+
),
1366+
pytest.param(
1367+
LoraConfig,
1368+
{"r": 2, "target_modules": ["conv2d"], "use_dora": True},
1369+
id="dora with rank equal to groups",
1370+
),
1371+
pytest.param(
1372+
LoraConfig,
1373+
{"r": 1, "target_modules": ["conv2d"], "use_dora": True},
1374+
id="dora with rank not divisible by groups",
1375+
),
1376+
],
1377+
)
1378+
def test_error_raised_if_rank_not_divisible_by_groups(self, config_cls, config_kwargs):
1379+
# This test checks if error is raised when rank is not divisible by groups for conv layer since
1380+
# currently, support is limited to conv layers where the rank is divisible by groups in lora and dora
1381+
base_model = self.get_model_conv2d_groups()
1382+
peft_config = config_cls(**config_kwargs)
1383+
r = config_kwargs["r"]
1384+
base_layer = base_model.conv2d
1385+
groups = base_layer.groups
1386+
if r % groups != 0:
1387+
with pytest.raises(
1388+
ValueError,
1389+
match=(
1390+
f"Targeting a {base_layer.__class__.__name__} with groups={base_layer.groups} and rank {r}. "
1391+
"Currently, support is limited to conv layers where the rank is divisible by groups. "
1392+
"Either choose a different rank or do not target this specific layer."
1393+
),
1394+
):
1395+
peft_model = get_peft_model(base_model, peft_config)
1396+
else:
1397+
# No error should be raised
1398+
peft_model = get_peft_model(base_model, peft_config)
1399+
13271400

13281401
class TestLokrInitialization:
13291402
torch_device = infer_device()

0 commit comments

Comments
 (0)