Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,17 +200,37 @@ from peft import PeftModel
model = PeftModel.from_pretrained(base_model, peft_model_id, ephemeral_gpu_offload=True)
```

#### Optimization

DoRA is optimized (computes faster and takes less memory) for models in the evaluation mode, or when dropout is set to 0. We reuse the
base result at those times to get the speedup.
Running [dora finetuning](https://github.com/huggingface/peft/blob/main/examples/dora_finetuning/dora_finetuning.py)
with `CUDA_VISIBLE_DEVICES=0 ZE_AFFINITY_MASK=0 time python examples/dora_finetuning/dora_finetuning.py --quantize --lora_dropout 0 --batch_size 16 --eval_step 2 --use_dora`
on a 4090 with gradient accumulation set to 2 and max step to 20 resulted with the following observations:
with `CUDA_VISIBLE_DEVICES=0 ZE_AFFINITY_MASK=0 time python examples/dora_finetuning/dora_finetuning.py --quantize --lora_dropout 0 --batch_size 16 --eval_step 2 --use_dora` on a 4090 with gradient accumulation set to 2 and max step to 20 resulted with the following observations:

| | Without Optimization | With Optimization |
| :--: | :--: | :--: |
| train_runtime | 359.7298 | **279.2676** |
| train_samples_per_second | 1.779 | **2.292** |
| train_steps_per_second | 0.056 | **0.072** |
| train runtime (sec) | 359.7298 | **279.2676** |
| train samples per second | 1.779 | **2.292** |
| train steps per second | 0.056 | **0.072** |

Moreover, it is possible to further increase runtime performance of DoRA by using the [`DoraCaching`] helper context. This requires the model to be in `eval` mode:

```py
from peft.helpers import DoraCaching

