Skip to content

Commit 9063c3b

Browse files
committedMar 13, 2023
add sft training codes
1 parent dc1670a commit 9063c3b

File tree

2 files changed

+233
-0
lines changed

2 files changed

+233
-0
lines changed
 

‎config/sft_deepspeed_config.json

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
{
2+
"fp16": {
3+
"enabled": "auto",
4+
"loss_scale": 0,
5+
"loss_scale_window": 1000,
6+
"initial_scale_power": 16,
7+
"hysteresis": 2,
8+
"min_loss_scale": 1
9+
},
10+
11+
"optimizer": {
12+
"type": "AdamW",
13+
"params": {
14+
"lr": "auto",
15+
"betas": "auto",
16+
"eps": "auto",
17+
"weight_decay": "auto"
18+
}
19+
},
20+
21+
"scheduler": {
22+
"type": "WarmupDecayLR",
23+
"params": {
24+
"total_num_steps": "auto",
25+
"warmup_min_lr": "auto",
26+
"warmup_max_lr": "auto",
27+
"warmup_num_steps": "auto"
28+
}
29+
},
30+
31+
"zero_optimization": {
32+
"stage": 3,
33+
"offload_optimizer": {
34+
"device": "cpu",
35+
"pin_memory": true
36+
},
37+
"offload_param": {
38+
"device": "cpu",
39+
"pin_memory": true
40+
},
41+
"overlap_comm": true,
42+
"contiguous_gradients": true,
43+
"sub_group_size": 1e9,
44+
"reduce_bucket_size": "auto",
45+
"stage3_prefetch_bucket_size": "auto",
46+
"stage3_param_persistence_threshold": "auto",
47+
"stage3_max_live_parameters": 1e9,
48+
"stage3_max_reuse_distance": 1e9,
49+
"stage3_gather_16bit_weights_on_model_save": true
50+
},
51+
52+
"gradient_accumulation_steps": "auto",
53+
"gradient_clipping": "auto",
54+
"steps_per_print": 2000,
55+
"train_batch_size": "auto",
56+
"train_micro_batch_size_per_gpu": "auto",
57+
"wall_clock_breakdown": false
58+
}

