1
+ from datasets import load_dataset
2
+ from transformers import (
3
+ AutoTokenizer ,
4
+ AutoModelForCausalLM ,
5
+ TrainingArguments ,
6
+ Trainer ,
7
+ PreTrainedTokenizerBase ,
8
+ HfArgumentParser ,
9
+ AdamW ,
10
+ default_data_collator ,
11
+ )
12
+ from transformers .utils import PaddingStrategy
13
+ from deepspeed .runtime .lr_schedules import WarmupDecayLR
14
+ from typing import Optional , Union , List , Dict , Any
15
+ import evaluate
16
+ from dataclasses import dataclass , field
17
+ import torch .nn as nn
18
+ import numpy as np
19
+ import wandb
20
+ import multiprocessing
21
+ cpu_cores = multiprocessing .cpu_count ()
22
+
23
+ # python -m torch.distributed.launch --nproc_per_node=8 train_sft.py \
24
+ # --per_device_train_batch_size=8 --per_device_eval_batch_size=8 --gradient_accumulation_steps=2 \
25
+ # --model_name=facebook/xglm-1.7B --bf16 --deepspeed=../config/sft_deepspeed_config.json
26
+
27
+ # Define and parse arguments.
28
+ @dataclass
29
+ class ScriptArguments :
30
+ """
31
+ These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
32
+ """
33
+ num_train_epochs : Optional [int ] = field (default = 2 )
34
+ resume_from_checkpoint : Optional [bool ] = field (default = False )
35
+ #multigpu stuff
36
+ local_rank : Optional [int ] = field (default = 0 )
37
+ deepspeed : Optional [str ] = field (default = None )
38
+ per_device_train_batch_size : Optional [int ] = field (default = 8 )
39
+ per_device_eval_batch_size : Optional [int ] = field (default = 8 )
40
+ gradient_accumulation_steps : Optional [int ] = field (default = 2 )
41
+ #lr stuff
42
+ max_learning_rate : Optional [float ] = field (default = 1e-5 )
43
+ min_learning_rate : Optional [float ] = field (default = 0. )
44
+ weight_decay : Optional [float ] = field (default = 0.001 )
45
+ warmup_ratio : Optional [float ] = field (default = 0.1 )
46
+ #logging stuff
47
+ wandb_project : Optional [str ] = field (default = "php_sft_model" )
48
+ logging_steps : Optional [int ] = field (default = 50 )
49
+ #eval stuff
50
+ eval_steps : Optional [int ] = field (default = 500 )
51
+ #model and dataset
52
+ model_name : Optional [str ] = field (default = "facebook/xglm-1.7B" )
53
+ dataset_name : Optional [str ] = field (default = "pythainlp/php_sft" )
54
+ question_column : Optional [str ] = field (default = "question" )
55
+ answer_column : Optional [str ] = field (default = "answer" )
56
+ train_split_name : Optional [str ] = field (default = "train" )
57
+ eval_split_name : Optional [str ] = field (default = "test" )
58
+ #tokenizer stuff
59
+ max_length : Optional [int ] = field (default = 512 )
60
+ #half precision stuff
61
+ bf16 : Optional [bool ] = field (default = True ,)
62
+
63
+ parser = HfArgumentParser (ScriptArguments )
64
+ script_args = parser .parse_args_into_dataclasses ()[0 ]
65
+
66
+ # initialize wandb with project and run names
67
+ wandb .init (project = script_args .wandb_project ,
68
+ name = f"{ script_args .wandb_project } _{ wandb .util .generate_id ()} " )
69
+
70
+ # Load the human comparisons dataset for tuning the reward model.
71
+ ds = load_dataset (script_args .dataset_name )
72
+ #debug
73
+ ds ['train' ] = ds ['train' ].select ([i for i in range (1000 )])
74
+ ds ['test' ] = ds ['test' ].select ([i for i in range (100 )])
75
+
76
+ # Define the training args. Needs to be done before the model is loaded if you are using deepspeed.
77
+ training_args = TrainingArguments (
78
+ output_dir = f"{ script_args .model_name } _sft_model" ,
79
+ learning_rate = script_args .max_learning_rate ,
80
+ per_device_train_batch_size = script_args .per_device_train_batch_size ,
81
+ per_device_eval_batch_size = script_args .per_device_eval_batch_size ,
82
+ num_train_epochs = script_args .num_train_epochs ,
83
+ weight_decay = script_args .weight_decay ,
84
+ warmup_ratio = script_args .warmup_ratio ,
85
+ evaluation_strategy = "steps" ,
86
+ eval_steps = script_args .eval_steps ,
87
+ metric_for_best_model = "accuracy" ,
88
+ greater_is_better = True ,
89
+ logging_steps = script_args .logging_steps ,
90
+ save_strategy = "steps" ,
91
+ save_steps = script_args .eval_steps ,
92
+ load_best_model_at_end = True ,
93
+ gradient_accumulation_steps = script_args .gradient_accumulation_steps ,
94
+ deepspeed = script_args .deepspeed ,
95
+ local_rank = script_args .local_rank ,
96
+ label_names = [],
97
+ remove_unused_columns = False
98
+ )
99
+
100
+ # Load the value-head model and tokenizer.
101
+ tokenizer = AutoTokenizer .from_pretrained (script_args .model_name )
102
+ model = AutoModelForCausalLM .from_pretrained (script_args .model_name )
103
+
104
+ # Need to do this for gpt2, because it doesn't have an official pad token.
105
+ tokenizer .pad_token = tokenizer .eos_token
106
+ model .config .pad_token_id = tokenizer .eos_token_id
107
+
108
+ # Tokenize the dataset.
109
+ def preprocess_function (examples ):
110
+ tokenized_question = tokenizer (examples [script_args .question_column ],
111
+ truncation = True ,
112
+ padding = "max_length" ,
113
+ max_length = script_args .max_length )
114
+ tokenized_answer = tokenizer (examples [script_args .answer_column ],
115
+ truncation = True ,
116
+ padding = "max_length" ,
117
+ max_length = script_args .max_length )
118
+ return {
119
+ "input_ids" : tokenized_question ["input_ids" ],
120
+ "attention_mask" : tokenized_question ["attention_mask" ],
121
+ "labels" : tokenized_answer ["input_ids" ],
122
+ }
123
+
124
+ tokenized_ds = ds .map (preprocess_function ,
125
+ batched = True ,
126
+ num_proc = cpu_cores ,
127
+ remove_columns = ds [script_args .train_split_name ].column_names )
128
+
129
+ #use rouge for metric; not really representative but ok
130
+ rouge = evaluate .load ("rouge" )
131
+
132
+ def compute_metrics (eval_preds ):
133
+ labels_ids = eval_preds .label_ids
134
+ pred_ids = eval_preds .predictions
135
+ pred_str = tokenizer .batch_decode (pred_ids , skip_special_tokens = True )
136
+ label_str = tokenizer .batch_decode (labels_ids , skip_special_tokens = True )
137
+ result = rouge .compute (predictions = pred_str , references = label_str )
138
+ return result
139
+
140
+ # Create a preprocessing function to extract out the proper logits from the model output
141
+ def preprocess_logits_for_metrics (logits , labels ):
142
+ if isinstance (logits , tuple ):
143
+ logits = logits [0 ]
144
+ return logits .argmax (dim = - 1 )
145
+
146
+ class SFTTrainer (Trainer ):
147
+ def create_optimizer_and_scheduler (self , num_training_steps : int ):
148
+ params = self .get_model ().parameters ()
149
+ optimizer = AdamW (params , lr = self .args .max_learning_rate ,
150
+ weight_decay = self .args .weight_decay ,
151
+ bias_correction = True )
152
+ total_steps = num_training_steps
153
+ warmup_steps = int (self .args .warmup_ratio * total_steps )
154
+ scheduler = WarmupDecayLR (optimizer , total_num_steps = total_steps ,
155
+ warmup_min_lr = script_args .min_learning_rate ,
156
+ warmup_max_lr = script_args .max_learning_rate ,
157
+ warmup_num_steps = warmup_steps ,)
158
+ return optimizer , scheduler
159
+
160
+ # Train the model, woohoo.
161
+ trainer = SFTTrainer (
162
+ model = model ,
163
+ args = training_args ,
164
+ train_dataset = tokenized_ds [script_args .train_split_name ],
165
+ eval_dataset = tokenized_ds [script_args .eval_split_name ],
166
+ data_collator = default_data_collator ,
167
+ compute_metrics = compute_metrics ,
168
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics ,
169
+ )
170
+
171
+ trainer .train (script_args .resume_from_checkpoint )
172
+
173
+ # Push to the hub so you can share it with people :D
174
+ # model.push_to_hub(f"{script_args.model_name}_reward_model")
175
+ # tokenizer.push_to_hub(f"{script_args.model_name}_reward_model")
0 commit comments