model.eval()
with DoraCaching():
output = model(inputs)
```

For [`meta-llama/Llama-3.1-8B`](https://huggingface.co/meta-llama/Llama-3.1-8B), the [DoRA caching benchmark script](https://github.com/huggingface/peft/blob/main/examples/dora_finetuning/dora-caching.py) shows that, compared to LoRA:

- DoRA without caching requires 139% more time
- DoRA without caching requires 4% more memory
- DoRA with caching requires 17% more time
- DoRA with caching requires 41% more memory

Caching can thus make inference with DoRA significantly faster but it also requires signficantly more memory. Ideally, if the use case allows it, just merge the DoRA adapter to avoid both memory and runtime overhead.

#### Caveats

Expand Down
5 changes: 5 additions & 0 deletions docs/source/package_reference/helpers.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,8 @@ A collection of helper functions for PEFT.

[[autodoc]] helpers.disable_input_dtype_casting
- all

## Context manager to enable DoRA caching (faster at inference time but requires more memory)

[[autodoc]] helpers.DoraCaching
- all
126 changes: 126 additions & 0 deletions examples/dora_finetuning/dora-caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Small script to measure DoRA caching efficiency
"""

import argparse
import time
from contextlib import contextmanager

import torch
from transformers import AutoModelForCausalLM

from peft import LoraConfig, get_peft_model
from peft.helpers import DoraCaching
from peft.utils import infer_device


device = infer_device()
# check for CPU
if device == "cpu":
raise ValueError("This benchmark requires a hardware accelerator, only found CPU")
torch_accelerator_module = getattr(torch, device, torch.cuda)


@contextmanager
def timeit(logs):
start = time.perf_counter()
yield
end = time.perf_counter()
dur = end - start
logs["time"].append(dur)


def run_benchmark(model, num_runs):
logs = {
"time": [],
}

mem_start = torch_accelerator_module.max_memory_reserved()
for _ in range(num_runs + 1):
with timeit(logs):
for i in range(3):
x = torch.randint(10, 100, (1, 50)).to(device)
model(x)
mem_end = torch_accelerator_module.max_memory_reserved()
logs["memory"] = (mem_end - mem_start) / 1024**2

# remove the first run (warm up)
del logs["time"][0]
return logs


def main(model_id, num_runs):
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=device)
base_memory = torch_accelerator_module.max_memory_reserved() / 1024**2

# LORA
config = LoraConfig(init_lora_weights=False, use_dora=False)
model = get_peft_model(model, config)
model.eval()
torch_accelerator_module.reset_peak_memory_stats()
logs_lora = run_benchmark(model, num_runs)
avg_duration_lora = sum(logs_lora["time"]) / num_runs
max_memory_lora = logs_lora["memory"] + base_memory

# DORA
del model
torch_accelerator_module.empty_cache()

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=device)
base_memory = torch_accelerator_module.max_memory_reserved() / 1024**2
config = LoraConfig(init_lora_weights=False, use_dora=True)
model = get_peft_model(model, config)
model.eval()

# WITHOUT CACHING
torch_accelerator_module.reset_peak_memory_stats()
logs_dora_no_caching = run_benchmark(model, num_runs)
avg_duration_no_caching = sum(logs_dora_no_caching["time"]) / num_runs
max_memory_no_caching = logs_dora_no_caching["memory"] + base_memory

# WITH CACHING
torch_accelerator_module.reset_peak_memory_stats()
with DoraCaching():
logs_dora_caching = run_benchmark(model, num_runs)
avg_duration_caching = sum(logs_dora_caching["time"]) / num_runs
max_memory_caching = logs_dora_caching["memory"] + base_memory

print(
f"Benchmark results for model {model_id} with {num_runs} runs:\n\n"
f"avg time LoRA: {avg_duration_lora:.4f} sec\n"
f"avg time DoRA no caching: {avg_duration_no_caching:.4f} sec\n"
f"avg time DoRA with caching: {avg_duration_caching:.4f} sec\n"
f"\n"
f"memory LoRA: {max_memory_lora:.2f} MB\n"
f"memory DoRA no caching: {max_memory_no_caching:.2f} MB\n"
f"memory DoRA with caching: {max_memory_caching:.2f} MB\n"
f"\n"
f"DoRA time overhead no caching: {(avg_duration_no_caching - avg_duration_lora) / avg_duration_lora * 100:.2f}%\n"
f"DoRA time overhead with caching: {(avg_duration_caching - avg_duration_lora) / avg_duration_lora * 100:.2f}%\n"
f"\n"
f"DoRA memory overhead no caching: {(max_memory_no_caching - max_memory_lora) / max_memory_lora * 100:.2f}%\n"
f"DoRA memory overhead with caching: {(max_memory_caching - max_memory_lora) / max_memory_lora * 100:.2f}%"
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark DoRA caching efficiency")
parser.add_argument("--model_id", type=str, default="meta-llama/Llama-3.1-8B", help="Model ID to benchmark")
parser.add_argument("--num_runs", type=int, default=10, help="Number of runs for the benchmark")
args = parser.parse_args()

main(args.model_id, args.num_runs)
43 changes: 42 additions & 1 deletion src/peft/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import nn

from .peft_model import PeftConfig, PeftModel
from .tuners.lora import LoraLayer
from .tuners.lora import LoraLayer, dora
from .tuners.tuners_utils import BaseTunerLayer


Expand Down Expand Up @@ -249,3 +249,44 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True):
continue
if name in original_values:
module.cast_input_dtype_enabled = original_values[name]


class DoraCaching:
"""Context manager to enable DoRA caching, which improves speed of DoRA inference at the expense of memory.

Even within the caching context, if the model is in training mode, caching is disabled. When the model switches to
training mode, the cache will be cleared.

Example:

```py
>>> from peft.helpers import enable_dora_scaling

>>> model.eval() # put in eval model for caching to work

>>> with DoraCaching(): # use as a context manager
... output = model(inputs)

>>> dora_caching = DoraCaching()
>>> dora_caching(enabled=True) # permanently enable caching
>>> output = model(inputs)
>>> dora_caching(enabled=False) # permanently disable caching
>>> output = model(inputs)
```

"""

def __init__(self, enabled: bool = True) -> None:
self.enabled = enabled
self.prev_value = None

def __enter__(self):
self.prev_value = dora.ENABLE_DORA_CACHING
dora.ENABLE_DORA_CACHING = self.enabled

def __exit__(self, type, value, traceback):
dora.ENABLE_DORA_CACHING = self.prev_value
self.prev_value = None

def __call__(self, enabled: bool = True):
dora.ENABLE_DORA_CACHING = enabled
Loading
Loading