Skip to content

Commit

Permalink
feat(moe): support group mlp for moe (#345)
Browse files Browse the repository at this point in the history
  • Loading branch information
blankde authored Oct 21, 2024
1 parent 001a97a commit 2bae28f
Show file tree
Hide file tree
Showing 13 changed files with 1,006 additions and 93 deletions.
2 changes: 1 addition & 1 deletion configs/1.8B_MoE16_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
num_experts=16,
moe_use_residual=False,
moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-D"
moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-Dropless", "Dropless"
)
"""
zero1 parallel (dict):
Expand Down
2 changes: 1 addition & 1 deletion configs/7B_MoE4_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@
# qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
qk_interleaved=False,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-D", "Dropless"
moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-Dropless", "Dropless"
num_experts=4,
top_k=2,
)
Expand Down
67 changes: 40 additions & 27 deletions internlm/checkpoint/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.trainer import TrainState
from internlm.model.moe.moe import MoE
from internlm.model.moe import MoE
from internlm.solver.optimizer import HybridZeroOptimizer, HybridZeroOptimizer_v2
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
Expand All @@ -28,17 +28,23 @@
internlm_accelerator = get_accelerator()


def try_load_moe_checkpoint(folder, model, state_dict, tp_rank, pp_rank):
pipeline_stage_size = gpc.config.model.num_layers // gpc.get_world_size(ParallelMode.PIPELINE)
moe_layer_id = pp_rank * pipeline_stage_size
# only support auto resume
def try_load_moe_checkpoint(folder, model, state_dict, expert_mp_rank, pp_rank):
"""Load MoE layer parameters from separate files if the model has MoE layers."""
# Calculate the stage size and rank within the pipeline parallelism
pp_stage_size = gpc.config.model.num_layers // gpc.get_world_size(ParallelMode.PIPELINE)
moe_layer_id = pp_rank * pp_stage_size
mode = "wp" if is_using_isp() else "tp"

# Iterate over all modules in the model to find MoE layers
for _, module in model.named_modules():
if isinstance(module, MoE):
num_local_experts = module.moe_layer.num_local_experts
num_local_wrapped_experts = len(module.moe_layer.experts.wrapped_experts)
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)
# loop all local_experts
for local_expert_id in range(num_local_experts):
global_expert_id = expp_rank * num_local_experts + local_expert_id
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_tp{tp_rank}.pt"
for local_expert_id in range(num_local_wrapped_experts):
global_expert_id = expp_rank * num_local_wrapped_experts + local_expert_id
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_{mode}{expert_mp_rank}.pt"
fp = os.path.join(folder, fn)
expert_state_dict = llm_load(fp, map_location=get_current_device())
# Updating global -> local expert ids
Expand All @@ -50,13 +56,14 @@ def try_load_moe_checkpoint(folder, model, state_dict, tp_rank, pp_rank):
moe_layer_id += 1


def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank):
def try_save_moe_checkpoint(folder, model, expert_mp_rank, pp_rank):
# Using layer_#_expert_# to save the model's expert state_dict,a hack.
pipeline_stage_size = gpc.config.model.num_layers // gpc.get_world_size(ParallelMode.PIPELINE)
moe_layer_id = pp_rank * pipeline_stage_size
mode = "wp" if is_using_isp() else "tp"
for n_module, module in model.named_modules():
if isinstance(module, MoE):
num_local_experts = module.moe_layer.num_local_experts
num_local_wrapped_experts = len(module.moe_layer.experts.wrapped_experts)
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)

# get all moe parameters
Expand All @@ -76,7 +83,7 @@ def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank):
else:
local_expert_id = m.group(1)

global_expert_id = expp_rank * num_local_experts + int(local_expert_id)
global_expert_id = expp_rank * num_local_wrapped_experts + int(local_expert_id)
expert_key = key.replace(f"{moe_str_prefix}{local_expert_id}", f"{moe_str_prefix}{global_expert_id}")

# truncating extra tensor (shared) storage
Expand All @@ -86,7 +93,7 @@ def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank):
# let save the moe parameters
for global_expert_id, expert_state_dict in experts_state_dict.items():
# save the moe parameters
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_tp{tp_rank}.pt"
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_{mode}{expert_mp_rank}.pt"
fp = os.path.join(folder, fn)
llm_save(fp, saved_obj=expert_state_dict)
moe_layer_id += 1
Expand Down Expand Up @@ -179,10 +186,12 @@ def load_model_checkpoint(folder, model):
states[key] = states[key].float()
print("load: ", states[key].float(),flush=True)
"""

