Skip to content

Commit 0c15622

Browse files
authored
ENH DoRA optimization for ConvNd if dropout=0. (huggingface#2371)
1 parent 1e21444 commit 0c15622

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

src/peft/tuners/lora/dora.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
139139
weight_norm = weight.norm(p=2, dim=dim, keepdim=True).transpose(1, 0)
140140
return weight_norm
141141

142-
def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
142+
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None):
143143
"""
144144
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
145145
output.
@@ -157,8 +157,9 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
157157
# during backpropagation"
158158
weight_norm = weight_norm.detach()
159159
mag_norm_scale = magnitude / weight_norm
160-
result_dora = (mag_norm_scale - 1) * (
161-
self.conv_fn(
160+
161+
if base_result is None:
162+
base_result = self.conv_fn(
162163
x,
163164
weight,
164165
bias=None,
@@ -167,8 +168,14 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
167168
dilation=base_layer.dilation,
168169
groups=base_layer.groups,
169170
)
170-
) + mag_norm_scale * lora_B(lora_A(x)) * scaling
171+
else:
172+
bias = base_layer.bias
173+
if bias is not None:
174+
# reshape bias to (1, -1, 1, ...)
175+
bias_shape = (1, -1) + (1,) * (base_result.dim() - 2)
176+
base_result = base_result - bias.view(*bias_shape)
171177

178+
result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_B(lora_A(x)) * scaling
172179
return result_dora
173180

174181
def __repr__(self) -> str:

src/peft/tuners/lora/layer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
12721272
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
12731273
elif self.merged:
12741274
result = self.base_layer(x, *args, **kwargs)
1275+
12751276
else:
12761277
result = self.base_layer(x, *args, **kwargs)
12771278
torch_result_dtype = result.dtype
@@ -1288,13 +1289,19 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
12881289
if not self.use_dora[active_adapter]:
12891290
result = result + lora_B(lora_A(dropout(x))) * scaling
12901291
else:
1291-
x = dropout(x)
1292+
if isinstance(dropout, nn.Identity) or not self.training:
1293+
base_result = result
1294+
else:
1295+
x = dropout(x)
1296+
base_result = None
1297+
12921298
result = result + self.lora_magnitude_vector[active_adapter](
12931299
x,
12941300
lora_A=lora_A,
12951301
lora_B=lora_B,
12961302
scaling=scaling,
12971303
base_layer=self.get_base_layer(),
1304+
base_result=base_result,
12981305
)
12991306

13001307
result = result.to(torch_result_dtype)

0 commit comments

Comments
 (0)