Skip to content

Commit cd51f2a

Browse files
ariG23498yaswanth19
authored andcommitted
ENH Faster DoRA in when no dropout/eval mode (huggingface#2122)
1 parent 6ada0cc commit cd51f2a

File tree

6 files changed

+77
-45
lines changed

6 files changed

+77
-45
lines changed

docs/source/developer_guides/lora.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,21 @@ from peft import PeftModel
138138
model = PeftModel.from_pretrained(base_model, peft_model_id, ephemeral_gpu_offload=True)
139139
```
140140

141+
DoRA is optimized (computes faster and takes less memory) for models in the evaluation mode, or when dropout is set to 0. We reuse the
142+
base result at those times to get the speedup.
143+
Running [dora finetuning](https://github.com/huggingface/peft/blob/main/examples/dora_finetuning/dora_finetuning.py)
144+
with `CUDA_VISIBLE_DEVICES=0 time python examples/dora_finetuning/dora_finetuning.py --quantize --lora_dropout 0 --batch_size 16 --eval_step 2 --use_dora`
145+
on a 4090 with gradient accumulation set to 2 and max step to 20 resulted with the following observations:
146+
147+
| | Without Optimization | With Optimization |
148+
| :--: | :--: | :--: |
149+
| train_runtime | 359.7298 | **279.2676** |
150+
| train_samples_per_second | 1.779 | **2.292** |
151+
| train_steps_per_second | 0.056 | **0.072** |
152+
141153
#### Caveats
142154

143-
- DoRA only supports linear and Conv2d layers at the moment.
155+
- DoRA only supports embedding, linear, and Conv2d layers at the moment.
144156
- DoRA introduces a bigger overhead than pure LoRA, so it is recommended to merge weights for inference, see [`LoraModel.merge_and_unload`].
145157
- DoRA should work with weights quantized with bitsandbytes ("QDoRA"). However, issues have been reported when using QDoRA with DeepSpeed Zero2.
146158

src/peft/tuners/lora/bnb.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -235,20 +235,24 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
235235
x = x.to(compute_dtype)
236236

237237
if not self.use_dora[active_adapter]:
238-
output = lora_B(lora_A(dropout(x))) * scaling
238+
result = result + lora_B(lora_A(dropout(x))) * scaling
239239
else:
240-
x = dropout(x)
241-
output = self.lora_magnitude_vector[active_adapter](
240+
if isinstance(dropout, torch.nn.Identity) or not self.training:
241+
base_result = result
242+
else:
243+
x = dropout(x)
244+
base_result = None
245+
246+
result = result + self.lora_magnitude_vector[active_adapter](
242247
x,
243248
lora_A=lora_A,
244249
lora_B=lora_B,
245250
scaling=scaling,
246251
base_layer=self.get_base_layer(),
252+
base_result=base_result,
247253
)
248254
if requires_conversion:
249-
output = output.to(expected_dtype)
250-
251-
result = result + output
255+
result = result.to(expected_dtype)
252256

253257
return result
254258

@@ -486,20 +490,24 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
486490
x = x.to(lora_A.weight.dtype)
487491

488492
if not self.use_dora[active_adapter]:
489-
output = lora_B(lora_A(dropout(x))) * scaling
493+
result = result + lora_B(lora_A(dropout(x))) * scaling
490494
else:
491-
x = dropout(x)
492-
output = self.lora_magnitude_vector[active_adapter](
495+
if isinstance(dropout, torch.nn.Identity) or not self.training:
496+
base_result = result
497+
else:
498+
x = dropout(x)
499+
base_result = None
500+
501+
result = result + self.lora_magnitude_vector[active_adapter](
493502
x,
494503
lora_A=lora_A,
495504
lora_B=lora_B,
496505
scaling=scaling,
497506
base_layer=self.get_base_layer(),
507+
base_result=base_result,
498508
)
499509
if requires_conversion:
500-
output = output.to(expected_dtype)
501-
502-
result = result + output
510+
result = result.to(expected_dtype)
503511

504512
return result
505513

src/peft/tuners/lora/dora.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,11 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=Fals
6262
weight_norm = weight_norm.to("cpu")
6363
self.weight = nn.Parameter(weight_norm, requires_grad=True)
6464

65-
def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
65+
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None):
6666
"""
6767
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
6868
output.
6969
"""
70-
lora_result = lora_B(lora_A(x))
71-
7270
# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
7371
# calculate the same but using forward.
7472
x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device, dtype=x.dtype)
@@ -86,19 +84,18 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
8684
# during backpropagation"
8785
weight_norm = weight_norm.detach()
8886
mag_norm_scale = (magnitude / weight_norm).view(1, -1)
89-
result_dora = (mag_norm_scale - 1) * (
90-
F.linear(x, transpose(weight, self.fan_in_fan_out))
91-
) + mag_norm_scale * lora_result * scaling
92-
93-
# Note: Computation could potentially be accelerated by using the code below instead of calculating X@W again.
94-
# This is only correct if dropout=0, otherwise results will differ:
95-
# https://github.com/huggingface/peft/pull/1474#issuecomment-1964682771
96-
# bias = self.get_base_layer().bias
97-
# if bias is not None:
98-
# result = result - bias
99-
# result = mag_norm_scale * result + mag_norm_scale * lora_B(lora_A(x)) * scaling
100-
# if bias is not None:
101-
# result = result + bias
87+
88+
lora_result = lora_B(lora_A(x))
89+
90+
bias = None
91+
if base_result is not None:
92+
bias = base_layer.bias
93+
if bias is not None:
94+
base_result = base_result - bias
95+
else:
96+
base_result = F.linear(x, transpose(weight, self.fan_in_fan_out))
97+
98+
result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * scaling
10299

103100
return result_dora
104101

src/peft/tuners/lora/hqq.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,24 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
218218
x = x.to(compute_dtype)
219219

220220
if not self.use_dora[active_adapter]:
221-
output = lora_B(lora_A(dropout(x))) * scaling
221+
result = result + lora_B(lora_A(dropout(x))) * scaling
222222
else:
223-
output = self._apply_dora(x, lora_A, lora_B, scaling, active_adapter)
223+
if isinstance(dropout, torch.nn.Identity) or not self.training:
224+
base_result = result
225+
else:
226+
x = dropout(x)
227+
base_result = None
228+
229+
result = result + self.lora_magnitude_vector[active_adapter](
230+
x,
231+
lora_A=lora_A,
232+
lora_B=lora_B,
233+
scaling=scaling,
234+
base_layer=self.get_base_layer(),
235+
base_result=base_result,
236+
)
224237
if requires_conversion:
225-
output = output.to(expected_dtype)
226-
227-
result = result + output
238+
result = result.to(expected_dtype)
228239

229240
return result
230241

src/peft/tuners/lora/layer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,13 +585,19 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
585585
if not self.use_dora[active_adapter]:
586586
result = result + lora_B(lora_A(dropout(x))) * scaling
587587
else:
588-
x = dropout(x)
588+
if isinstance(dropout, nn.Identity) or not self.training:
589+
base_result = result
590+
else:
591+
x = dropout(x)
592+
base_result = None
593+
589594
result = result + self.lora_magnitude_vector[active_adapter](
590595
x,
591596
lora_A=lora_A,
592597
lora_B=lora_B,
593598
scaling=scaling,
594599
base_layer=self.get_base_layer(),
600+
base_result=base_result,
595601
)
596602

597603
result = result.to(torch_result_dtype)

src/peft/tuners/lora/tp_layer.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -201,23 +201,21 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
201201
x = x.to(lora_A.weight.dtype)
202202

203203
if not self.use_dora[active_adapter]:
204-
lora_result = lora_A(dropout(x))
205-
if isinstance(lora_result, tuple):
206-
lora_result = lora_result[0]
207-
lora_result = lora_B(lora_result)
208-
if isinstance(lora_result, tuple):
209-
lora_result = lora_result[0]
210-
lora_result = lora_result * scaling
211-
212-
result = result + lora_result
204+
result = result + lora_B(lora_A(dropout(x))) * scaling
213205
else:
214-
x = dropout(x)
206+
if isinstance(dropout, torch.nn.Identity) or not self.training:
207+
base_result = result
208+
else:
209+
x = dropout(x)
210+
base_result = None
211+
215212
result = result + self.lora_magnitude_vector[active_adapter](
216213
x,
217214
lora_A=lora_A,
218215
lora_B=lora_B,
219216
scaling=scaling,
220217
base_layer=self.get_base_layer(),
218+
base_result=base_result,
221219
)
222220

223221
result = result.to(torch_result_dtype)

0 commit comments

Comments
 (0)