You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Standalone Custom Tokens Tuner and integrated into LoRA (#2376)
This change is based on the nifty addition of @marcusinthesky from #1541.
When adding tokens or fine-tuning the representation of specific tokens we currently have little choice but to retrain the whole embedding matrix which can be huge and adds to the memory footprint (in RAM but also on disk). This method creates a sparse matrix of shape (n, embed_dim) where n is the number of tokens to be customized and only trains these few values.
This change introduces two ways of using it:
```
peft_config = TrainableTokensConfig(target_modules=['embed_tokens'], token_indices=[0, 1, 2])
peft_model = get_peft_model(model, peft_config)
```
and with LoRA
```
peft_config = LoraConfig(
target_modules='all-linear',
trainable_token_indices={'embed_tokens': [0, 1, 2]},
)
peft_model = get_peft_model(model, peft_config)
```
Adding this feature to adapters other than LoRA should be relatively easy, mostly adding the `trainable_token_indices` config option and some debugging.
To make this change it was necessary to change the `modules_to_save` infrastructure as combining this feature with LoRA is quite similar. This refactoring entailed moving most of the basic functionality of `ModulesToSave` to the `AuxiliaryTrainingWrapper` class. This also changes the logic how `modules_to_save` is loaded/saved from from the state dict, so there could still be bugs here.
This implementation does not entail support for weight-tied layers yet. This will follow in a future change.
---
Notable commits in this squash:
* Use unload_and_optionally_merge_module protocol
With `AuxiliaryTrainingWrapper` as abstraction it is probably a good idea to
have support for `unload_and_optionally_merge_module`.
Since the wrapper is more akin to a PEFT layer than a model the name semantics
are fine and it does basically the same job.
* trainable tokens is also trained in certain adapters
Before, the assumption was that modules_to_save was the only thing that
is trained alongside an adapter's parameters. Now there's also the
token_adapter delta tokens via `NewTokensWrapper`.
* Remove old modules_to_save handling
This is now all handled via the `AuxiliaryTrainingWrapper`.
* Fix modules_to_save module overwriting
The state dict imlementation of ModulesToSaveWrapper was incorrect in that
it did not include its own parameters, just the parameters it needs to overwrite
in the end. I.e. if layer `lin1` is modules to save wrapped,
`lin1.{weight,bias}` is saved and overwritten but `lin1.modules_to_save.<adpater_name>.[...]`
is not saved.
* Introduce a load key map for aux. train wrapper
Before this change it was only possible to remove a key prefix from the wrapper's
state dict (e.g., `modules_to_save.default.weight` -> `weight`); now it is possible
to restore such reduced value by mapping the key back
(i.e., `weight` -> `modules_to_save.default.weight`).
* Replace sparse matrix with dense + index_copy
This change is mostly because sparse matrices are not that beneficial in this case
(at least not from what we can see right now) and they do not solve the problem
of having to change the new tokens in-place to avoid outdated deltas when new token
vectors are initialized randomly after loading the deltas.
* Make peft_config.layers_to_transform optional
Before this change the base tuner class was forcing this attribute
to be present on the config class even though the attribute is not
specified in the base config.
* Implement missing key logic in `_set_trainable`
Before this it was not checked if the targeted module by `modules_to_save` or `trainable_token_indices` existed
or not (when used in conjunction with a PEFT method). In this case an error message similar to the `inject_adapter`
error is raised when no module is found.
---------
Co-authored-by: Marcus Gawronsky <[email protected]>
Co-authored-by: Benjamin Bossan <[email protected]>
[PiSSA](https://arxiv.org/abs/2404.02948) initializes the LoRA adapter using the principal singular values and singular vectors. This straightforward modification allows PiSSA to converge more rapidly than LoRA and ultimately attain superior performance. Moreover, PiSSA reduces the quantization error compared to QLoRA, leading to further enhancements.
44
+
[PiSSA](https://arxiv.org/abs/2404.02948) initializes the LoRA adapter using the principal singular values and singular vectors. This straightforward modification allows PiSSA to converge more rapidly than LoRA and ultimately attain superior performance. Moreover, PiSSA reduces the quantization error compared to QLoRA, leading to further enhancements.
45
45
46
46
Configure the initialization method to "pissa", which may take several minutes to execute SVD on the pre-trained model:
Alternatively, execute fast SVD, which takes only a few seconds. The number of iterations determines the trade-off between the error and computation time:
52
52
```python
53
-
lora_config = LoraConfig(init_lora_weights="pissa_niter_[number of iters]", ...)
53
+
lora_config = LoraConfig(init_lora_weights="pissa_niter_[number of iters]", ...)
54
54
```
55
55
For detailed instruction on using PiSSA, please follow [these instructions](https://github.com/huggingface/peft/tree/main/examples/pissa_finetuning).
56
56
57
57
### CorDA
58
58
59
59
[CorDA](https://arxiv.org/pdf/2406.05223) builds task-aware LoRA adapters from weight decomposition oriented by the context of downstream task to learn (instruction-previewed mode, IPM) or world knowledge to maintain (knowledge-preserved mode, KPM).
60
60
The KPM not only achieves better performance than LoRA on fine-tuning tasks, but also mitigates the catastrophic forgetting of pre-trained world knowledge.
61
-
When preserving pre-trained knowledge is not a concern,
62
-
the IPM is favored because it can further accelerate convergence and enhance the fine-tuning performance.
61
+
When preserving pre-trained knowledge is not a concern,
62
+
the IPM is favored because it can further accelerate convergence and enhance the fine-tuning performance.
63
63
64
-
You need to configure the initialization method to "corda", and specify the mode of IPM or KPM and the dataset to collect covariance matrices.
64
+
You need to configure the initialization method to "corda", and specify the mode of IPM or KPM and the dataset to collect covariance matrices.
65
65
66
66
```py
67
67
@torch.no_grad()
@@ -201,7 +201,7 @@ model = PeftModel.from_pretrained(base_model, peft_model_id, ephemeral_gpu_offlo
201
201
```
202
202
203
203
DoRA is optimized (computes faster and takes less memory) for models in the evaluation mode, or when dropout is set to 0. We reuse the
with `CUDA_VISIBLE_DEVICES=0 time python examples/dora_finetuning/dora_finetuning.py --quantize --lora_dropout 0 --batch_size 16 --eval_step 2 --use_dora`
207
207
on a 4090 with gradient accumulation set to 2 and max step to 20 resulted with the following observations:
@@ -215,7 +215,7 @@ on a 4090 with gradient accumulation set to 2 and max step to 20 resulted with t
215
215
#### Caveats
216
216
217
217
- DoRA only supports embedding, linear, and Conv2d layers at the moment.
218
-
- DoRA introduces a bigger overhead than pure LoRA, so it is recommended to merge weights for inference, see [`LoraModel.merge_and_unload`].
218
+
- DoRA introduces a bigger overhead than pure LoRA, so it is recommended to merge weights for inference, see [`LoraModel.merge_and_unload`].
219
219
- DoRA should work with weights quantized with bitsandbytes ("QDoRA"). However, issues have been reported when using QDoRA with DeepSpeed Zero2.
220
220
221
221
### QLoRA-style training
@@ -272,6 +272,50 @@ trainer = Trainer(
272
272
)
273
273
```
274
274
275
+
## Efficiently train tokens alongside LoRA
276
+
277
+
Sometimes it is necessary to not only change some layer's weights but to add new tokens as well. With larger models this can be a memory-costly endeavour. PEFT LoRA adapters support the `trainable_token_indices` parameter which allows tuning of other tokens alongside fine-tuning of specific layers with LoRA. This method only trains the tokens you specify and leaves all other tokens untouched. This saves memory and doesn't throw away learned context of existing token embeddings in contrast to when training the whole embedding matrix. Under the hood this method uses the layer of [`TrainableTokensModel`].
The token weights are part of your adapter state dictand saved alongside the LoRA weights.
316
+
If we would have used full fine-tuning with`modules_to_save=['embed_tokens']` we would have stored the full embedding matrix in the checkpoint, leading to a much bigger file.
317
+
318
+
275
319
## Merge LoRA weights into the base model
276
320
277
321
While LoRA is significantly smaller and faster to train, you may encounter latency issues during inference due to separately loading the base model and the LoRA adapter. To eliminate latency, use the [`~LoraModel.merge_and_unload`] function to merge the adapter weights with the base model. This allows you to use the newly merged model as a standalone model. The [`~LoraModel.merge_and_unload`] function doesn't keep the adapter weights in memory.
Note that the order does not matter here, i.e. the samples in the batch don't need to be grouped by adapter as in the example above. We just need to ensure that the `adapter_names` argument is aligned correctly with the samples.
445
489
446
-
Additionally, the same approach also works with the `modules_to_save` feature, which allows for saving and reusing specific neural network layers, such as custom heads for classification tasks, across different LoRA adapters.
490
+
Additionally, the same approach also works with the `modules_to_save` feature, which allows for saving and reusing specific neural network layers, such as custom heads for classification tasks, across different LoRA adapters.
The method only targets specific tokens and selectively trains the token indices you specify. Consequently the
24
+
required RAM will be lower and disk memory is also significantly lower than storing the full fine-tuned embedding matrix.
25
+
26
+
Some preliminary benchmarks acquired with [this script](https://github.com/huggingface/peft/blob/main/scripts/train_memory.py)
27
+
suggest that for `gemma-2-2b` (which has a rather large embedding matrix) you can save 4.8GiB VRAM with Trainable Tokens
28
+
over fully fine-tuning the embedding matrix. While LoRA will use even less memory (-6.3GiB total over fine-tuning) it might also target
29
+
tokens you don't want to be changed. With less extreme embedding matrixes the difference might come out shorter as well.
30
+
31
+
Note that this method does not add tokens for you, you have to add tokens to the tokenizer yourself and resize the
32
+
embedding matrix of the model accordingly. This method will only re-train the embeddings for the tokens you specify.
33
+
This method can also be used in conjunction with LoRA layers! See [the LoRA developer guide](../developer_guides/lora#efficiently-train-tokens-alongside-lora).
0 commit comments