diff --git a/docs/source/developer_guides/lora.md b/docs/source/developer_guides/lora.md index e4b6ec28e0..ebb60f133f 100644 --- a/docs/source/developer_guides/lora.md +++ b/docs/source/developer_guides/lora.md @@ -138,9 +138,21 @@ from peft import PeftModel model = PeftModel.from_pretrained(base_model, peft_model_id, ephemeral_gpu_offload=True) ``` +DoRA is optimized (computes faster and takes less memory) for models in the evaluation mode, or when dropout is set to 0. We reuse the +base result at those times to get the speedup. +Running [dora finetuning](https://github.com/huggingface/peft/blob/main/examples/dora_finetuning/dora_finetuning.py) +with `CUDA_VISIBLE_DEVICES=0 time python examples/dora_finetuning/dora_finetuning.py --quantize --lora_dropout 0 --batch_size 16 --eval_step 2 --use_dora` +on a 4090 with gradient accumulation set to 2 and max step to 20 resulted with the following observations: + +| | Without Optimization | With Optimization | +| :--: | :--: | :--: | +| train_runtime | 359.7298 | **279.2676** | +| train_samples_per_second | 1.779 | **2.292** | +| train_steps_per_second | 0.056 | **0.072** | + #### Caveats -- DoRA only supports linear and Conv2d layers at the moment. +- DoRA only supports embedding, linear, and Conv2d layers at the moment. - DoRA introduces a bigger overhead than pure LoRA, so it is recommended to merge weights for inference, see [`LoraModel.merge_and_unload`]. - DoRA should work with weights quantized with bitsandbytes ("QDoRA"). However, issues have been reported when using QDoRA with DeepSpeed Zero2. diff --git a/src/peft/tuners/lora/bnb.py b/src/peft/tuners/lora/bnb.py index 3b37faf74b..7f51b0ba54 100644 --- a/src/peft/tuners/lora/bnb.py +++ b/src/peft/tuners/lora/bnb.py @@ -235,20 +235,24 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: x = x.to(compute_dtype) if not self.use_dora[active_adapter]: - output = lora_B(lora_A(dropout(x))) * scaling + result = result + lora_B(lora_A(dropout(x))) * scaling else: - x = dropout(x) - output = self.lora_magnitude_vector[active_adapter]( + if isinstance(dropout, torch.nn.Identity) or not self.training: + base_result = result + else: + x = dropout(x) + base_result = None + + result = result + self.lora_magnitude_vector[active_adapter]( x, lora_A=lora_A, lora_B=lora_B, scaling=scaling, base_layer=self.get_base_layer(), + base_result=base_result, ) if requires_conversion: - output = output.to(expected_dtype) - - result = result + output + result = result.to(expected_dtype) return result @@ -486,20 +490,24 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: x = x.to(lora_A.weight.dtype) if not self.use_dora[active_adapter]: - output = lora_B(lora_A(dropout(x))) * scaling + result = result + lora_B(lora_A(dropout(x))) * scaling else: - x = dropout(x) - output = self.lora_magnitude_vector[active_adapter]( + if isinstance(dropout, torch.nn.Identity) or not self.training: + base_result = result + else: + x = dropout(x) + base_result = None + + result = result + self.lora_magnitude_vector[active_adapter]( x, lora_A=lora_A, lora_B=lora_B, scaling=scaling, base_layer=self.get_base_layer(), + base_result=base_result, ) if requires_conversion: - output = output.to(expected_dtype) - - result = result + output + result = result.to(expected_dtype) return result diff --git a/src/peft/tuners/lora/dora.py b/src/peft/tuners/lora/dora.py index 3125e3c716..9d8cd9a02f 100644 --- a/src/peft/tuners/lora/dora.py +++ b/src/peft/tuners/lora/dora.py @@ -62,13 +62,11 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=Fals weight_norm = weight_norm.to("cpu") self.weight = nn.Parameter(weight_norm, requires_grad=True) - def forward(self, x, *, lora_A, lora_B, scaling, base_layer): + def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None): """ For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer output. """ - lora_result = lora_B(lora_A(x)) - # Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead, # calculate the same but using forward. x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device, dtype=x.dtype) @@ -86,19 +84,18 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer): # during backpropagation" weight_norm = weight_norm.detach() mag_norm_scale = (magnitude / weight_norm).view(1, -1) - result_dora = (mag_norm_scale - 1) * ( - F.linear(x, transpose(weight, self.fan_in_fan_out)) - ) + mag_norm_scale * lora_result * scaling - - # Note: Computation could potentially be accelerated by using the code below instead of calculating X@W again. - # This is only correct if dropout=0, otherwise results will differ: - # https://github.com/huggingface/peft/pull/1474#issuecomment-1964682771 - # bias = self.get_base_layer().bias - # if bias is not None: - # result = result - bias - # result = mag_norm_scale * result + mag_norm_scale * lora_B(lora_A(x)) * scaling - # if bias is not None: - # result = result + bias + + lora_result = lora_B(lora_A(x)) + + bias = None + if base_result is not None: + bias = base_layer.bias + if bias is not None: + base_result = base_result - bias + else: + base_result = F.linear(x, transpose(weight, self.fan_in_fan_out)) + + result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * scaling return result_dora diff --git a/src/peft/tuners/lora/hqq.py b/src/peft/tuners/lora/hqq.py index c3a51e7b27..7e997275d8 100644 --- a/src/peft/tuners/lora/hqq.py +++ b/src/peft/tuners/lora/hqq.py @@ -218,13 +218,24 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: x = x.to(compute_dtype) if not self.use_dora[active_adapter]: - output = lora_B(lora_A(dropout(x))) * scaling + result = result + lora_B(lora_A(dropout(x))) * scaling else: - output = self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) + if isinstance(dropout, torch.nn.Identity) or not self.training: + base_result = result + else: + x = dropout(x) + base_result = None + + result = result + self.lora_magnitude_vector[active_adapter]( + x, + lora_A=lora_A, + lora_B=lora_B, + scaling=scaling, + base_layer=self.get_base_layer(), + base_result=base_result, + ) if requires_conversion: - output = output.to(expected_dtype) - - result = result + output + result = result.to(expected_dtype) return result diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index f3359ca9a8..ec9ddb58db 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -585,13 +585,19 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: if not self.use_dora[active_adapter]: result = result + lora_B(lora_A(dropout(x))) * scaling else: - x = dropout(x) + if isinstance(dropout, nn.Identity) or not self.training: + base_result = result + else: + x = dropout(x) + base_result = None + result = result + self.lora_magnitude_vector[active_adapter]( x, lora_A=lora_A, lora_B=lora_B, scaling=scaling, base_layer=self.get_base_layer(), + base_result=base_result, ) result = result.to(torch_result_dtype) diff --git a/src/peft/tuners/lora/tp_layer.py b/src/peft/tuners/lora/tp_layer.py index 394f3af2dd..8c33889374 100644 --- a/src/peft/tuners/lora/tp_layer.py +++ b/src/peft/tuners/lora/tp_layer.py @@ -201,23 +201,21 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): x = x.to(lora_A.weight.dtype) if not self.use_dora[active_adapter]: - lora_result = lora_A(dropout(x)) - if isinstance(lora_result, tuple): - lora_result = lora_result[0] - lora_result = lora_B(lora_result) - if isinstance(lora_result, tuple): - lora_result = lora_result[0] - lora_result = lora_result * scaling - - result = result + lora_result + result = result + lora_B(lora_A(dropout(x))) * scaling else: - x = dropout(x) + if isinstance(dropout, torch.nn.Identity) or not self.training: + base_result = result + else: + x = dropout(x) + base_result = None + result = result + self.lora_magnitude_vector[active_adapter]( x, lora_A=lora_A, lora_B=lora_B, scaling=scaling, base_layer=self.get_base_layer(), + base_result=base_result, ) result = result.to(torch_result_dtype)