Skip to content

Commit 4192101

Browse files
Method comparison evaluation suite (#2395)
Introduction of a method evaluation suite. We generally face the problem that there is little knowledge on what PEFT methods perform best. To this end we decided to build an evaluation suite that has defined tasks, shared hyper-parameters and can be extended with new tasks and new method configurations over time. For the sake of comparison we've not decided to incorporate user-submitted results but we encourage users to inspect the results, suggest new experiments and improve the configuration of methods if they're deemed unfavorable. As of now there's only one task based on the MetaMathQA dataset which has the benefit of being complex while still fitting on a consumer GPU. Notable changes in this squash: * Add default training params The experiment specific training params use the default training params but can override any parameter from it if needed. However, this way it's easier to make a change to all experiments (say, I want to change the base model, I don't need to change each individual training_parameters.json). * Add possibility to change attn implementation However, both flash attention 2 and flex attention are slower on my system. Thus, stay with default None (-> SDPA). * Refactor to use GenerationConfig Allows to more easily use, say, static cache, which is the new default, as it's faster (apart from the first pass) * Better parsing of answers E.g. 1/2 == 0.5 * Keep adapter file by default after train run But add --clean to delete it. Keeping the adapter can be useful if the user wants to run further tests with the trained model. --------- Co-authored-by: Benjamin Bossan <[email protected]>
1 parent 7279a9f commit 4192101

File tree

47 files changed

+2567
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2567
-0
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,7 @@ dmypy.json
139139

140140
# More test things
141141
wandb
142+
143+
# method_comparison logs
144+
method_comparison/MetaMathQA/cancelled_results/
145+
method_comparison/MetaMathQA/temporary_results/

method_comparison/Makefile

Whitespace-only changes.
+213
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# PEFT method comparison on the MetaMathQA and GSM8K datasets
2+
3+
## Goal
4+
5+
This goal is to provide a benchmarking framework for the different PEFT methods that are implemented. It is important that evaluating different PEFT methods is reproducible, idempotent, and version-controlled. Results for more PEFT methods can be added over time.
6+
7+
## Dataset
8+
9+
This task trains on the [MetaMathQA]((https://huggingface.co/datasets/meta-math/MetaMathQA)) dataset and validates/tests on the [GSM8K](https://huggingface.co/datasets/openai/gsm8k) dataset ("main").
10+
11+
For the model to attain good accuracy, it needs to learn to adhere to the output format and it must express basic chain of thought reasoning capabilities to get to the correct result in the first place. The task is challenging for models in the sub 7B parameter range.
12+
13+
The train set uses the whole of MetaMathQA. The validation set is a random sample from the train set of GSM8K. The test set is the whole of the GSM8K test set.
14+
15+
## Running
16+
17+
Create an experiment in the `experiment/<peft-method>` folder of your choice and give it a name (the name itself does not matter but helps identify the experiment). An example would be `experiments/lora/llama-3.2-3B-rank32/`. Inside that directory, create 2 files:
18+
19+
- `adapter_config.json`
20+
- Optional: `training_parameters.json`
21+
22+
### `adapter_config.json`
23+
24+
This must be a valid PEFT configuration. It is easiest to create it programmatically, e.g.:
25+
26+
```python
27+
from peft import LoraConfig
28+
29+
config = LoraConfig(...)
30+
config.save_pretrained(<path-to-experiment>)
31+
```
32+
33+
### `training_parameters.json`
34+
35+
There is a default file for the non-PEFT parameters: `default_training_params.json`. This contains all the other parameters that are relevant for training, e.g. the base model id, number of steps, batch size, learning rate, etc. If parameters that differ from the defaults are needed for a specific experiment, place a `training_parameters.json` into the experiment directory and adjust the parameters that need changing. The other parametes are taken from the aforementioned default config.
36+
37+
For an overview of all possible arguments, you can also check the `TrainConfig` `dataclass` in `utils.py`.
38+
39+
### Runtime performance
40+
41+
Several factors should be considered to achieve a fast runtime performance. Besides the obvious factors like `max_steps` or the base model size, we found the following factors to have a significant impact:
42+
43+
#### Eval batch size
44+
45+
Regarding the `batch_size_eval` parameter, it is quite critical since evaluation takes up a significant portion of the training time and batching helps with reducing that. It should be possible to choose a value that is multiple times higher than the batch size used for training (`batch_size`). You should also pay attention to the size of the validation set -- e.g. if it's 50, don't choose a `batch_size_eval` of 40, as that results in a large batch of 30 and a small batch of 10. 25 might be a better choice. Also, ensure via a quick train run that the batch size does not lead to out of memory errors -- getting this error at the very end on evaluating the test set would be quite a loss of time.
46+
47+
#### Generation length
48+
49+
During testing, we discovered that the validation time is greatly inflated by just a few very long generations. Those can inflate the validation time by a factor of 3 or more. At the same time, we discovered that these long generations do not help with accuracy -- in fact, if they exceed the maximum configured length, they're just cut off mid sentence and would thus produce an accuracy of 0 anyway.
50+
51+
To remedy this, we now set both `max_length` and `max_new_tokens` for the generation kwargs in the default training parameters. Normally, this is not possible when using transformers, as the latter argument overrides the former. However, we have added special logic inside of `get_generation_config` which takes both and chooses the smaller of the two. This way, we can get rid of these excessively long generations, thus considerably reducing eval times, while still guaranteeing a maximum total generation length to guard against OOM errors. Testing showed that this does not hamper test accuracy. It is therefore recommended not to change these settings.
52+
53+
#### Bucketing
54+
55+
The length of the sequences in the training data can vary a lot. Therefore, if samples are taken randomly from the training dataset, we will end up with batches containing very short and very long sequences. This is bad because the batch will be padded to the longest sequence, slowing down training. The obvious solution would be to sort the whole dataset by sequence length, but this is also bad because it introduces an order bias (e.g. first training on only short and then on only long answers).
56+
57+
The solution is to find a trade off between the two factors. This is achieved by the `BucketIterator`. It first creates buckets that contain multiple batches, e.g. 20x the batch size. The bucket is then sorted by sequence length and then batches are yielded from the bucket. Therefore, we have a small order bias within a bucket but not between buckets, stricking a good balance between training speed and training loss.
58+
59+
From practical experiments, for a batch size of 4, a bucket size of 80 provides a good balance with only slightly lower training loss but cutting training time by 25%. For eval, we don't use the iterator since there, the batch size is relatively big and thus there is little upside.
60+
61+
### Start a run
62+
63+
Once everything is set up properly, start a run by using the `run.py` script. Pass `-v` for verbose output to the console (recommended if observing the progress is desired). As an example, for `experiments/lora/llama-3.2-3B-rank32/` the invocation would be:
64+
65+
```sh
66+
python run.py -v experiments/lora/llama-3.2-3B-rank32/
67+
```
68+
69+
By default, the adapter will be saved in a temporary file for further inspection if needed. The prevent this, add the `--clean` flag to the call.
70+
71+
### Run status
72+
73+
The run can be categorized 3 different states:
74+
75+
1. Main run: You are on the `main` branch and the run ended successfully. The results are stored in the `results` folder and are used for further analysis.
76+
2. Test run: You are not on the `main` branch and the run ended successfully. The results are stored in the `temporary_results` folder and are not used for further analysis.
77+
3. The run was cancelled (`ctrl + c`). The results are stored in the `cancelled_results` folder and are not used for further analysis.
78+
79+
## Outputs
80+
81+
Results are stored in one of the result directories. An example output could look like so:
82+
83+
```js
84+
{
85+
"run_info": {
86+
"created_at": "2025-03-05T13:50:05+00:00",
87+
"total_time": 2711.0915009640157,
88+
"experiment_name": "ia3/lr_0.001",
89+
"peft_branch": "ben-method-comparison",
90+
"train_config": {
91+
"model_id": "meta-llama/Llama-3.2-3B",
92+
"dtype": "bfloat16",
93+
"max_seq_length": 768,
94+
"batch_size": 4,
95+
"batch_size_eval": 51,
96+
"max_steps": 5000,
97+
"eval_steps": 250,
98+
"compile": false,
99+
"query_template": "Question: {query} Think step by step.\nAnswer:",
100+
"seed": 0,
101+
"grad_norm_clip": 1.0,
102+
"optimizer_kwargs": {
103+
"lr": 0.001
104+
},
105+
"lr_scheduler": "cosine",
106+
"use_amp": false,
107+
"generation_kwargs": {
108+
"max_length": 800
109+
},
110+
"attn_implementation": null
111+
},
112+
"peft_config": {
113+
"task_type": null,
114+
"peft_type": "IA3",
115+
"auto_mapping": null,
116+
"base_model_name_or_path": "meta-llama/Llama-3.2-3B",
117+
"revision": null,
118+
"inference_mode": false,
119+
"target_modules": [
120+
"v_proj",
121+
"k_proj",
122+
"down_proj"
123+
],
124+
"exclude_modules": null,
125+
"feedforward_modules": [
126+
"down_proj"
127+
],
128+
"fan_in_fan_out": false,
129+
"modules_to_save": null,
130+
"init_ia3_weights": true
131+
}
132+
},
133+
"train_info": {
134+
"cuda_memory_reserved_avg": 14229219940,
135+
"cuda_memory_max": 24847056896,
136+
"cuda_memory_reserved_99th": 19115624366,
137+
"train_time": 2238.65277833899,
138+
"file_size": 1157064,
139+
"status": "success",
140+
"metrics": [
141+
{
142+
"step": 250,
143+
"valid accuracy": 0.0784313725490196,
144+
"train loss": 1.1336498007774354,
145+
"train samples": 1000
146+
},
147+
[...]
148+
{
149+
"step": 5000,
150+
"valid accuracy": 0.21568627450980393,
151+
"train loss": 0.6345920492410659,
152+
"train samples": 20000
153+
},
154+
{
155+
"step": 5000,
156+
"test accuracy": 0.35129740518962077,
157+
"train loss": 0.6345920492410659,
158+
"train samples": 20000,
159+
"train total tokens": 4197579
160+
}
161+
]
162+
},
163+
"meta_info": {
164+
"model_sha": "13afe5124825b4f3751f836b40dafda64c1ed062",
165+
"model_created_at": "2024-09-18T15:23:48+00:00",
166+
"dataset_sha": "aa4f34d3d2d3231299b5b03d9b3e5a20da45aa18",
167+
"dataset_created_at": "2023-09-21T17:22:46+00:00",
168+
"package_info": {
169+
"transformers-version": "4.50.0.dev0",
170+
"transformers-commit-hash": "752ef3fd4e70869626ec70657a770a85c0ad9219",
171+
"peft-version": "0.14.1.dev0",
172+
"peft-commit-hash": "a447a4e5ecd87b7d57733f4df9616a328cf130f4",
173+
"datasets-version": "3.3.2",
174+
"datasets-commit-hash": null,
175+
"bitsandbytes-version": "0.45.2",
176+
"bitsandbytes-commit-hash": null,
177+
"torch-version": "2.6.0+cu124",
178+
"torch-commit-hash": null
179+
},
180+
"system_info": {
181+
"system": "Linux",
182+
"release": "6.11.0-17-generic",
183+
"version": "#17~24.04.2-Ubuntu SMP PREEMPT_DYNAMIC Mon Jan 20 22:48:29 UTC 2",
184+
"machine": "x86_64",
185+
"processor": "x86_64",
186+
"gpu": "NVIDIA GeForce RTX 4090"
187+
},
188+
"pytorch_info": "PyTorch built with: [...]"
189+
}
190+
}
191+
```
192+
193+
## Dependencies
194+
195+
Apart from the normal PEFT dependencies, ensure that the packages in the `requirements.txt` are installed, e.g. via:
196+
197+
```sh
198+
python -m pip install -r requirements.txt
199+
```
200+
201+
Python 3.12+ is required.
202+
203+
## Open tasks
204+
205+
- consider using `DataLoader`
206+
- consider adding https://github.com/huggingface/Math-Verify
207+
- consider adding `weight` argument to cross entropy calculation to downweight the EOS token, but it would require calculating the loss manually instead of relying on transformers (see https://github.com/huggingface/transformers/blob/6a876462c308bd7cd7d3ca8e93abaa7d5b02e90e/src/transformers/loss/loss_utils.py#L24-L48)
208+
- do a sanity check against/comparison with transformers Trainer
209+
- consider using vLLM to potentially speed up generations, at least for the test set
210+
- using `torch.compile` leads to a huge slowdown, investigate (maybe recompiles), although it does save memory
211+
- AMP does not appear to help, investigate
212+
- packing of sequences (but this probably requires adjusting the attention matrix)
213+
- clean up what gets printed and where (stdout, stderr)

method_comparison/MetaMathQA/cancelled_results/.gitkeep

Whitespace-only changes.

method_comparison/MetaMathQA/data.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2025-present the HuggingFace Inc. team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
All utilities related to data handling.
17+
"""
18+
19+
from functools import partial
20+
from typing import Callable
21+
22+
import datasets
23+
import numpy as np
24+
from datasets import Dataset, load_dataset
25+
26+
27+
# with a token limit of 768 for query + response, we have to exclude all texts with length > 1304; this leaves 93.8% of
28+
# the dataset
29+
CHAR_LIMIT = 1300
30+
# train/valid/test split -- note that evaluation takes quite long, so don't choose too large sizes for the valid set,
31+
# since it's run multiple times during training; test is only run once at the end and thus can be larger
32+
VALID_SIZE = 50
33+
34+
35+
def get_filtered_dataset(*, ds: datasets.Dataset, print_fn: Callable[..., None]) -> Dataset:
36+
"""Return the filtered dataset, with long queries removed.
37+
38+
We determined that 99% of queries have 529 or fewer characters. Characters roughly correspond to tokens, so this is
39+
a good proxy. We cannot use tokens directly, as that depends on the tokenizer, which can be different for each
40+
model, but we want the same filter for each model.
41+
42+
"""
43+
char_lengths = [len(f"{q} {r}") for q, r in zip(ds["query"], ds["response"])]
44+
idx_filtered = [i for i, length in enumerate(char_lengths) if length <= CHAR_LIMIT]
45+
print_fn(f"Filtered dataset: {100 * len(idx_filtered) / len(ds):.1f}% of the original dataset")
46+
return ds.select(idx_filtered)
47+
48+
49+
def get_train_valid_test_datasets(
50+
*, tokenizer, query_template: str, print_fn: Callable[..., None]
51+
) -> tuple[Dataset, Dataset, Dataset]:
52+
"""
53+
Return the indices of the train, valid, and test splits of the dataset.
54+
55+
We cannot use ds.train_test_split(..., stratify_by_column="type") as it gives:
56+
57+
> ValueError: Stratifying by column is only supported for ClassLabel column, and column type is Value.
58+
59+
even after calling ds_filtered.class_encode_column("type"). Thus, using sklearn's StratifiedKFold instead.
60+
"""
61+
metamath = load_dataset("meta-math/MetaMathQA")["train"]
62+
metamath = get_filtered_dataset(ds=metamath, print_fn=print_fn)
63+
64+
# gsmk8k does not need to be filtered as query and response are short enough
65+
gsm8k = load_dataset("openai/gsm8k", "main")
66+
gsm8k = gsm8k.rename_columns({"question": "query", "answer": "response"})
67+
gsm8k_train = gsm8k["train"]
68+
gsm8k_test = gsm8k["test"]
69+
70+
np.random.seed(0)
71+
indices = np.arange(len(gsm8k_train))
72+
np.random.shuffle(indices)
73+
idx_valid = indices[:VALID_SIZE]
74+
75+
ds_train = metamath
76+
ds_valid = gsm8k_train.select(idx_valid)
77+
ds_test = gsm8k_test
78+
79+
print_fn(f"Train size: {len(ds_train)}")
80+
print_fn(f"Valid size: {len(ds_valid)}")
81+
print_fn(f"Test size: {len(ds_test)}")
82+
83+
tokenize_with_answer_ = partial(tokenize_with_answer, tokenizer=tokenizer, template=query_template)
84+
tokenize_wo_answer_ = partial(tokenize_wo_answer, tokenizer=tokenizer, template=query_template)
85+
ds_train = ds_train.map(tokenize_with_answer_, batched=True).remove_columns(["type", "query", "original_question"])
86+
ds_valid = ds_valid.map(tokenize_wo_answer_, batched=True).remove_columns(["query"])
87+
ds_test = ds_test.map(tokenize_wo_answer_, batched=True).remove_columns(["query"])
88+
89+
return ds_train, ds_valid, ds_test
90+
91+
92+
def tokenize_with_answer(samples, tokenizer, template):
93+
queries = [template.format(query=sample) + answer for sample, answer in zip(samples["query"], samples["response"])]
94+
tokenized = tokenizer(queries)
95+
tokenized["input_ids"] = [input_ids[: tokenizer.model_max_length] for input_ids in tokenized["input_ids"]]
96+
tokenized["attention_mask"] = [
97+
input_ids[: tokenizer.model_max_length] for input_ids in tokenized["attention_mask"]
98+
]
99+
return tokenized
100+
101+
102+
def tokenize_wo_answer(samples, tokenizer, template):
103+
queries = [template.format(query=sample) for sample in samples["query"]]
104+
tokenized = tokenizer(queries)
105+
tokenized["input_ids"] = [input_ids[: tokenizer.model_max_length] for input_ids in tokenized["input_ids"]]
106+
tokenized["attention_mask"] = [
107+
input_ids[: tokenizer.model_max_length] for input_ids in tokenized["attention_mask"]
108+
]
109+
return tokenized
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"model_id": "meta-llama/Llama-3.2-3B",
3+
"dtype": "bfloat16",
4+
"max_seq_length": 768,
5+
"batch_size": 4,
6+
"batch_size_eval": 50,
7+
"max_steps": 5000,
8+
"eval_steps": 250,
9+
"compile": false,
10+
"seed": 0,
11+
"grad_norm_clip": 1.0,
12+
"optimizer_kwargs": {
13+
"lr": 1e-4,
14+
"weight_decay": 0.1
15+
},
16+
"lr_scheduler": "cosine",
17+
"use_amp": false,
18+
"autocast_adapter_dtype": true,
19+
"attn_implementation": null,
20+
"generation_kwargs": {
21+
"max_length": 800,
22+
"max_new_tokens": 300
23+
},
24+
"query_template": "Question: {query} Think step by step.\nAnswer:"
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
{
2+
"alpha_pattern": {},
3+
"auto_mapping": null,
4+
"base_model_name_or_path": null,
5+
"beta1": 0.85,
6+
"beta2": 0.85,
7+
"bias": "none",
8+
"corda_config": null,
9+
"deltaT": 1,
10+
"eva_config": null,
11+
"exclude_modules": null,
12+
"fan_in_fan_out": false,
13+
"inference_mode": false,
14+
"init_lora_weights": true,
15+
"init_r": 64,
16+
"layer_replication": null,
17+
"layers_pattern": null,
18+
"layers_to_transform": null,
19+
"loftq_config": {},
20+
"lora_alpha": 8,
21+
"lora_bias": false,
22+
"lora_dropout": 0.0,
23+
"megatron_config": null,
24+
"megatron_core": "megatron.core",
25+
"modules_to_save": null,
26+
"orth_reg_weight": 0.5,
27+
"peft_type": "ADALORA",
28+
"r": 8,
29+
"rank_pattern": null,
30+
"revision": null,
31+
"target_modules": null,
32+
"target_r": 32,
33+
"task_type": null,
34+
"tfinal": 500,
35+
"tinit": 200,
36+
"total_step": 5000,
37+
"use_dora": false,
38+
"use_rslora": false
39+
}

0 commit comments

Comments
 (0)