-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
113 lines (89 loc) · 3.69 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import lightning.pytorch as pl
from configs import Config
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import LearningRateMonitor
from models.gpt2_rope import GPTModel
from data.data_loader import IndicGPTDataModule
from keys import WANDB_KEY
import wandb
wandb.login(key=WANDB_KEY)
import sentencepiece as spm
import torch
import os
import warnings
warnings.simplefilter('ignore')
if __name__ == "__main__":
# Init config
config = Config()
train_file_path = config.train_file_path
val_file_path = config.val_file_path
tokenizer_path = config.tokenizer_path
checkpoint_dir = config.checkpoint_dir
files = os.listdir(val_file_path)
val_file_path = [val_file_path+i for i in files]
val_id_to_name = {i:files[i].split('.')[0] for i in range(len(files))}
files = os.listdir(train_file_path)
train_file_path = [train_file_path+i for i in files]
train_id_to_name = {i:files[i].split('.')[0] for i in range(len(files))}
print('Val id',val_id_to_name)
print('Train id',train_id_to_name)
Config.val_id_to_name = val_id_to_name
Config.train_id_to_name = train_id_to_name
#Wandb Logger
wandb_logger = WandbLogger(name=config.wandb_name,project=config.wandb_project, job_type='train', offline=True, dir=f'./wandb-{config.wandb_name}/')
#Turn on SDP kernels for flash attention
if config.use_flashattn:
torch.backends.cuda.enable_flash_sdp(True) #Enable flash scaled dot product attention
torch.backends.cuda.enable_mem_efficient_sdp(False) #Enable mem efficient SDP
torch.backends.cuda.enable_math_sdp(False) #Math sdp
#Print status
print(torch.backends.cuda.flash_sdp_enabled())
print(torch.backends.cuda.mem_efficient_sdp_enabled())
print(torch.backends.cuda.math_sdp_enabled())
# Tokenizer
tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
# Define the model and dataset
model = GPTModel(config,tokenizer)
data = IndicGPTDataModule(
config=config,
train_file=train_file_path,
val_file=val_file_path,
tokenizer=tokenizer,
batch_size=config.bs
)
total_params = sum(p.numel() for name,p in model.named_parameters() if p.requires_grad and ('embedding' not in name and 'lm_head' not in name))
for name,p in model.named_parameters():
print(name, p.numel() if p.requires_grad else "None")
print(f"Total parameters in the model excluding embedding layer: {total_params}")
#Callback
checkpoint_callback = ModelCheckpoint(
dirpath=checkpoint_dir,
filename="minilm-{epoch:02d}-{val_loss:.5f}",
save_top_k=config.save_top_k,
save_last=True,
monitor="train_loss",
every_n_train_steps=config.checkpoint_every_n_steps,
mode="min",
save_on_train_epoch_end=True
)
lr_monitor = LearningRateMonitor(logging_interval='step')
trainer = pl.Trainer(
accelerator=config.accelerator,
devices=config.NUM_DEVICES,
strategy=config.strategy,
num_nodes=config.NUM__NODES,
max_epochs=config.epochs,
detect_anomaly=True,
enable_checkpointing=True,
val_check_interval=config.val_every,
logger=wandb_logger,
log_every_n_steps=config.log_every_n_steps,
callbacks=[checkpoint_callback, lr_monitor],
precision=config.precision,
gradient_clip_val=config.gradient_clip_val,
accumulate_grad_batches=config.accumulate_grad_batches
)
trainer.fit(model, data)
print('Best model path', checkpoint_callback.best_model_path)
wandb.finish()