Skip to content

Commit

Permalink
Feat(QA): Check init model weights (InternLM#502)
Browse files Browse the repository at this point in the history
* check_init

* check_init

* check_init

* check_init
  • Loading branch information
li126com authored Nov 16, 2023
1 parent be5b9ea commit e8cf27b
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 31 deletions.
176 changes: 176 additions & 0 deletions tests/test_training/7B_check_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
JOB_NAME = "7b_train"
DO_ALERT = False

SEQ_LEN = 2048
HIDDEN_SIZE = 4096
NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3
NUM_LAYER = 32
VOCAB_SIZE = 103168

CHECK_INIT = 1

# MODEL_ONLY_FOLDER = "llm_ckpts_test_3/2"
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
# SAVE_CKPT_FOLDER = "local:llm_ckpts_test_3"
# LOAD_CKPT_FOLDER = "local:llm_ckpts_test_3"

# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
auto_resume=False,
# save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
# load_ckpt_folder="local:llm_ckpts/",
# 'load_ckpt_info' setting guide:
# 1. the 'path' indicate ckpt path,
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported.
# load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("all",), ckpt_type="internlm"),
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
# with an automatic restart mechanism upon training reboot.
# Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
# path specified in `load_ckpt_info` by default.
# If you want to initialize your model weights from another model, you must set `auto_resume` to False.
# If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
# auto_resume=False,
checkpoint_every=CHECKPOINT_EVERY,
# async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
# async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
# oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)

TRAIN_FOLDER = "/path/to/dataset"
VALID_FOLDER = "/path/to/dataset"
data = dict(
seq_len=SEQ_LEN,
# micro_num means the number of micro_batch contained in one gradient update
micro_num=4,
# packed_length = micro_bsz * SEQ_LEN
micro_bsz=2,
# defaults to the value of micro_num
valid_micro_num=4,
# defaults to 0, means disable evaluate
valid_every=50,
pack_sample_into_one=False,
total_steps=0,
skip_batches="",
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
min_length=50,
# train_folder=TRAIN_FOLDER,
# valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=10,
diag_outlier_ratio=1.1,
)

grad_scaler = dict(
fp16=dict(
# the initial loss scale, defaults to 2**16
initial_scale=2**16,
# the minimum loss scale, defaults to None
min_scale=1,
# the number of steps to increase loss scale when no overflow occurs
growth_interval=1000,
),
# the multiplication factor for increasing loss scale, defaults to 2
growth_factor=2,
# the multiplication factor for decreasing loss scale, defaults to 0.5
backoff_factor=0.5,
# the maximum loss scale, defaults to None
max_scale=2**24,
# the number of overflows before decreasing loss scale, defaults to 2
hysteresis=2,
)

hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication
overlap_sync_grad=True,
overlap_sync_param=True,
# bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024,
# grad clipping
clip_grad_norm=1.0,
)

loss = dict(
label_smoothing=0,
)

adam = dict(
lr=1e-4,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=0.01,
)

lr_scheduler = dict(
total_steps=data["total_steps"],
init_steps=0, # optimizer_warmup_step
warmup_ratio=0.01,
eta_min=1e-5,
last_epoch=-1,
)

beta2_scheduler = dict(
init_beta2=adam["adam_beta2"],
c=adam["adam_beta2_c"],
cur_iter=-1,
)

