Skip to content

Commit fb6a587

Browse files
authored
Feat(QA): temp no fa (#75)
1 parent e8f5118 commit fb6a587

File tree

3 files changed

+136
-2
lines changed

3 files changed

+136
-2
lines changed

tests/common_fixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from internlm.data.utils import unpack_data
1212
from internlm.initialize.launch import args_sanity_check
1313

14-
config = Config(
14+
config_7B = Config(
1515
dict(
1616
parallel=dict(
1717
zero1=dict(size=-1),
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import multiprocessing as mp
2+
3+
import pytest
4+
import torch
5+
6+
import internlm
7+
from internlm.core.context import ParallelMode
8+
from internlm.core.context import global_context as gpc
9+
from internlm.model.loss import FlashGPTLMLoss
10+
from internlm.model.metrics import AccPerplex
11+
from internlm.train import (
12+
get_scheduler_hooks,
13+
get_train_data_loader,
14+
initialize_isp_communicator,
15+
initialize_model,
16+
initialize_optimizer,
17+
)
18+
from internlm.utils.logger import get_logger
19+
from tests.common_fixture import (
20+
build_environment,
21+
config_7B,
22+
find_free_port,
23+
load_new_batch,
24+
seed_all,
25+
)
26+
27+
logger = get_logger(__file__)
28+
29+
# init config
30+
config = config_7B
31+
total_steps = 5
32+
config.data.total_steps = total_steps
33+
config.lr_scheduler.total_steps = total_steps
34+
config.model.use_flash_attn = False
35+
config.parallel.pipeline = dict(size=2, interleaved_overlap=True)
36+
37+
38+
def train_check(args):
39+
# init
40+
rank, world_size, free_port, mode, num_chunks = args
41+
config.model.num_chunks = num_chunks
42+
config.parallel.tensor = dict(size=2, mode=f"{mode}")
43+
if mode == "isp":
44+
config.parallel.weight = dict(size=4, overlap=True, memory_pool=True)
45+
46+
build_environment(rank, world_size, free_port, config)
47+
48+
# set seed
49+
seed_all(1024)
50+
51+
# initialize model
52+
model = initialize_model()
53+
54+
# initialize isp communicator
55+
isp_communicator = initialize_isp_communicator(model)
56+
57+
# initialize loss function
58+
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing)
59+
60+
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator)
61+
62+
train_dl, dataset_types = get_train_data_loader(num_worker=0)
63+
64+
metric = AccPerplex(
65+
device=torch.cuda.current_device(),
66+
tp_pg=gpc.get_group(ParallelMode.TENSOR),
67+
dp_pg=gpc.get_group(ParallelMode.DATA),
68+
dataset_types=dataset_types,
69+
)
70+
71+
trainer, train_dl, _, _ = internlm.initialize_trainer(
72+
model=model,
73+
optimizer=optimizer,
74+
criterion=criterion,
75+
train_dataloader=train_dl,
76+
lr_scheduler=lr_scheduler,
77+
beta2_scheduler=beta2_scheduler,
78+
scheduler_hooks=get_scheduler_hooks(metric, optimizer, isp_communicator),
79+
)
80+
81+
# transfer the train data loader into train data iterator
82+
trainer.train()
83+
84+
train_iter = iter(train_dl)
85+
86+
for batch_count in range(total_steps):
87+
if gpc.is_rank_for_log():
88+
print(f"{mode}: {batch_count}", flush=True)
89+
90+
# load batch data
91+
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter)
92+
93+
# zero the grads of parameters
94+
trainer.zero_grad()
95+
96+
# process data
97+
if batch[0].get("type_ids", None) is not None:
98+
metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))
99+
100+
# zero the grads of parameters
101+
_, _, _ = trainer.execute_schedule(
102+
batch,
103+
forward_only=False,
104+
return_loss=True,
105+
return_output_label=False,
106+
)
107+
108+
if isp_communicator and isp_communicator.enable_memory_pool:
109+
isp_communicator.memory_pool.reset_lazy_pools()
110+
111+
trainer.step()
112+
torch.cuda.reset_peak_memory_stats()
113+
114+
115+
mode_list = ["mtp"]
116+
num_chunks = [1, 2]
117+
118+
119+
@pytest.mark.parametrize("mode", mode_list)
120+
@pytest.mark.parametrize("num_chunks", num_chunks)
121+
def test_train(mode, num_chunks):
122+
free_port = find_free_port()
123+
ctx = mp.get_context("spawn")
124+
with ctx.Pool(processes=8) as pool:
125+
pool.map(
126+
train_check,
127+
[[rank, 8, free_port, mode, num_chunks] for rank in range(8)],
128+
)
129+
pool.close()
130+
pool.join()

tests/test_training/test_norm_weight.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020
from internlm.utils.logger import get_logger
2121
from tests.common_fixture import (
2222
build_environment,
23-
config,
23+
config_7B,
2424
find_free_port,
2525
load_new_batch,
2626
seed_all,
2727
)
2828

2929
logger = get_logger(__file__)
30+
config = config_7B
3031

3132

3233
def compute_rotol(tensor1, tensor2):
@@ -109,6 +110,9 @@ def train_check_norm_weight(args):
109110
# load batch data
110111
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter)
111112

113+
# zero the grads of parameters
114+
trainer.zero_grad()
115+
112116
# process data
113117
if batch[0].get("type_ids", None) is not None:
114118
metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))

0 commit comments

Comments
 (0)