‎script/train_sft.py

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
from datasets import load_dataset
2+
from transformers import (
3+
AutoTokenizer,
4+
AutoModelForCausalLM,
5+
TrainingArguments,
6+
Trainer,
7+
PreTrainedTokenizerBase,
8+
HfArgumentParser,
9+
AdamW,
10+
default_data_collator,
11+
)
12+
from transformers.utils import PaddingStrategy
13+
from deepspeed.runtime.lr_schedules import WarmupDecayLR
14+
from typing import Optional, Union, List, Dict, Any
15+
import evaluate
16+
from dataclasses import dataclass, field
17+
import torch.nn as nn
18+
import numpy as np
19+
import wandb
20+
import multiprocessing
21+
cpu_cores = multiprocessing.cpu_count()
22+
23+
# python -m torch.distributed.launch --nproc_per_node=8 train_sft.py \
24+
# --per_device_train_batch_size=8 --per_device_eval_batch_size=8 --gradient_accumulation_steps=2 \
25+
# --model_name=facebook/xglm-1.7B --bf16 --deepspeed=../config/sft_deepspeed_config.json
26+
27+
# Define and parse arguments.
28+
@dataclass
29+
class ScriptArguments:
30+
"""
31+
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
32+
"""
33+
num_train_epochs: Optional[int] = field(default=2)
34+
resume_from_checkpoint: Optional[bool] = field(default=False)
35+
#multigpu stuff
36+
local_rank: Optional[int] = field(default=0)
37+
deepspeed: Optional[str] = field(default=None)
38+
per_device_train_batch_size: Optional[int] = field(default=8)
39+
per_device_eval_batch_size: Optional[int] = field(default=8)
40+
gradient_accumulation_steps: Optional[int] = field(default=2)
41+
#lr stuff
42+
max_learning_rate: Optional[float] = field(default=1e-5)
43+
min_learning_rate: Optional[float] = field(default=0.)
44+
weight_decay: Optional[float] = field(default=0.001)
45+
warmup_ratio: Optional[float] = field(default=0.1)
46+
#logging stuff
47+
wandb_project: Optional[str] = field(default="php_sft_model")
48+
logging_steps: Optional[int] = field(default=50)
49+
#eval stuff
50+
eval_steps: Optional[int] = field(default=500)
51+
#model and dataset
52+
model_name: Optional[str] = field(default="facebook/xglm-1.7B")
53+
dataset_name: Optional[str] = field(default="pythainlp/php_sft")
54+
question_column: Optional[str] = field(default="question")
55+
answer_column: Optional[str] = field(default="answer")
56+
train_split_name: Optional[str] = field(default="train")
57+
eval_split_name: Optional[str] = field(default="test")
58+
#tokenizer stuff
59+
max_length: Optional[int] = field(default=512)
60+
#half precision stuff
61+
bf16: Optional[bool] = field(default=True,)
62+
63+
parser = HfArgumentParser(ScriptArguments)
64+
script_args = parser.parse_args_into_dataclasses()[0]
65+
66+
# initialize wandb with project and run names
67+
wandb.init(project=script_args.wandb_project,
68+
name=f"{script_args.wandb_project}_{wandb.util.generate_id()}")
69+
70+
# Load the human comparisons dataset for tuning the reward model.
71+
ds = load_dataset(script_args.dataset_name)
72+
#debug
73+
ds['train'] = ds['train'].select([i for i in range(1000)])
74+
ds['test'] = ds['test'].select([i for i in range(100)])
75+
76+
# Define the training args. Needs to be done before the model is loaded if you are using deepspeed.
77+
training_args = TrainingArguments(
78+
output_dir=f"{script_args.model_name}_sft_model",
79+
learning_rate=script_args.max_learning_rate,
80+
per_device_train_batch_size=script_args.per_device_train_batch_size,
81+
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
82+
num_train_epochs=script_args.num_train_epochs,
83+
weight_decay=script_args.weight_decay,
84+
warmup_ratio=script_args.warmup_ratio,
85+
evaluation_strategy="steps",
86+
eval_steps=script_args.eval_steps,
87+
metric_for_best_model="accuracy",
88+
greater_is_better=True,
89+
logging_steps=script_args.logging_steps,
90+
save_strategy="steps",
91+
save_steps=script_args.eval_steps,
92+
load_best_model_at_end=True,
93+
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
94+
deepspeed=script_args.deepspeed,
95+
local_rank=script_args.local_rank,
96+
label_names=[],
97+
remove_unused_columns=False
98+
)
99+
100+
# Load the value-head model and tokenizer.
101+
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)
102+
model = AutoModelForCausalLM.from_pretrained(script_args.model_name)
103+
104+
# Need to do this for gpt2, because it doesn't have an official pad token.
105+
tokenizer.pad_token = tokenizer.eos_token
106+
model.config.pad_token_id = tokenizer.eos_token_id
107+
108+
# Tokenize the dataset.
109+
def preprocess_function(examples):
110+
tokenized_question = tokenizer(examples[script_args.question_column],
111+
truncation=True,
112+
padding="max_length",
113+
max_length=script_args.max_length)
114+
tokenized_answer = tokenizer(examples[script_args.answer_column],
115+
truncation=True,
116+
padding="max_length",
117+
max_length=script_args.max_length)
118+
return {
119+
"input_ids": tokenized_question["input_ids"],
120+
"attention_mask": tokenized_question["attention_mask"],
121+
"labels": tokenized_answer["input_ids"],
122+
}
123+
124+
tokenized_ds = ds.map(preprocess_function,
125+
batched=True,
126+
num_proc=cpu_cores,
127+
remove_columns=ds[script_args.train_split_name].column_names)
128+
129+
#use rouge for metric; not really representative but ok
130+
rouge = evaluate.load("rouge")
131+
132+
def compute_metrics(eval_preds):
133+
labels_ids = eval_preds.label_ids
134+
pred_ids = eval_preds.predictions
135+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
136+
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
137+
result = rouge.compute(predictions=pred_str, references=label_str)
138+
return result
139+
140+
# Create a preprocessing function to extract out the proper logits from the model output
141+
def preprocess_logits_for_metrics(logits, labels):
142+
if isinstance(logits, tuple):
143+
logits = logits[0]
144+
return logits.argmax(dim=-1)
145+
146+
class SFTTrainer(Trainer):
147+
def create_optimizer_and_scheduler(self, num_training_steps: int):
148+
params = self.get_model().parameters()
149+
optimizer = AdamW(params, lr=self.args.max_learning_rate,
150+
weight_decay=self.args.weight_decay,
151+
bias_correction=True)
152+
total_steps = num_training_steps
153+
warmup_steps = int(self.args.warmup_ratio*total_steps)
154+
scheduler = WarmupDecayLR(optimizer, total_num_steps=total_steps,
155+
warmup_min_lr=script_args.min_learning_rate,
156+
warmup_max_lr=script_args.max_learning_rate,
157+
warmup_num_steps=warmup_steps,)
158+
return optimizer, scheduler
159+
160+
# Train the model, woohoo.
161+
trainer = SFTTrainer(
162+
model=model,
163+
args=training_args,
164+
train_dataset=tokenized_ds[script_args.train_split_name],
165+
eval_dataset=tokenized_ds[script_args.eval_split_name],
166+
data_collator=default_data_collator,
167+
compute_metrics=compute_metrics,
168+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
169+
)
170+
171+
trainer.train(script_args.resume_from_checkpoint)
172+
173+
# Push to the hub so you can share it with people :D
174+
# model.push_to_hub(f"{script_args.model_name}_reward_model")
175+
# tokenizer.push_to_hub(f"{script_args.model_name}_reward_model")

0 commit comments

Comments
 (0)
Please sign in to comment.