-
Notifications
You must be signed in to change notification settings - Fork 23
/
finetune.py
441 lines (390 loc) · 15.3 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
import os
import sys
from typing import List, Union
import fire
import numpy as np
import torch
import transformers
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
from transformers import (
AutoModelForCausalLM,
BitsAndBytesConfig,
LlamaTokenizer,
Trainer,
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from utils.prompter import AlpacaPrompter, PromptSelector
from utils.text import load_text_file
class PeftTrainer(Trainer):
def _save_checkpoint(self, _, trial, metrics=None):
"""Don't save base model, optimizer etc.
but create checkpoint folder (needed for saving adapter)"""
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics[metric_to_check]
operator = np.greater if self.args.greater_is_better else np.less
if (
self.state.best_metric is None
or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric)
):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
os.makedirs(output_dir, exist_ok=True)
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
class PeftSavingCallback(TrainerCallback):
"""Correctly save PEFT model and not full model"""
def _save(self, model, folder):
peft_model_path = os.path.join(folder, "adapter_model")
model.save_pretrained(peft_model_path)
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Save final best model adapter"""
pass
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
folder = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
)
self._save(kwargs["model"], folder)
def load_hf_model(
base_model,
lora_config=None,
mode=8,
gradient_checkpointing=False,
device_map="auto",
):
from peft import prepare_model_for_kbit_training
kwargs = {"device_map": device_map}
if mode == 8:
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=0.0,
)
elif mode == 4:
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
elif mode == 16:
kwargs["torch_dtype"] = torch.float16
model = AutoModelForCausalLM.from_pretrained(base_model, **kwargs)
# setup tokenizer
tokenizer = LlamaTokenizer.from_pretrained(base_model)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
if lora_config:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
return model, tokenizer
# noinspection PyTypeChecker
def train(
# model/data params
base_model: str = "", # the only required argument
data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "./lora-alpaca",
# training hyperparams
batch_size: int = 128,
micro_batch_size: int = 4,
num_epochs: int = 3,
learning_rate: float = 3e-4,
cutoff_len: int = 256,
val_set_size: int = 2000,
eval_steps: int = 100,
save_steps: int = 10,
logging_steps: int = 10,
# lora hyperparams
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = (
"q_proj",
"v_proj",
),
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = False,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
prompt_template: str = "vicuna", # The prompt template to use, will default to alpaca.
# memory optimization params
mode: Union[int, str] = 8, # training floating point mode
gradient_checkpointing: bool = False,
# GPTQ specific params
gptq_backend: str = "cuda", # GPTQ backend "cuda" or "triton"
gptq_groupsize: int = 128,
# evaluation flag
eval: bool = False,
):
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training Alpaca-LoRA model with params:\n"
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"mode: {mode}\n"
f"eval: {eval}\n"
f"gradient_checkpointing: {gradient_checkpointing}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"eval_steps: {eval_steps}\n"
f"logging_steps: {logging_steps}\n"
f"save_steps: {save_steps}\n"
f"lora_r: {lora_r}\n"
f"lora_alpha: {lora_alpha}\n"
f"lora_dropout: {lora_dropout}\n"
f"lora_target_modules: {lora_target_modules}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"add_eos_token: {add_eos_token}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt_template: {prompt_template}\n"
)
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size
prompter = PromptSelector.from_template_name(prompt_template, verbose=False)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or (
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
)
print("use_wandb", use_wandb)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
os.environ["WANDB_PROJECT"] = wandb_project
if len(wandb_watch) > 0:
os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
# setup model and tokenizer
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
if isinstance(mode, int):
# use HF loader for normal model loading with bitsandbytes quantization
model, tokenizer = load_hf_model(
base_model,
lora_config,
mode=mode,
gradient_checkpointing=gradient_checkpointing,
device_map=device_map,
)
# setup model checkpoint if neeeded
if resume_from_checkpoint:
# Check the available weights and load them
checkpoint_name = os.path.join(
resume_from_checkpoint, "pytorch_model.bin"
) # Full checkpoint
if not os.path.exists(checkpoint_name):
checkpoint_name = os.path.join(
resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name, map_location="cpu")
set_peft_model_state_dict(model, adapters_weights)
else:
print(f"Checkpoint {checkpoint_name} not found")
elif mode == "gptq":
from utils.loader.gptq_loader import load_model_gptq
kwargs = {
"gradient_checkpointing": gradient_checkpointing,
"device_map": device_map,
"group_size": gptq_groupsize,
"backend": gptq_backend,
}
if resume_from_checkpoint:
kwargs.update(
{
"lora_path": resume_from_checkpoint,
"load_lora": True,
"lora_trainable": True,
}
)
print(f"Restarting from {resume_from_checkpoint}")
model, tokenizer = load_model_gptq(
base_model,
lora_config,
**kwargs,
)
else:
raise NotImplementedError(f"Mode '{mode}' is not supported.")
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
tokenizer.padding_side = "left" # Allow batched inference
def tokenize(prompt, add_eos_token=True):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(**data_point)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
assert isinstance(prompter, AlpacaPrompter)
user_prompt = prompter.generate_prompt(
data_point["instruction"], data_point["input"]
)
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=add_eos_token)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
if add_eos_token:
user_prompt_len -= 1
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
# data preparation
# check if using raw text format (prompter is None)
if prompter is None:
train_data = load_text_file(
data_path, tokenizer, cutoff_len=cutoff_len, overlap_len=cutoff_len // 2
)
if val_set_size > 0:
train_val = train_data.train_test_split(
test_size=val_set_size, shuffle=True, seed=42
)
train_data = train_val["train"].shuffle()
val_data = train_val["test"]
else:
val_data = None
else:
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_path)
else:
if os.path.exists(data_path):
data = load_dataset(
"json",
data_files={
"train": data_path + "/train.json",
"test": data_path + "/test.json",
},
)
else:
data = load_dataset(data_path)
if val_set_size > 0:
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=42
)
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = data["test"].map(generate_and_tokenize_prompt)
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
model.config.use_cache = False
# sanity check of model saving process
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
model.save_pretrained(output_dir)
trainer = PeftTrainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
per_device_eval_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=100,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
fp16=True,
logging_steps=logging_steps,
optim="paged_adamw_8bit" if mode in [4, 8] else "adamw_torch",
evaluation_strategy="steps" if eval_steps > 0 else "no",
save_strategy="steps",
eval_steps=eval_steps if eval_steps > 0 else None,
save_steps=save_steps,
output_dir=output_dir,
save_total_limit=3,
load_best_model_at_end=False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else None,
run_name=wandb_run_name if use_wandb else None,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
callbacks=[PeftSavingCallback],
)
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
if eval:
eval_results = trainer.evaluate()
print(eval_results)
else:
trainer.train()
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
model.save_pretrained(output_dir)
print(
"\n If there's a warning about missing keys above, please disregard :)"
)
if __name__ == "__main__":
fire.Fire(train)