Skip to content

Commit f51203f

Browse files
githubnemoMarcus GawronskyBenjaminBossan
authored
Standalone Custom Tokens Tuner and integrated into LoRA (#2376)
This change is based on the nifty addition of @marcusinthesky from #1541. When adding tokens or fine-tuning the representation of specific tokens we currently have little choice but to retrain the whole embedding matrix which can be huge and adds to the memory footprint (in RAM but also on disk). This method creates a sparse matrix of shape (n, embed_dim) where n is the number of tokens to be customized and only trains these few values. This change introduces two ways of using it: ``` peft_config = TrainableTokensConfig(target_modules=['embed_tokens'], token_indices=[0, 1, 2]) peft_model = get_peft_model(model, peft_config) ``` and with LoRA ``` peft_config = LoraConfig( target_modules='all-linear', trainable_token_indices={'embed_tokens': [0, 1, 2]}, ) peft_model = get_peft_model(model, peft_config) ``` Adding this feature to adapters other than LoRA should be relatively easy, mostly adding the `trainable_token_indices` config option and some debugging. To make this change it was necessary to change the `modules_to_save` infrastructure as combining this feature with LoRA is quite similar. This refactoring entailed moving most of the basic functionality of `ModulesToSave` to the `AuxiliaryTrainingWrapper` class. This also changes the logic how `modules_to_save` is loaded/saved from from the state dict, so there could still be bugs here. This implementation does not entail support for weight-tied layers yet. This will follow in a future change. --- Notable commits in this squash: * Use unload_and_optionally_merge_module protocol With `AuxiliaryTrainingWrapper` as abstraction it is probably a good idea to have support for `unload_and_optionally_merge_module`. Since the wrapper is more akin to a PEFT layer than a model the name semantics are fine and it does basically the same job. * trainable tokens is also trained in certain adapters Before, the assumption was that modules_to_save was the only thing that is trained alongside an adapter's parameters. Now there's also the token_adapter delta tokens via `NewTokensWrapper`. * Remove old modules_to_save handling This is now all handled via the `AuxiliaryTrainingWrapper`. * Fix modules_to_save module overwriting The state dict imlementation of ModulesToSaveWrapper was incorrect in that it did not include its own parameters, just the parameters it needs to overwrite in the end. I.e. if layer `lin1` is modules to save wrapped, `lin1.{weight,bias}` is saved and overwritten but `lin1.modules_to_save.<adpater_name>.[...]` is not saved. * Introduce a load key map for aux. train wrapper Before this change it was only possible to remove a key prefix from the wrapper's state dict (e.g., `modules_to_save.default.weight` -> `weight`); now it is possible to restore such reduced value by mapping the key back (i.e., `weight` -> `modules_to_save.default.weight`). * Replace sparse matrix with dense + index_copy This change is mostly because sparse matrices are not that beneficial in this case (at least not from what we can see right now) and they do not solve the problem of having to change the new tokens in-place to avoid outdated deltas when new token vectors are initialized randomly after loading the deltas. * Make peft_config.layers_to_transform optional Before this change the base tuner class was forcing this attribute to be present on the config class even though the attribute is not specified in the base config. * Implement missing key logic in `_set_trainable` Before this it was not checked if the targeted module by `modules_to_save` or `trainable_token_indices` existed or not (when used in conjunction with a PEFT method). In this case an error message similar to the `inject_adapter` error is raised when no module is found. --------- Co-authored-by: Marcus Gawronsky <[email protected]> Co-authored-by: Benjamin Bossan <[email protected]>
1 parent 3dd2668 commit f51203f

File tree

22 files changed

+1872
-119
lines changed

22 files changed

+1872
-119
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@
122122
title: CPT
123123
- local: package_reference/bone
124124
title: Bone
125+
- local: package_reference/trainable_tokens
126+
title: Trainable Tokens
125127

126128
title: Adapters
127129
- sections:

docs/source/developer_guides/lora.md

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ config = LoraConfig(init_lora_weights=False, ...)
4141
```
4242

4343
### PiSSA
44-
[PiSSA](https://arxiv.org/abs/2404.02948) initializes the LoRA adapter using the principal singular values and singular vectors. This straightforward modification allows PiSSA to converge more rapidly than LoRA and ultimately attain superior performance. Moreover, PiSSA reduces the quantization error compared to QLoRA, leading to further enhancements.
44+
[PiSSA](https://arxiv.org/abs/2404.02948) initializes the LoRA adapter using the principal singular values and singular vectors. This straightforward modification allows PiSSA to converge more rapidly than LoRA and ultimately attain superior performance. Moreover, PiSSA reduces the quantization error compared to QLoRA, leading to further enhancements.
4545

4646
Configure the initialization method to "pissa", which may take several minutes to execute SVD on the pre-trained model:
4747
```python
@@ -50,18 +50,18 @@ config = LoraConfig(init_lora_weights="pissa", ...)
5050
```
5151
Alternatively, execute fast SVD, which takes only a few seconds. The number of iterations determines the trade-off between the error and computation time:
5252
```python
53-
lora_config = LoraConfig(init_lora_weights="pissa_niter_[number of iters]", ...)
53+
lora_config = LoraConfig(init_lora_weights="pissa_niter_[number of iters]", ...)
5454
```
5555
For detailed instruction on using PiSSA, please follow [these instructions](https://github.com/huggingface/peft/tree/main/examples/pissa_finetuning).
5656

5757
### CorDA
5858

5959
[CorDA](https://arxiv.org/pdf/2406.05223) builds task-aware LoRA adapters from weight decomposition oriented by the context of downstream task to learn (instruction-previewed mode, IPM) or world knowledge to maintain (knowledge-preserved mode, KPM).
6060
The KPM not only achieves better performance than LoRA on fine-tuning tasks, but also mitigates the catastrophic forgetting of pre-trained world knowledge.
61-
When preserving pre-trained knowledge is not a concern,
62-
the IPM is favored because it can further accelerate convergence and enhance the fine-tuning performance.
61+
When preserving pre-trained knowledge is not a concern,
62+
the IPM is favored because it can further accelerate convergence and enhance the fine-tuning performance.
6363

64-
You need to configure the initialization method to "corda", and specify the mode of IPM or KPM and the dataset to collect covariance matrices.
64+
You need to configure the initialization method to "corda", and specify the mode of IPM or KPM and the dataset to collect covariance matrices.
6565

6666
```py
6767
@torch.no_grad()
@@ -201,7 +201,7 @@ model = PeftModel.from_pretrained(base_model, peft_model_id, ephemeral_gpu_offlo
201201
```
202202

203203
DoRA is optimized (computes faster and takes less memory) for models in the evaluation mode, or when dropout is set to 0. We reuse the
204-
base result at those times to get the speedup.
204+
base result at those times to get the speedup.
205205
Running [dora finetuning](https://github.com/huggingface/peft/blob/main/examples/dora_finetuning/dora_finetuning.py)
206206
with `CUDA_VISIBLE_DEVICES=0 time python examples/dora_finetuning/dora_finetuning.py --quantize --lora_dropout 0 --batch_size 16 --eval_step 2 --use_dora`
207207
on a 4090 with gradient accumulation set to 2 and max step to 20 resulted with the following observations:
@@ -215,7 +215,7 @@ on a 4090 with gradient accumulation set to 2 and max step to 20 resulted with t
215215
#### Caveats
216216

217217
- DoRA only supports embedding, linear, and Conv2d layers at the moment.
218-
- DoRA introduces a bigger overhead than pure LoRA, so it is recommended to merge weights for inference, see [`LoraModel.merge_and_unload`].
218+
- DoRA introduces a bigger overhead than pure LoRA, so it is recommended to merge weights for inference, see [`LoraModel.merge_and_unload`].
219219
- DoRA should work with weights quantized with bitsandbytes ("QDoRA"). However, issues have been reported when using QDoRA with DeepSpeed Zero2.
220220

221221
### QLoRA-style training
@@ -272,6 +272,50 @@ trainer = Trainer(
272272
)
273273
```
274274

275+
## Efficiently train tokens alongside LoRA
276+
277+
Sometimes it is necessary to not only change some layer's weights but to add new tokens as well. With larger models this can be a memory-costly endeavour. PEFT LoRA adapters support the `trainable_token_indices` parameter which allows tuning of other tokens alongside fine-tuning of specific layers with LoRA. This method only trains the tokens you specify and leaves all other tokens untouched. This saves memory and doesn't throw away learned context of existing token embeddings in contrast to when training the whole embedding matrix. Under the hood this method uses the layer of [`TrainableTokensModel`].
278+
279+
```py
280+
# for layer 'embed_tokens'
281+
config = LoraConfig(trainable_token_indices=[idx_1, idx_2, ...], ...)
282+
283+
# specific embedding layer
284+
config = LoraConfig(trainable_token_indices={'emb_tokens': [idx_1, idx_2, ...]}, ...)
285+
```
286+
287+
In the snippet below we show how to add new tokens to the model and how to train it alongside the other layers in the model.
288+
289+
```py
290+
from transformers import AutoTokenizer, AutoModelForCausalLM
291+
from peft import get_peft_model, LoraConfig
292+
293+
base_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
294+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
295+
296+
# we define our new tokens and add them to the tokenizer as special tokens
297+
special_tokens = ['<|start_think|>', '<|stop_think|>']
298+
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
299+
300+
# make room for new tokens in the embedding matrix if it isn't big enough already
301+
base_model.resize_token_embeddings(max(len(tokenizer), base_model.model.embed_tokens.num_embeddings)
302+
303+
# typical LoRA config with `trainable_token_indices` targeting embedding layer `embed_tokens`
304+
# and specifically our new tokens we just added
305+
lora_config = LoraConfig(
306+
target_modules='all-linear',
307+
trainable_token_indices={'embed_tokens': tokenizer.convert_tokens_to_ids(special_tokens)},
308+
)
309+
peft_model = get_peft_model(base_model, lora_config)
310+
311+
# proceed to train the model like normal
312+
[...]
313+
```
314+
315+
The token weights are part of your adapter state dict and saved alongside the LoRA weights.
316+
If we would have used full fine-tuning with `modules_to_save=['embed_tokens']` we would have stored the full embedding matrix in the checkpoint, leading to a much bigger file.
317+
318+
275319
## Merge LoRA weights into the base model
276320

277321
While LoRA is significantly smaller and faster to train, you may encounter latency issues during inference due to separately loading the base model and the LoRA adapter. To eliminate latency, use the [`~LoraModel.merge_and_unload`] function to merge the adapter weights with the base model. This allows you to use the newly merged model as a standalone model. The [`~LoraModel.merge_and_unload`] function doesn't keep the adapter weights in memory.
@@ -323,7 +367,7 @@ base_model = AutoModelForCausalLM.from_pretrained(
323367
)
324368
```
325369

326-
Then we load the first adapter:
370+
Then we load the first adapter:
327371

328372
```python
329373
peft_model_id = "alignment-handbook/zephyr-7b-sft-lora"
@@ -443,7 +487,7 @@ output = peft_model.generate(**inputs, adapter_names=adapter_names, max_new_toke
443487

444488
Note that the order does not matter here, i.e. the samples in the batch don't need to be grouped by adapter as in the example above. We just need to ensure that the `adapter_names` argument is aligned correctly with the samples.
445489

446-
Additionally, the same approach also works with the `modules_to_save` feature, which allows for saving and reusing specific neural network layers, such as custom heads for classification tasks, across different LoRA adapters.
490+
Additionally, the same approach also works with the `modules_to_save` feature, which allows for saving and reusing specific neural network layers, such as custom heads for classification tasks, across different LoRA adapters.
447491

448492
### Caveats
449493

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Trainable Tokens
18+
19+
The Trainable Tokens method provides a way to target specific token embeddings for fine-tuning without resorting to
20+
training the full embedding matrix or using an adapter on the embedding matrix. It is based on the initial implementation from
21+
[here](https://github.com/huggingface/peft/pull/1541).
22+
23+
The method only targets specific tokens and selectively trains the token indices you specify. Consequently the
24+
required RAM will be lower and disk memory is also significantly lower than storing the full fine-tuned embedding matrix.
25+
26+
Some preliminary benchmarks acquired with [this script](https://github.com/huggingface/peft/blob/main/scripts/train_memory.py)
27+
suggest that for `gemma-2-2b` (which has a rather large embedding matrix) you can save 4.8GiB VRAM with Trainable Tokens
28+
over fully fine-tuning the embedding matrix. While LoRA will use even less memory (-6.3GiB total over fine-tuning) it might also target
29+
tokens you don't want to be changed. With less extreme embedding matrixes the difference might come out shorter as well.
30+
31+
Note that this method does not add tokens for you, you have to add tokens to the tokenizer yourself and resize the
32+
embedding matrix of the model accordingly. This method will only re-train the embeddings for the tokens you specify.
33+
This method can also be used in conjunction with LoRA layers! See [the LoRA developer guide](../developer_guides/lora#efficiently-train-tokens-alongside-lora).
34+
35+
## TrainableTokensConfig
36+
37+
[[autodoc]] tuners.trainable_tokens.config.TrainableTokensConfig
38+
39+
## TrainableTokensModel
40+
41+
[[autodoc]] tuners.trainable_tokens.model.TrainableTokensModel
42+

src/peft/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@
8787
PromptEncoderReparameterizationType,
8888
PromptTuningConfig,
8989
PromptTuningInit,
90+
TrainableTokensConfig,
91+
TrainableTokensModel,
9092
VBLoRAConfig,
9193
VBLoRAModel,
9294
VeraConfig,
@@ -177,6 +179,8 @@
177179
"PromptTuningConfig",
178180
"PromptTuningInit",
179181
"TaskType",
182+
"TrainableTokensConfig",
183+
"TrainableTokensModel",
180184
"VBLoRAConfig",
181185
"VBLoRAConfig",
182186
"VBLoRAModel",

src/peft/mixed_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def set_modules_to_save(self, peft_config: PeftConfig, adapter_name: str) -> Non
251251
self.modules_to_save = set(modules_to_save)
252252
else:
253253
self.modules_to_save.update(modules_to_save)
254-
_set_trainable(self, adapter_name, modules_to_save=peft_config.modules_to_save)
254+
_set_trainable(self, adapter_name, module_names=peft_config.modules_to_save)
255255

256256
def set_adapter(self, adapter_name: Union[str, list[str]]) -> None:
257257
"""

src/peft/peft_model.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
4242
from peft.utils.constants import DUMMY_MODEL_CONFIG
4343
from peft.utils.integrations import init_empty_weights
44+
from peft.utils.other import TrainableTokensWrapper
4445

4546
from . import __version__
4647
from .config import PeftConfig
@@ -128,7 +129,8 @@ def __init__(
128129
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
129130
with ctx():
130131
self.base_model = cls(model, {adapter_name: peft_config}, adapter_name)
131-
self.set_additional_trainable_modules(peft_config, adapter_name)
132+
133+
self.set_additional_trainable_modules(peft_config, adapter_name)
132134

133135
if hasattr(self.base_model, "_cast_adapter_dtype"):
134136
self.base_model._cast_adapter_dtype(
@@ -950,7 +952,45 @@ def set_additional_trainable_modules(self, peft_config, adapter_name):
950952
else:
951953
self.modules_to_save.update(peft_config.modules_to_save)
952954
# this may add a new ModulesToSaveWrapper
953-
_set_trainable(self, adapter_name, modules_to_save=peft_config.modules_to_save)
955+
_set_trainable(self, adapter_name, module_names=peft_config.modules_to_save)
956+
957+
if getattr(peft_config, "trainable_token_indices", None) is not None:
958+
if isinstance(peft_config.trainable_token_indices, dict):
959+
target_layers = peft_config.trainable_token_indices
960+
else:
961+
target_layers = {"embed_tokens": peft_config.trainable_token_indices}
962+
963+
if self.modules_to_save:
964+
for target_layer in target_layers:
965+
if target_layer in self.modules_to_save:
966+
raise ValueError(
967+
"The embedding layer is already marked to be trained fully, either specify "
968+
f'`modules_to_save=[..., "{target_layer}", ...]` or '
969+
f"`trainable_tokens={{'{target_layer}': x}}` but not both."
970+
)
971+
972+
# we are not adding these module names to `self.modules_to_save` as this is strictly reserved for the
973+
# `ModulesToSaveWrapper`.
974+
975+
for target_layer, token_indices in target_layers.items():
976+
new_training_modules = _set_trainable(
977+
self,
978+
adapter_name,
979+
module_names=[target_layer],
980+
strict_module_check=True,
981+
wrapper_cls=TrainableTokensWrapper,
982+
token_indices=token_indices,
983+
)
984+
985+
# Handle weight-tying of output and input embeddings. Currently this only consists of failing.
986+
model_config = BaseTuner.get_model_config(self)
987+
if model_config.get("tie_word_embeddings", False) and isinstance(
988+
self.model.get_input_embeddings(), TrainableTokensWrapper
989+
):
990+
raise ValueError(
991+
"The model uses weight-tying which is currently not supported with `trainable_token_indices`. "
992+
"You can try disabling weight-tying but you must expect an increased memory usage."
993+
)
954994

955995
def get_layer_status(self) -> list[TunerLayerStatus]:
956996
"""Get the status of each adapter layer in the model.
@@ -1447,7 +1487,7 @@ def __init__(
14471487
break
14481488

14491489
# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
1450-
_set_trainable(self, adapter_name, modules_to_save=peft_config.modules_to_save)
1490+
_set_trainable(self, adapter_name, module_names=peft_config.modules_to_save)
14511491

14521492
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
14531493
"""
@@ -2238,7 +2278,7 @@ def __init__(
22382278
break
22392279

22402280
# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
2241-
_set_trainable(self, adapter_name, modules_to_save=peft_config.modules_to_save)
2281+
_set_trainable(self, adapter_name, module_names=peft_config.modules_to_save)
22422282

22432283
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
22442284
"""
@@ -2459,7 +2499,7 @@ def __init__(
24592499
break
24602500

24612501
# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
2462-
_set_trainable(self, adapter_name, modules_to_save=peft_config.modules_to_save)
2502+
_set_trainable(self, adapter_name, module_names=peft_config.modules_to_save)
24632503

24642504
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
24652505
"""

src/peft/tuners/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .poly import PolyConfig, PolyModel
4040
from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
4141
from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
42+
from .trainable_tokens import TrainableTokensConfig, TrainableTokensModel
4243
from .vblora import VBLoRAConfig, VBLoRAModel
4344
from .vera import VeraConfig, VeraModel
4445
from .xlora import XLoraConfig, XLoraModel
@@ -88,6 +89,8 @@
8889
"PromptEncoderReparameterizationType",
8990
"PromptTuningConfig",
9091
"PromptTuningInit",
92+
"TrainableTokensConfig",
93+
"TrainableTokensModel",
9194
"VBLoRAConfig",
9295
"VBLoRAModel",
9396
"VeraConfig",

src/peft/tuners/lora/config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,14 @@ class LoraConfig(PeftConfig):
273273
parameter when you want to apply LoRA to the ColumnParallelLinear and RowParallelLinear layers of megatron.
274274
megatron_core (`Optional[str]`):
275275
The core module from Megatron to use, defaults to `"megatron.core"`.
276+
trainable_token_indices (`Optional[Union[List[int], dict[str, List[int]]]]`)
277+
Lets you specify which token indices to selectively fine-tune without requiring to re-train the whole
278+
embedding matrix using the `peft.TrainableTokensModel` method. You can either specify a list of indices
279+
which will then target the `embed_tokens` layer, or, if your model is using a different layer for
280+
embedding, you can specify a dictionary where the key is the name of the embedding module and the values
281+
are the list of token indices, e.g. `{'embed_tokens': [0, 1, ...]}`. Note that training with FSDP/DeepSpeed
282+
might not yet be fully supported with this option enabled. Also note that models using weight-tying are
283+
currently not supported.
276284
loftq_config (`Optional[LoftQConfig]`):
277285
The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone weights
278286
and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a
@@ -431,6 +439,20 @@ class LoraConfig(PeftConfig):
431439
)
432440
},
433441
)
442+
trainable_token_indices: Optional[Union[list[int], dict[str, list[int]]]] = field(
443+
default=None,
444+
metadata={
445+
"help": (
446+
"Lets you specify which token indices to selectively fine-tune without requiring to re-train the "
447+
"whole embedding matrix using the `peft.TrainableTokensModel` method. You can either specify a list "
448+
"of indices which will then target the `embed_tokens` layer, or, if your model is using a different "
449+
"layer for embedding, you can specify a dictionary where the key is the name of the embedding module "
450+
"and the values are the list of token indices, e.g. `{'embed_tokens': [0, 1, ...]}`. "
451+
"Note that training with FSDP/DeepSpeed might not yet be fully supported with this option enabled. "
452+
"Also note that models using weight-tying are currently not supported."
453+
)
454+
},
455+
)
434456
# dict type is used when loading config.json
435457
loftq_config: Union[LoftQConfig, dict] = field(
436458
default_factory=dict,

0 commit comments

Comments
 (0)