You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hotswap allow different alpha scalings and ranks (#2177)
Hotswapping of LoRA adapters is already implemented, but when alpha
scalings or ranks differ, this triggers recompilation of the model is
compiled, which is inefficient. Users can now call
prepare_model_for_compiled_hotswap to prevent recompilation in many
cases (see the doc update for caveats).
Copy file name to clipboardExpand all lines: docs/source/package_reference/hotswap.md
+35-4Lines changed: 35 additions & 4 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -8,6 +8,8 @@ The idea of hotswapping an adapter is the following: We can already load multipl
8
8
9
9
In general, this should be faster than deleting one adapter and loading the adapter in its place, which would be the how to achieve the same final outcome without hotswapping. Another advantage of hotswapping is that it prevents re-compilation in case the PEFT model is already compiled using `torch.compile`. This can save quite a lot of time.
10
10
11
+
## Example without `torch.compile`
12
+
11
13
```python
12
14
import torch
13
15
from transformers import AutoModelForCausalLM
@@ -21,7 +23,6 @@ model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
21
23
22
24
# load lora 0
23
25
model = PeftModel.from_pretrained(model, <path-adapter-0>)
24
-
model = torch.compile(model) # optionally compile the model
25
26
with torch.inference_mode():
26
27
output_adapter_0 = model(inputs)
27
28
@@ -31,12 +32,42 @@ with torch.inference_mode():
31
32
output_adapter_1 = model(inputs).logits
32
33
```
33
34
35
+
## Example with `torch.compile`
36
+
37
+
```python
38
+
import torch
39
+
from transformers import AutoModelForCausalLM
40
+
from peft import PeftModel
41
+
from peft.utils.hotswap import hotswap_adapter, prepare_model_for_compiled_hotswap
42
+
43
+
model_id =...
44
+
inputs =...
45
+
device =...
46
+
max_rank =...# maximum rank among all LoRA adapters that will be used
47
+
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
48
+
49
+
# load lora 0
50
+
model = PeftModel.from_pretrained(model, <path-adapter-0>)
51
+
# Prepare the model to allow hotswapping even if ranks/scalings of 2nd adapter differ.
52
+
# You can skip this step if all ranks and scalings are identical.
Hotswapping works with transformers models and diffusers models. However, there are some caveats:
35
67
36
-
- It only works for the same PEFT method, so no swapping LoRA and LoHa, for example.
37
68
- Right now, only LoRA is properly supported.
38
-
-The adapters must be compatible (e.g. same LoRA alpha, same target modules).
39
-
-If you use `torch.compile` and want to avoid recompilation, the LoRA rank must be the same.
69
+
-It only works for the same PEFT method, so no swapping LoRA and LoHa, for example.
70
+
-The adapter that is being swapped in must target the same layers as the previous adapter or a subset of those layers. It cannot target new layers. Therefore, if possible, start with the adapter that targets most layers.
0 commit comments