# try to load expert parameter to separate files if model have moe layer
expert_tp_rank = 0 if gpc.config.parallel.expert.no_tp else tp_rank
try_load_moe_checkpoint(folder, model, states, expert_tp_rank, pp_rank)
if is_using_isp():
expert_wp_rank = gpc.get_local_rank(ParallelMode.EXPERT_WEIGHT)
try_load_moe_checkpoint(folder, model, states, expert_wp_rank, pp_rank)
else:
expert_tp_rank = 0 if gpc.config.parallel.expert.no_tp else tp_rank
try_load_moe_checkpoint(folder, model, states, expert_tp_rank, pp_rank)

if gpc.config.parallel.zero1.fsdp:
missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False)
Expand Down Expand Up @@ -252,6 +261,10 @@ def save_model_checkpoint(folder, model):
topo_fn = f"topo_wp{wp_rank}_pp{pp_rank}.json"
topo_fp = os.path.join(folder, topo_fn)
llm_save(topo_fp, saved_obj=topo)
expert_wp_rank = gpc.get_local_rank(ParallelMode.EXPERT_WEIGHT)
expert_wdp_rank = gpc.get_local_rank(ParallelMode.EXPERT_DATA)
if expert_wdp_rank == 0:
try_save_moe_checkpoint(folder, model, expert_wp_rank, pp_rank)
else:
# for tensor parallel mode with mtp/msp/fsp
for i in range(tp_size):
Expand All @@ -271,17 +284,17 @@ def save_model_checkpoint(folder, model):
topo_fp = os.path.join(folder, topo_fn)
llm_save(topo_fp, saved_obj=topo)

# try to save expert parameter to separate files if model have moe layer
expert_dp_size = gpc.get_world_size(ParallelMode.EXPERT_DATA)
expert_tp_size = 1 if gpc.config.parallel.expert.no_tp else tp_size
expert_dp_rank = gpc.get_local_rank(ParallelMode.EXPERT_DATA)
expert_tp_rank = 0 if gpc.config.parallel.expert.no_tp else tp_rank
should_save_rank_pair.clear()
for i in range(expert_tp_size):
should_save_rank_pair.add((i, i % expert_dp_size))

if (expert_tp_rank, expert_dp_rank) in should_save_rank_pair:
try_save_moe_checkpoint(folder, model, expert_tp_rank, pp_rank)
# try to save expert parameter to separate files if model have moe layer
expert_dp_size = gpc.get_world_size(ParallelMode.EXPERT_DATA)
expert_tp_size = 1 if gpc.config.parallel.expert.no_tp else tp_size
expert_dp_rank = gpc.get_local_rank(ParallelMode.EXPERT_DATA)
expert_tp_rank = 0 if gpc.config.parallel.expert.no_tp else tp_rank
should_save_rank_pair.clear()
for i in range(expert_tp_size):
should_save_rank_pair.add((i, i % expert_dp_size))

if (expert_tp_rank, expert_dp_rank) in should_save_rank_pair:
try_save_moe_checkpoint(folder, model, expert_tp_rank, pp_rank)

torch.distributed.barrier()

Expand Down
6 changes: 6 additions & 0 deletions internlm/core/parallel/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ def get_parallel_strategies_split_mode(linear_name: str) -> str:
return "column"
elif linear_name in ("wo", "out_proj", "w2"):
return "row"
elif linear_name in ("grouped_w1", "grouped_w2", "grouped_w3") and tp_mode == "isp":
return "grouped_wp"
elif linear_name in ("grouped_w1", "grouped_w3"):
return "grouped_column"
elif linear_name in ("grouped_w2"):
return "grouped_row"
else:
return "unknown"

Expand Down
Loading

0 comments on commit 2bae28f

Please sign in to comment.