Skip to content

Commit

Permalink
Fix (unitest, interleaved pp and other bugs): re-adapt unitest for is…
Browse files Browse the repository at this point in the history
…p and adapt interleaved pp for no flash_attention (#52)
  • Loading branch information
li126com authored Mar 1, 2024
1 parent 7ab7764 commit 20a2832
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 47 deletions.
9 changes: 9 additions & 0 deletions internlm/core/scheduler/pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,15 @@ def load_micro_batch(self, model_chunk_id):
offset=self.microbatch_offset[model_chunk_id],
bsz_stride=self.bsz_stride,
)
if self.data_process_func:
micro_batch_data["input_ids"] = self.data_process_func(
micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"]
)
micro_batch_label = self.data_process_func(micro_batch_label, micro_batch_data["cu_seqlens"])

micro_batch_data.pop("cu_seqlens")
micro_batch_data.pop("indexes")

micro_batch_data["label"] = micro_batch_label
self.microbatch_offset[model_chunk_id] += self.bsz_stride
return move_to_device(micro_batch_data)
Expand Down
2 changes: 1 addition & 1 deletion internlm/data/packed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def build_unpack(self, index):

if cu_seqlens[-1] != self.packed_length:
pack = pack + [0] * (self.packed_length - cu_seqlens[-1])
labels = labels + [0] * (self.packed_length - cu_seqlens[-1])
labels = labels + [-100] * (self.packed_length - cu_seqlens[-1])
type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1])
indexes.extend(list(range(self.packed_length - cu_seqlens[-1])))
cu_seqlens.append(self.packed_length)
Expand Down
1 change: 1 addition & 0 deletions internlm/initialize/initialize_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def initialize_trainer(

communication_overlap = gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
scheduler = InterleavedPipelineScheduler(
data_process_func=data_fn,
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
num_chunks=gpc.config.model.num_chunks,
dtype=gpc.config.model["dtype"],
Expand Down
3 changes: 3 additions & 0 deletions internlm/utils/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,9 @@ def __init__(
self.storage_manager = get_storage_manager()
self.snapshot_counter = 0

if hasattr(model, "model"):
model = model.model

self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
Expand Down
44 changes: 23 additions & 21 deletions tests/test_model/test_model_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,29 +128,31 @@ def check_block(args):
)

hidden_states = hidden_states.squeeze(0).to(device).requires_grad_()

hid2 = hidden_states
output_list = []
for i in range(10):
hidden_states = hid2
# forward
for _, block in enumerate(blocks):
block = block.to(torch.bfloat16)
block = block.to(device)
hidden_states = block(
hidden_states,
cu_seqlens=cu_seqlens,
indexes=indexes,
inference_params=None,
max_seqlen=max_seqlen,
)
result = hidden_states
output_list.append(result)
hidden_states2 = hidden_states.clone()

# forward
for _, block in enumerate(blocks):
block = block.to(torch.bfloat16)
block = block.to(device)
hidden_states = block(
hidden_states,
cu_seqlens=cu_seqlens,
indexes=indexes,
inference_params=None,
max_seqlen=max_seqlen,
)
hidden_states2 = block(
hidden_states2,
cu_seqlens=cu_seqlens,
indexes=indexes,
inference_params=None,
max_seqlen=max_seqlen,
)
result = hidden_states
result2 = hidden_states2

# check only forward logits
first_output = output_list[0]
for i in range(1, 10):
assert torch.equal(first_output, output_list[i])
assert torch.equal(result, result2)

standard_result = torch.tensor(
[
Expand Down
22 changes: 14 additions & 8 deletions tests/test_solver/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def forward(self, x):
config = Config(
dict(
parallel=dict(
zero1=dict(size=1, fsdp=False),
pipeline=dict(size=1, interleaved_overlap=False),
sequence_parallel=False,
tensor=1,
zero1=dict(size=1),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
),
model_type="INTERNLM",
data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
Expand Down Expand Up @@ -106,22 +106,25 @@ def init_optimizer_grouped_parameters(check_group, model):
if check_group:
optimizer_grouped_parameters = [
{
"name": "default",
"params": list(model.parameters())[:2],
"weight_decay": config.adam.weight_decay,
"dp_mode": ParallelMode.DATA,
"optimizer_mode": ParallelMode.ZERO1,
},
{
"name": "default",
"params": list(model.parameters())[2:],
"weight_decay": config.adam.weight_decay,
"dp_mode": ParallelMode.DATA,
"optimizer_mode": ParallelMode.ZERO1,
},
]
else:
optimizer_grouped_parameters = [
{
"params": model.parameters(),
"name": "default",
"params": list(model.parameters())[:],
"weight_decay": config.adam.weight_decay,
"dp_mode": ParallelMode.DATA,
"optimizer_mode": ParallelMode.ZERO1,
}
]

Expand Down Expand Up @@ -166,6 +169,9 @@ def exam_hybrid_zero_optim_with_ddp(args):
torch_model = MlpModel().cuda()
zero_model = copy.deepcopy(torch_model).to(dtype)
torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
IS_TENSOR_ZERO_PARALLEL = "is_tensor_zero_parallel"
for param in zero_model.parameters():
setattr(param, IS_TENSOR_ZERO_PARALLEL, True)

# create optimizer
if config.hybrid_zero_optimizer.overlap_sync_param:
Expand Down
29 changes: 16 additions & 13 deletions tests/test_training/test_forward_output_no_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
config = Config(
dict(
parallel=dict(
zero1=dict(size=-1, fsdp=False),
pipeline=dict(size=1, interleaved_overlap=False),
sequence_parallel=False,
tensor=1,
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
),
data=dict(
seq_len=2048,
Expand Down Expand Up @@ -81,7 +81,7 @@
),
hybrid_zero_optimizer=dict(
overlap_sync_grad=True,
overlap_sync_param=True,
overlap_sync_param=False,
reduce_bucket_size=512 * 1024 * 1024,
clip_grad_norm=1.0,
),
Expand Down Expand Up @@ -179,7 +179,7 @@ def train_check_output(args):
SchedulerMetricHook(
metric=metric,
skip=(
gpc.is_using_pp()
gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
and hasattr(gpc.config.model, "num_chunks")
and gpc.config.model.num_chunks > 1
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
Expand Down Expand Up @@ -207,7 +207,7 @@ def train_check_output(args):
# zero the grads of parameters
output, _, _ = trainer.execute_schedule(
batch,
forward_only=False,
forward_only=True,
return_loss=True,
return_output_label=True,
)
Expand All @@ -216,18 +216,21 @@ def train_check_output(args):
standard_output_with_fa = torch.load(
os.path.join(share_path, "quailty_assurance/7B_no_flash_attention/output_with_fa.pt")
)
tensor1 = standard_output_with_fa[0][0]
tensor1 = standard_output_with_fa
tensor2 = output[0][0][0]

if torch.equal(tensor1, tensor2):
logger.info("Outputs are totally equal")
else:
logger.warning("Outputs are not totally equal")
for rtol in [1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 1e-2, 1e-3, 1e-4, 1e-5]:
assert torch.allclose(
tensor1, tensor2, atol=0, rtol=rtol
), f"{(tensor1 - tensor2).abs().max()} is over rtol {rtol}"
logger.info(f"Check for rotol={rtol} has passed")
max_diff, index_max_diff = (tensor1 - tensor2).abs().max(dim=0)
max_diff = max_diff.item()
index_max_diff = index_max_diff.item()
rtol = max_diff / abs(tensor2[index_max_diff])
logger.info(
f"The relative error is {rtol}. Between {tensor1[index_max_diff]} and {tensor2[index_max_diff]}"
)
assert rtol < 1e-5, f"The relative error is {rtol}"


def test_output():
Expand Down
5 changes: 2 additions & 3 deletions tests/test_training/test_load_ckpt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.context.parallel_context import Config
from internlm.model.metrics import SchedulerMetricHook
from internlm.core.trainer import TrainState
from internlm.initialize.launch import args_sanity_check
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.model.metrics import AccPerplex, SchedulerMetricHook
from internlm.train import (
get_train_data_loader,
initialize_model,
Expand Down Expand Up @@ -225,7 +224,7 @@ def train_model(args):
SchedulerMetricHook(
metric=metric,
skip=(
gpc.is_using_pp()
gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
and hasattr(gpc.config.model, "num_chunks")
and gpc.config.model.num_chunks > 1
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def main(args):

ckpt_manager = CheckpointManager(
ckpt_config=gpc.config.ckpt,
model=model.model,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
train_dl=train_dl,
Expand Down

0 comments on commit 20a2832

Please sign in to comment.