model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
)
"""
zero1 parallel:
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
so parameters will be divided within the range of dp.
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=4,
pipeline=dict(size=2, interleaved_overlap=True),
sequence_parallel=False,
)

cudnn_deterministic = False
cudnn_benchmark = False

monitor = dict(
# feishu alert configs
alert=dict(
enable_feishu_alert=DO_ALERT,
feishu_alert_address=None, # feishu webhook to send alert message
light_monitor_address=None, # light_monitor address to send heartbeat
),
)
89 changes: 58 additions & 31 deletions tests/test_training/train_CI.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,56 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import sys

script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(script_dir, "../../"))
sys.path.append(project_root)

# pylint: disable=C0413,W0612,W0611
import socket
import sys
import time
import traceback
from functools import partial

import torch
import torch.distributed as dist

import internlm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerMetricHook
from internlm.core.trainer import TrainState
from internlm.initialize import initialize_distributed_env
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.monitor import initialize_monitor_manager, send_alert_message
from internlm.monitor.monitor import monitor_manager as mm
from internlm.train import (
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(script_dir, "../../"))
sys.path.append(project_root)

import internlm # noqa: E402
from internlm.core.context import ParallelMode # noqa: E402
from internlm.core.context import global_context as gpc # noqa: E402
from internlm.core.scheduler import SchedulerMetricHook # noqa: E402
from internlm.core.trainer import TrainState # noqa: E402
from internlm.initialize import initialize_distributed_env # noqa: E402
from internlm.model.loss import FlashGPTLMLoss # noqa: E402
from internlm.model.metrics import AccPerplex # noqa: E402
from internlm.monitor import ( # noqa: E402
initialize_monitor_manager,
send_alert_message,
)
from internlm.monitor.monitor import monitor_manager as mm # noqa: E402
from internlm.train import ( # noqa: E402
get_train_data_loader,
get_validation_data_loader,
initialize_llm_profile,
initialize_model,
initialize_optimizer,
load_new_batch,
record_current_batch_training_metrics,
)
from internlm.utils.common import (
from internlm.utils.common import ( # noqa: E402
BatchSkipper,
get_megatron_flops,
launch_time,
parse_args,
)
from internlm.utils.evaluation import evaluate_on_val_dls
from internlm.utils.gputest import empty_cache_and_diag
from internlm.utils.logger import get_logger, initialize_uniscale_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import CheckpointManager
from internlm.utils.parallel import get_parallel_log_file_name
from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler
from internlm.utils.writer import Writer
from internlm.utils.evaluation import evaluate_on_val_dls # noqa: E402
from internlm.utils.gputest import empty_cache_and_diag # noqa: E402
from internlm.utils.logger import get_logger, initialize_uniscale_logger # noqa: E402
from internlm.utils.megatron_timers import megatron_timer as timer # noqa: E402
from internlm.utils.model_checkpoint import CheckpointManager # noqa: E402
from internlm.utils.parallel import get_parallel_log_file_name # noqa: E402
from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler # noqa: E402
from internlm.utils.writer import Writer # noqa: E402

# global llm logger
logger = get_logger(__file__)
Expand All @@ -74,15 +76,24 @@ def initialize_llm_logger(start_time: str):
return uniscale_logger


def check_model_weights(model, ckpt_path):
def check_model_weights(model, ckpt_path, total_equal=False):
model1_dict = torch.load(ckpt_path, map_location="cuda")
model2_dict = model.state_dict()

for key in model1_dict.keys():
if key in model2_dict:
tensor1 = model1_dict[key]
tensor2 = model2_dict[key]
assert torch.allclose(tensor1, tensor2, rtol=3e-2, atol=3e-2)
if total_equal:
assert torch.equal(tensor1, tensor2), "model weights are not equal"
else:
assert torch.allclose(tensor1, tensor2, rtol=3e-2, atol=3e-2), "model weights are not close"
else:
if gpc.is_rank_for_log():
logger.warning(f"The old key {key} no longer exists!")

if gpc.is_rank_for_log():
logger.info("Weight check passed")


def main(args):
Expand Down Expand Up @@ -207,7 +218,15 @@ def main(args):
trainer.train()

# transfer the train data loader into train data iterator
train_iter = iter(train_dl)
# train_iter = iter(train_dl)

# check model init weights
if hasattr(gpc.config, "CHECK_INIT") and gpc.config.CHECK_INIT == 1:
ckpt_name = (
f"model_tp{gpc.get_local_rank(ParallelMode.TENSOR)}_pp{gpc.get_local_rank(ParallelMode.PIPELINE)}.pt"
)
ckpt_path = os.path.join(os.environ["share_path"], "quailty_assurance/7B_init_8_tp=4_pp=2_ckpt", ckpt_name)
check_model_weights(model, ckpt_path, total_equal=True)

with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
# start iterating the train data and begin training
Expand All @@ -223,7 +242,11 @@ def main(args):
if batch_index == 0:
data_local_rank = gpc.get_local_rank(ParallelMode.DATA)
batch_step = (batch_count // 1000 + 1) * 1000
data_path = f"/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/dp-11{data_local_rank}/batch-{batch_step}.pt"
data_path = os.path.join(
os.environ["share_path"],
"quailty_assurance/debug_Qiansanqiang_7B_v16",
f"dp-11{data_local_rank}/batch-{batch_step}.pt",
)
data_1000 = torch.load(data_path, map_location=torch.device("cpu"))
batch = data_1000[batch_index]

Expand Down Expand Up @@ -311,9 +334,13 @@ def main(args):
update_panel=uniscale_logger is not None,
)

# check model weights
if batch_count > 0 and batch_count % 100 == 0:
ckpt_path = os.path.join(
"/mnt/petrelfs/share/quailty_assurance/7B_model_weights_ckpt", str(batch_count), "model_tp0_pp0.pt"
os.environ["share_path"],
"quailty_assurance/7B_model_weights_ckpt",
str(batch_count),
"model_tp0_pp0.pt",
)
check_model_weights(model, ckpt_path)

Expand Down

0 comments on commit e8cf27b

Please sign in to comment.