From 2bae28fa3b56cca7eaad73c962f1c8bec051f331 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Mon, 21 Oct 2024 13:49:51 +0800 Subject: [PATCH] feat(moe): support group mlp for moe (#345) --- configs/1.8B_MoE16_sft.py | 2 +- configs/7B_MoE4_sft.py | 2 +- internlm/checkpoint/components.py | 67 ++-- internlm/core/parallel/shard.py | 6 + internlm/model/modules/linear.py | 464 ++++++++++++++++++++++++++- internlm/model/modules/mlp.py | 107 ++++++ internlm/model/moe/dropless_layer.py | 29 +- internlm/model/moe/experts.py | 4 +- internlm/model/moe/gshard_layer.py | 59 +++- internlm/model/moe/utils.py | 10 + internlm/model/ops/linear.py | 55 ++++ internlm/train/pipeline.py | 78 +++-- tools/moe_group_ckpt_converter.py | 216 +++++++++++++ 13 files changed, 1006 insertions(+), 93 deletions(-) create mode 100644 tools/moe_group_ckpt_converter.py diff --git a/configs/1.8B_MoE16_sft.py b/configs/1.8B_MoE16_sft.py index c3f03c3b..f8530277 100644 --- a/configs/1.8B_MoE16_sft.py +++ b/configs/1.8B_MoE16_sft.py @@ -160,7 +160,7 @@ num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. num_experts=16, moe_use_residual=False, - moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-D" + moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-Dropless", "Dropless" ) """ zero1 parallel (dict): diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index 116f2bd5..c558427c 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -156,7 +156,7 @@ # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] qk_interleaved=False, num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. - moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-D", "Dropless" + moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-Dropless", "Dropless" num_experts=4, top_k=2, ) diff --git a/internlm/checkpoint/components.py b/internlm/checkpoint/components.py index edf39b55..eee92c9c 100644 --- a/internlm/checkpoint/components.py +++ b/internlm/checkpoint/components.py @@ -10,7 +10,7 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.trainer import TrainState -from internlm.model.moe.moe import MoE +from internlm.model.moe import MoE from internlm.solver.optimizer import HybridZeroOptimizer, HybridZeroOptimizer_v2 from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -28,17 +28,23 @@ internlm_accelerator = get_accelerator() -def try_load_moe_checkpoint(folder, model, state_dict, tp_rank, pp_rank): - pipeline_stage_size = gpc.config.model.num_layers // gpc.get_world_size(ParallelMode.PIPELINE) - moe_layer_id = pp_rank * pipeline_stage_size +# only support auto resume +def try_load_moe_checkpoint(folder, model, state_dict, expert_mp_rank, pp_rank): + """Load MoE layer parameters from separate files if the model has MoE layers.""" + # Calculate the stage size and rank within the pipeline parallelism + pp_stage_size = gpc.config.model.num_layers // gpc.get_world_size(ParallelMode.PIPELINE) + moe_layer_id = pp_rank * pp_stage_size + mode = "wp" if is_using_isp() else "tp" + + # Iterate over all modules in the model to find MoE layers for _, module in model.named_modules(): if isinstance(module, MoE): - num_local_experts = module.moe_layer.num_local_experts + num_local_wrapped_experts = len(module.moe_layer.experts.wrapped_experts) expp_rank = gpc.get_local_rank(ParallelMode.EXPERT) # loop all local_experts - for local_expert_id in range(num_local_experts): - global_expert_id = expp_rank * num_local_experts + local_expert_id - fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_tp{tp_rank}.pt" + for local_expert_id in range(num_local_wrapped_experts): + global_expert_id = expp_rank * num_local_wrapped_experts + local_expert_id + fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_{mode}{expert_mp_rank}.pt" fp = os.path.join(folder, fn) expert_state_dict = llm_load(fp, map_location=get_current_device()) # Updating global -> local expert ids @@ -50,13 +56,14 @@ def try_load_moe_checkpoint(folder, model, state_dict, tp_rank, pp_rank): moe_layer_id += 1 -def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank): +def try_save_moe_checkpoint(folder, model, expert_mp_rank, pp_rank): # Using layer_#_expert_# to save the model's expert state_dict,a hack. pipeline_stage_size = gpc.config.model.num_layers // gpc.get_world_size(ParallelMode.PIPELINE) moe_layer_id = pp_rank * pipeline_stage_size + mode = "wp" if is_using_isp() else "tp" for n_module, module in model.named_modules(): if isinstance(module, MoE): - num_local_experts = module.moe_layer.num_local_experts + num_local_wrapped_experts = len(module.moe_layer.experts.wrapped_experts) expp_rank = gpc.get_local_rank(ParallelMode.EXPERT) # get all moe parameters @@ -76,7 +83,7 @@ def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank): else: local_expert_id = m.group(1) - global_expert_id = expp_rank * num_local_experts + int(local_expert_id) + global_expert_id = expp_rank * num_local_wrapped_experts + int(local_expert_id) expert_key = key.replace(f"{moe_str_prefix}{local_expert_id}", f"{moe_str_prefix}{global_expert_id}") # truncating extra tensor (shared) storage @@ -86,7 +93,7 @@ def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank): # let save the moe parameters for global_expert_id, expert_state_dict in experts_state_dict.items(): # save the moe parameters - fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_tp{tp_rank}.pt" + fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}_{mode}{expert_mp_rank}.pt" fp = os.path.join(folder, fn) llm_save(fp, saved_obj=expert_state_dict) moe_layer_id += 1 @@ -179,10 +186,12 @@ def load_model_checkpoint(folder, model): states[key] = states[key].float() print("load: ", states[key].float(),flush=True) """ - - # try to load expert parameter to separate files if model have moe layer - expert_tp_rank = 0 if gpc.config.parallel.expert.no_tp else tp_rank - try_load_moe_checkpoint(folder, model, states, expert_tp_rank, pp_rank) + if is_using_isp(): + expert_wp_rank = gpc.get_local_rank(ParallelMode.EXPERT_WEIGHT) + try_load_moe_checkpoint(folder, model, states, expert_wp_rank, pp_rank) + else: + expert_tp_rank = 0 if gpc.config.parallel.expert.no_tp else tp_rank + try_load_moe_checkpoint(folder, model, states, expert_tp_rank, pp_rank) if gpc.config.parallel.zero1.fsdp: missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False) @@ -252,6 +261,10 @@ def save_model_checkpoint(folder, model): topo_fn = f"topo_wp{wp_rank}_pp{pp_rank}.json" topo_fp = os.path.join(folder, topo_fn) llm_save(topo_fp, saved_obj=topo) + expert_wp_rank = gpc.get_local_rank(ParallelMode.EXPERT_WEIGHT) + expert_wdp_rank = gpc.get_local_rank(ParallelMode.EXPERT_DATA) + if expert_wdp_rank == 0: + try_save_moe_checkpoint(folder, model, expert_wp_rank, pp_rank) else: # for tensor parallel mode with mtp/msp/fsp for i in range(tp_size): @@ -271,17 +284,17 @@ def save_model_checkpoint(folder, model): topo_fp = os.path.join(folder, topo_fn) llm_save(topo_fp, saved_obj=topo) - # try to save expert parameter to separate files if model have moe layer - expert_dp_size = gpc.get_world_size(ParallelMode.EXPERT_DATA) - expert_tp_size = 1 if gpc.config.parallel.expert.no_tp else tp_size - expert_dp_rank = gpc.get_local_rank(ParallelMode.EXPERT_DATA) - expert_tp_rank = 0 if gpc.config.parallel.expert.no_tp else tp_rank - should_save_rank_pair.clear() - for i in range(expert_tp_size): - should_save_rank_pair.add((i, i % expert_dp_size)) - - if (expert_tp_rank, expert_dp_rank) in should_save_rank_pair: - try_save_moe_checkpoint(folder, model, expert_tp_rank, pp_rank) + # try to save expert parameter to separate files if model have moe layer + expert_dp_size = gpc.get_world_size(ParallelMode.EXPERT_DATA) + expert_tp_size = 1 if gpc.config.parallel.expert.no_tp else tp_size + expert_dp_rank = gpc.get_local_rank(ParallelMode.EXPERT_DATA) + expert_tp_rank = 0 if gpc.config.parallel.expert.no_tp else tp_rank + should_save_rank_pair.clear() + for i in range(expert_tp_size): + should_save_rank_pair.add((i, i % expert_dp_size)) + + if (expert_tp_rank, expert_dp_rank) in should_save_rank_pair: + try_save_moe_checkpoint(folder, model, expert_tp_rank, pp_rank) torch.distributed.barrier() diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py index fa79ddc9..a0e90578 100644 --- a/internlm/core/parallel/shard.py +++ b/internlm/core/parallel/shard.py @@ -168,6 +168,12 @@ def get_parallel_strategies_split_mode(linear_name: str) -> str: return "column" elif linear_name in ("wo", "out_proj", "w2"): return "row" + elif linear_name in ("grouped_w1", "grouped_w2", "grouped_w3") and tp_mode == "isp": + return "grouped_wp" + elif linear_name in ("grouped_w1", "grouped_w3"): + return "grouped_column" + elif linear_name in ("grouped_w2"): + return "grouped_row" else: return "unknown" diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index a3c684f6..cfbc8f85 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -4,6 +4,7 @@ from __future__ import annotations +import math from typing import TYPE_CHECKING, Optional, Union import torch @@ -18,7 +19,12 @@ get_parallel_strategies_split_mode, get_tensor_split_parallel_mode, ) -from internlm.model.ops.linear import linear_backward_op, linear_forward_op +from internlm.model.ops.linear import ( + gmm_backward_op, + gmm_forward_op, + linear_backward_op, + linear_forward_op, +) from internlm.utils.logger import get_logger if TYPE_CHECKING: @@ -294,6 +300,191 @@ def backward(ctx, grad_output, *args): return grad_input, grad_weight, grad_bias, None, None, None, None +class GroupedGemmSPFusedDenseFunc(torch.autograd.Function): + "Grouped Gemm FusedDenseFunc for tensor parallel" + + @staticmethod + @custom_fwd + def forward( + ctx, + x: torch.Tensor, + weight: torch.Tensor, + batch_sizes: torch.Tensor, + backend: str, + ): + + if backend == "bmm": + assert x.dim() == 3, f"bmm only support 3d input (e, c, m), but got: {x.shape}" + elif backend == "gmm": + assert x.dim() == 2, f"gmm only support 2d input (s, m), but got: {x.shape}" + assert batch_sizes is not None, "batch_sizes should be provided for gmm" + else: + raise NotImplementedError(f"Invalid backend: {backend}") + + input_numel = x.numel() + if input_numel == 0: + backend = "bmm" + + ctx.compute_weight_gradient = weight.requires_grad + ctx.backend = backend + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + + if backend == "gmm": + output = gmm_forward_op(x, weight, batch_sizes) + else: + if input_numel == 0: + # if inp is empty, reshape to make grad flow. + # inp shape: (0, hdim) + weight = weight.view(x.shape[-1], -1) + + output = torch.matmul(x, weight) + + saved_x = None if ctx.compute_weight_gradient is False else x + ctx.save_for_backward(saved_x, weight, batch_sizes) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + backend = ctx.backend + + grad_output = grad_output.contiguous() + x, weight, batch_sizes = ctx.saved_tensors + grad_input, grad_weight = None, None + + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + if backend == "gmm": + grad_input, grad_weight = gmm_backward_op(x, grad_output, batch_sizes, input_weight=weight) + else: + grad_weight = torch.matmul(x.transpose(-1, -2), grad_output) + + if ctx.needs_input_grad[0]: + if backend == "gmm": + if grad_input is None: + grad_input, _ = gmm_backward_op(grad_output, weight, batch_sizes, is_grad_input=True) + else: + grad_input = torch.matmul(grad_output, weight.transpose(-1, -2)) + + return grad_input, grad_weight, None, None, None, None, None + + +class GroupedGemmWPFusedDenseFunc(torch.autograd.Function): + "Grouped Gemm FusedDenseFunc for weigth parallel." + + @staticmethod + @custom_fwd + def forward( + ctx, + x: torch.Tensor, + weight: torch.Tensor, + module: nn.Module, + communicator: WPCommunicator, + batch_sizes: torch.Tensor, + backend: str, + full_weight_shape: torch.Size, + ): + assert full_weight_shape is not None, "full_weight_shape should be provided" + if backend == "bmm": + assert x.dim() == 3, f"bmm only support 3d input (e, c, m), but got: {x.shape}" + elif backend == "gmm": + assert x.dim() == 2, f"gmm only support 2d input (s, m), but got: {x.shape}" + assert batch_sizes is not None, "batch_sizes should be provided for gmm" + else: + raise NotImplementedError(f"Invalid backend: {backend}") + + input_numel = x.numel() + if input_numel == 0: + backend = "bmm" + + ctx.compute_weight_gradient = weight.requires_grad + ctx.module = module + ctx.communicator = communicator + ctx.backend = backend + ctx.full_weight_shape = full_weight_shape + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + + total_weight = communicator.weight_hook(weight, module=module) + total_weight = total_weight.reshape(full_weight_shape) + + if torch.is_autocast_enabled(): + total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype()) + total_weight = total_weight.contiguous() + + if backend == "gmm": + output = gmm_forward_op(x, total_weight, batch_sizes) + else: + if input_numel == 0: + # if inp is empty, reshape to make grad flow. + # inp shape: (0, hdim) + total_weight = total_weight.view(x.shape[-1], -1) + + output = torch.matmul(x, total_weight) + + # release memory + del total_weight + + saved_x = None if ctx.compute_weight_gradient is False else x + ctx.save_for_backward(saved_x, weight, batch_sizes) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + module: nn.Module = ctx.module + communicator: WPCommunicator = ctx.communicator + x, weight, batch_sizes = ctx.saved_tensors + backend = ctx.backend + full_weight_shape = ctx.full_weight_shape + + grad_output = grad_output.contiguous() + + total_weight = communicator.weight_hook(weight, module=module) + total_weight = total_weight.reshape(full_weight_shape) + grad_input, grad_weight = None, None + if grad_output.numel() == 0: + if ctx.needs_input_grad[0]: + grad_input = torch.zeros_like(x) + if ctx.needs_input_grad[1]: + grad_weight = torch.zeros_like(total_weight).reshape(-1, full_weight_shape[-1]) + grad_weight, _ = communicator.grad_hook(grad_weight, async_op=False, module=module, is_bias=False) + + return grad_input, grad_weight, None, None, None, None, None + + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + if backend == "gmm": + grad_input, grad_weight = gmm_backward_op(x, grad_output, batch_sizes, input_weight=total_weight) + else: + grad_weight = torch.matmul(x.transpose(-1, -2), grad_output) + grad_weight = grad_weight.view(-1, grad_weight.shape[-1]) + grad_weight, grad_weight_sync = communicator.grad_hook( + grad_weight, async_op=True, module=module, is_bias=False + ) + + if ctx.needs_input_grad[0]: + if backend == "gmm": + if grad_input is None: + grad_input, _ = gmm_backward_op(grad_output, total_weight, batch_sizes, is_grad_input=True) + else: + grad_input = torch.matmul(grad_output, total_weight.transpose(-1, -2)) + + del total_weight + + if ctx.needs_input_grad[1]: + grad_weight_sync.wait() + + return grad_input, grad_weight, None, None, None, None, None + + def fused_dense_func( x: torch.Tensor, weight: torch.Tensor, @@ -301,24 +492,51 @@ def fused_dense_func( module: Optional[nn.Module] = None, bias: Optional[torch.Tensor] = None, return_residual: bool = False, + use_grouped_linear: bool = False, + **kwargs, ): if communicator.communication_mode() == "wp": - return WPFusedDenseFunc.apply( - x, - weight, - bias, - module, - communicator, - return_residual, - ) + if not use_grouped_linear: + return WPFusedDenseFunc.apply( + x, + weight, + bias, + module, + communicator, + return_residual, + ) + else: + batch_sizes = kwargs.pop("batch_sizes", None) + backend = kwargs.pop("backend", "gmm") + full_weight_shape = kwargs.pop("full_weight_shape", None) + return GroupedGemmWPFusedDenseFunc.apply( + x, + weight, + module, + communicator, + batch_sizes, + backend, + full_weight_shape, + ) else: # mtp, msp, and fsp - return SPFusedDenseFunc.apply( - x, - weight, - bias, - communicator, - return_residual, - ) + if not use_grouped_linear: + return SPFusedDenseFunc.apply( + x, + weight, + bias, + communicator, + return_residual, + ) + else: + # TODO: support grouped linear for mtp, msp, and fsp + batch_sizes = kwargs.pop("batch_sizes", None) + backend = kwargs.pop("backend", "gmm") + return GroupedGemmSPFusedDenseFunc.apply( + x, + weight, + batch_sizes, + backend, + ) class ParallelLinearWithCommExt(nn.Linear): @@ -380,16 +598,27 @@ def __init__( else: super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) - def forward(self, input: torch.Tensor) -> torch.Tensor: # pylint: disable=W0622 + def forward(self, input: torch.Tensor, batch_sizes: torch.Tensor = None) -> torch.Tensor: # pylint: disable=W0622 _class_name = self.__class__.__name__ assert self._communicator is not None, f"{_class_name} should register with a communicator first." + mixer_kwargs = {} + use_grouped_linear = getattr(self, "is_grouped_linear", False) + if use_grouped_linear: + mixer_kwargs = { + "batch_sizes": batch_sizes, + "backend": self.backend, + "full_weight_shape": self.full_weight_shape if hasattr(self, "full_weight_shape") else None, + } + return fused_dense_func( input, self.weight, communicator=self._communicator, module=self, bias=self.bias, + use_grouped_linear=use_grouped_linear, + **mixer_kwargs, ) @@ -571,6 +800,177 @@ def __init__( dist.broadcast(self.bias, gpc.get_ranks_in_group(parallel_mode)[0], process_group) +class GroupedParallelLinearWithCommExt(ParallelLinearWithCommExt): + """ + Parallel linear with commuication extention. + + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + split_mode (str): The split mode. It can be "none", "column", or "row". + """ + + def __init__( # pylint: disable=W0231, W0233 + self, + in_features: int, + out_features: int, + num_groups: int, + parallel_mode: ParallelMode, + backend: str = "gmm", + multiple_of: int = 1, + device: torch.device = None, + dtype: torch.dtype = None, + split_mode: str = "none", + ) -> None: + nn.Module.__init__(self) + + assert split_mode in ("none", "column", "row", "weight"), f"unknown split_mode {split_mode}" + + world_size = gpc.get_world_size(parallel_mode) + rank = gpc.get_local_rank(parallel_mode) + + split_features_dict = {"column": out_features, "row": in_features, "weight": num_groups * in_features} + if split_mode != "none": + split_features = split_features_dict[split_mode] + multiple = split_features // multiple_of + # We want to split @multiple across world_size, but it could be an uneven split + div = multiple // world_size + mod = multiple % world_size + # The first @mod ranks get @div + 1 copies, the rest get @div copies + local_multiple = div + int(rank < mod) + + if split_mode == "column": + self.weight = nn.Parameter( + torch.empty(num_groups, in_features, local_multiple * multiple_of, device=device, dtype=dtype) + ) + elif split_mode == "row": + self.weight = nn.Parameter( + torch.empty(num_groups, local_multiple * multiple_of, out_features, device=device, dtype=dtype) + ) + elif split_mode == "weight": + self.weight = nn.Parameter( + torch.empty(local_multiple * multiple_of, out_features, device=device, dtype=dtype) + ) + else: # none + self.weight = nn.Parameter(torch.empty(num_groups, in_features, out_features, device=device, dtype=dtype)) + + self.register_parameter("bias", None) + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + self.is_grouped_linear = True + self.backend = backend + + +class GroupedColumnLinear(GroupedParallelLinearWithCommExt): + """ + GroupedSPLinear + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + num_groups (int): number of groups. + backend (str): backend used for the grouped linear. It can be "gmm" or "bmm". + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + """ + + def __init__( + self, + in_features: int, + out_features: int, + num_groups: int, + backend: str = "gmm", + device: torch.device = None, + dtype: torch.dtype = None, + is_expert: bool = True, + ): + parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert) + super().__init__( + in_features, + out_features, + num_groups, + parallel_mode, + backend, + device=device, + dtype=dtype, + split_mode="column", + ) + + world_size = gpc.get_world_size(parallel_mode) + assert world_size == 1, "GroupedSPLinear not support tensor parallel yet." + + +class GroupedRowLinear(GroupedParallelLinearWithCommExt): + """ + GroupedSPLinear + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + num_groups (int): number of groups. + backend (str): backend used for the grouped linear. It can be "gmm" or "bmm". + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + """ + + def __init__( + self, + in_features: int, + out_features: int, + num_groups: int, + backend: str = "gmm", + device: torch.device = None, + dtype: torch.dtype = None, + is_expert: bool = True, + ): + parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert) + super().__init__( + in_features, out_features, num_groups, parallel_mode, backend, device=device, dtype=dtype, split_mode="row" + ) + + world_size = gpc.get_world_size(parallel_mode) + assert world_size == 1, "GroupedSPLinear not support tensor parallel yet." + + +class GroupedWPLinear(GroupedParallelLinearWithCommExt): + """ + GroupedWPLinear + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + num_groups (int): number of groups. + backend (str): backend used for the grouped linear. It can be "gmm" or "bmm". + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + """ + + def __init__( + self, + in_features: int, + out_features: int, + num_groups: int, + backend: str = "gmm", + device: torch.device = None, + dtype: torch.dtype = None, + is_expert: bool = True, + ): + parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert) + super().__init__( + in_features, + out_features, + num_groups, + parallel_mode, + backend, + device=device, + dtype=dtype, + split_mode="weight", + ) + + self.full_weight_shape = torch.Size((num_groups, in_features, out_features)) + + def new_linear( name: str, in_features: int, @@ -639,6 +1039,36 @@ def new_linear( dtype, is_expert, ) + elif split_mode == "grouped_wp": + return GroupedWPLinear( + in_features, + out_features, + kwargs["num_groups"], + kwargs["backend"], + device, + dtype, + is_expert, + ) + elif split_mode == "grouped_column": + return GroupedColumnLinear( + in_features, + out_features, + kwargs["num_groups"], + kwargs["backend"], + device, + dtype, + is_expert, + ) + elif split_mode == "grouped_row": + return GroupedRowLinear( + in_features, + out_features, + kwargs["num_groups"], + kwargs["backend"], + device, + dtype, + is_expert, + ) elif split_mode == "gate": return nn.Linear( in_features, diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index 94d9654e..6e74d6b6 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -122,6 +122,94 @@ def forward(self, x): return out +class GroupedFeedForward(nn.Module): + """ + Base FeedForward in flash implementation. + Args: + in_features (int): size of each input sample + hidden_features (int): size of hidden state of FFN + out_features (int): size of each output sample + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default. + mlp_layer_fusion (Optional[Bool]): Some linears without bias in FFN can be fused to reduce the comm cost of SP. + activation_type (str): the activation function used for feed forward, "swiglu" by default. + """ + + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: int = None, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + multiple_of: int = 256, + mlp_layer_fusion: Optional[bool] = False, + activation_type: str = "swiglu", + num_groups: int = 1, + backend: str = "bmm", + is_expert: bool = False, + ): + super().__init__() + + # TODO: support gelu... + assert activation_type in ("swiglu"), f"Unsupported activation type: {activation_type}" + assert bias is False, "Grouped FeedForward only support bias is False." + + self.mlp_layer_fusion = mlp_layer_fusion + + hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) + + if self.mlp_layer_fusion: + assert False, "do not support for grouped mlp." + else: + self.w1 = new_linear( + "grouped_w1", + in_features, + hidden_features, + bias, + device=device, + dtype=dtype, + num_groups=num_groups, + backend=backend, + is_expert=is_expert, + ) + self.w2 = new_linear( + "grouped_w2", + hidden_features, + out_features, + bias, + device=device, + dtype=dtype, + num_groups=num_groups, + backend=backend, + is_expert=is_expert, + ) + self.w3 = new_linear( + "grouped_w3", + in_features, + hidden_features, + bias, + device=device, + dtype=dtype, + num_groups=num_groups, + backend=backend, + is_expert=is_expert, + ) + + def forward(self, x, batch_sizes=None): + if not self.mlp_layer_fusion: + w1_o = self.w1(x, batch_sizes) + w3_o = self.w3(x, batch_sizes) + else: + assert False + out = self.w2(Silu(w1_o, w3_o), batch_sizes) + return out + + def new_feed_forward( in_features: int, hidden_features: int, @@ -133,7 +221,26 @@ def new_feed_forward( mlp_layer_fusion: Optional[bool] = False, activation_type: str = "swiglu", is_expert: bool = False, + use_grouped_mlp: bool = False, + **kwargs, ) -> FeedForward: + if use_grouped_mlp: + num_groups = kwargs.pop("num_groups", 1) + backend = kwargs.pop("backend", "bmm") + return GroupedFeedForward( + in_features, + hidden_features, + out_features, + bias, + device, + dtype, + multiple_of, + mlp_layer_fusion, + activation_type, + num_groups=num_groups, + backend=backend, + is_expert=is_expert, + ) return FeedForward( in_features, hidden_features, diff --git a/internlm/model/moe/dropless_layer.py b/internlm/model/moe/dropless_layer.py index d0342430..f5881dfb 100644 --- a/internlm/model/moe/dropless_layer.py +++ b/internlm/model/moe/dropless_layer.py @@ -149,9 +149,9 @@ def __init__( drop_policy="probs", capacity_factor: float = None, noisy_gate_policy: str = None, - moe_grouped_mlp: bool = True, enable_fused_permute: bool = True, token_dispatch_policy: str = "alltoall", + use_grouped_mlp: bool = True, deterministic_mode: bool = False, ) -> None: assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], ( @@ -161,8 +161,23 @@ def __init__( num_experts % ep_size == 0 ), f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})" - if moe_grouped_mlp: - assert False, "not support yet" + backend = "bmm" if drop_and_pad else "gmm" + if use_grouped_mlp: + experts = new_feed_forward( + in_features, + hidden_features, + out_features, + bias=False, + device=device, + dtype=dtype, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + activation_type=activation_type, + is_expert=True, + use_grouped_mlp=True, + num_groups=num_experts // ep_size, + backend=backend, + ) else: experts = torch.nn.ModuleList( [ @@ -200,7 +215,7 @@ def __init__( self.local_expert_indices = [local_expert_indices_offset + i for i in range(self.num_local_experts)] assert len(self.local_expert_indices) > 0, "Expected at least one local expert index" self.topk = top_k - self.moe_grouped_mlp = moe_grouped_mlp + self.use_grouped_mlp = use_grouped_mlp self.deterministic_mode = deterministic_mode self.drop_and_pad = drop_and_pad @@ -265,7 +280,7 @@ def forward(self, *inputs: Tensor) -> Tensor: (dispatched_input, tokens_per_expert) = self.token_permutation_func( reshaped_inputs, expert_weights, indices, tokens_per_expert_before_capacity ) - if self.moe_grouped_mlp: + if self.use_grouped_mlp: expert_output = self.experts(dispatched_input, batch_sizes=tokens_per_expert) else: expert_output = self.experts(dispatched_input, split_size_or_sections=tokens_per_expert, split_dim=0) @@ -424,7 +439,7 @@ def preprocess(self, indices, expert_weight, tokens_per_expert_before_capacity) num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(self.num_experts) num_tokens_per_local_expert = num_local_tokens_per_expert - if self.moe_grouped_mlp: + if self.use_grouped_mlp: num_tokens_per_local_expert = num_tokens_per_local_expert.to(torch.device("cpu"), non_blocking=True) if self.num_local_experts > 1 and self.ep_size > 1: @@ -765,7 +780,7 @@ def token_permutation_by_all_gather( tokens_per_expert = torch.histc(local_indices.view(-1), bins=self.num_experts, min=0, max=self.num_experts) if self.num_local_experts < self.num_experts: tokens_per_expert = tokens_per_expert[self.local_expert_indices[0] : self.local_expert_indices[-1] + 1] - if self.moe_grouped_mlp: + if self.use_grouped_mlp: tokens_per_expert = tokens_per_expert.cpu().to(torch.long) # Stage2: permute the tokens locally so that they are grouped by their expert assignment diff --git a/internlm/model/moe/experts.py b/internlm/model/moe/experts.py index e6d8b8e6..c48a2265 100644 --- a/internlm/model/moe/experts.py +++ b/internlm/model/moe/experts.py @@ -42,7 +42,7 @@ def forward(self, inputs, split_size_or_sections=None, split_dim=0, **kwargs): kwargs: args used for expert's forward pass other than input tokens """ - if self.num_local_experts == 1: + if len(self.wrapped_experts) == 1: return self.wrapped_experts[0](inputs, **kwargs) # The following code is designed for multiple experts. @@ -50,7 +50,7 @@ def forward(self, inputs, split_size_or_sections=None, split_dim=0, **kwargs): # 2. do for-loop for experts's computation if split_size_or_sections is None: # chunk can be faster than split - chunks = inputs.chunk(self.num_local_experts, dim=split_dim) + chunks = inputs.chunk(len(self.wrapped_experts), dim=split_dim) else: if isinstance(split_size_or_sections, torch.Tensor): split_size_or_sections = split_size_or_sections.tolist() diff --git a/internlm/model/moe/gshard_layer.py b/internlm/model/moe/gshard_layer.py index c8cb8c4c..3aba8d1a 100644 --- a/internlm/model/moe/gshard_layer.py +++ b/internlm/model/moe/gshard_layer.py @@ -475,6 +475,7 @@ def __init__( drop_tokens: bool = True, use_rts: bool = True, use_fused_gating: bool = True, + use_grouped_mlp: bool = True, ) -> None: assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], ( "Unsupported noisy_gate_policy: " + noisy_gate_policy @@ -482,20 +483,24 @@ def __init__( assert ( num_experts % ep_size == 0 ), f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})" - super().__init__( - TopKGate( + if use_grouped_mlp: + experts = new_feed_forward( in_features, - num_experts, - top_k, - capacity_factor, - eval_capacity_factor, - min_capacity, - noisy_gate_policy, - drop_tokens, - use_rts, - use_fused_gating, - ), - torch.nn.ModuleList( + hidden_features, + out_features, + bias=False, + device=device, + dtype=dtype, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + activation_type=activation_type, + is_expert=True, + use_grouped_mlp=True, + num_groups=num_experts // ep_size, + backend="bmm", + ) + else: + experts = torch.nn.ModuleList( [ new_feed_forward( in_features, @@ -511,12 +516,28 @@ def __init__( ) for _ in range(num_experts // ep_size) ] + ) + super().__init__( + TopKGate( + in_features, + num_experts, + top_k, + capacity_factor, + eval_capacity_factor, + min_capacity, + noisy_gate_policy, + drop_tokens, + use_rts, + use_fused_gating, ), + experts, ep_group, ep_size, num_experts // ep_size, ) + self.use_grouped_mlp = use_grouped_mlp + self.time_falltoall = 0.0 self.time_salltoall = 0.0 self.time_moe = 0.0 @@ -552,8 +573,20 @@ def forward(self, *inputs: Tensor) -> Tensor: # Re-shape after all-to-all: ecm -> gecm dispatched_inputs = dispatched_inputs.reshape(self.ep_size, self.num_local_experts, -1, d_model) + if self.use_grouped_mlp: + # (g,e,c,m) -> (e, g*c, m) + dispatched_inputs = ( + dispatched_inputs.transpose(0, 1).reshape(self.num_local_experts, -1, d_model).contiguous() + ) + expert_output = self.experts(dispatched_inputs, split_dim=1) + if self.use_grouped_mlp: + # (e, g*c, m) -> (e, g, c, m) -> (g, e, c, m) + expert_output = ( + expert_output.reshape(self.num_local_experts, self.ep_size, -1, d_model).transpose(0, 1).contiguous() + ) + if self.wall_clock_breakdown: timer("salltoall").start() diff --git a/internlm/model/moe/utils.py b/internlm/model/moe/utils.py index 08caa63a..b4745909 100644 --- a/internlm/model/moe/utils.py +++ b/internlm/model/moe/utils.py @@ -29,6 +29,11 @@ def forward( ctx.input_split_sizes = input_split_sizes ctx.group = group + world_size = torch.distributed.get_world_size(group=group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return inputs, None + inputs = inputs.contiguous() out = ( torch.empty_like(inputs) @@ -50,6 +55,11 @@ def forward( @staticmethod def backward(ctx: Any, grad_output: Tensor, _) -> Tuple[None, Tensor]: if ctx.needs_input_grad[0]: + # Bypass the function if we are using only 1 GPU. + world_size = torch.distributed.get_world_size(group=ctx.group) + if world_size == 1: + return grad_output, None, None, None, None + grad_output = grad_output.contiguous() out = torch.empty(ctx.input_shape, device=grad_output.device, dtype=grad_output.dtype) torch.distributed.all_to_all_single( diff --git a/internlm/model/ops/linear.py b/internlm/model/ops/linear.py index eeffddc0..fa4c93c6 100644 --- a/internlm/model/ops/linear.py +++ b/internlm/model/ops/linear.py @@ -21,6 +21,18 @@ except (ModuleNotFoundError, ImportError): flash_attn_impl = False +try: + # grouped_gemm on GPU + from grouped_gemm.backend import gmm as gmm_ops +except (ModuleNotFoundError, ImportError): + # grouped_gemm on NPU + try: + from mindspeed.op_builder import GMMOpBuilder + + gmm_ops = GMMOpBuilder().load() + except (ModuleNotFoundError, ImportError): + gmm_ops = None + internlm_accelerator = get_accelerator() @@ -61,3 +73,46 @@ def linear_backward_op( _, _backward_op = _select_ops_binding(_input.dtype, _is_cuda) return _backward_op(_input, weight, has_d_bias) + + +def _gmm_forward_op_npu(_input: torch.Tensor, weight: torch.Tensor, batch_sizes: torch.Tensor) -> torch.Tensor: + group_list = torch.cumsum(batch_sizes, dim=-1) + group_list = group_list.tolist() + group_type = 0 + + return gmm_ops.npu_gmm([_input], [weight], [], group_list, group_type)[0] + + +def _gmm_backward_op_npu( + _input: torch.Tensor, weight: torch.Tensor, batch_sizes: torch.Tensor, grad_output: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + group_list = torch.cumsum(batch_sizes, dim=-1) + group_list = group_list.tolist() + + dx, dw, _ = gmm_ops.npu_gmm_backward([grad_output], [_input], [weight], group_list) + + return dx[0], dw[0] + + +def gmm_forward_op(_input: torch.Tensor, weight: torch.Tensor, batch_sizes: torch.Tensor) -> torch.Tensor: + if internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU: + return gmm_ops(_input, weight, batch_sizes) + elif internlm_accelerator.get_accelerator_backend() is AcceleratorType.NPU: + return _gmm_forward_op_npu(_input, weight, batch_sizes) + + +def gmm_backward_op( + _input: torch.Tensor, weight: torch.Tensor, batch_sizes: torch.Tensor, **kwargs +) -> Tuple[torch.Tensor, torch.Tensor]: + if internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU: + if kwargs.get("is_grad_input", False): + trans_a, trans_b = (False, True) + return gmm_ops(_input, weight, batch_sizes, trans_a=trans_a, trans_b=trans_b), None + else: + trans_a, trans_b = (True, False) + return None, gmm_ops(_input, weight, batch_sizes, trans_a=trans_a, trans_b=trans_b) + + elif internlm_accelerator.get_accelerator_backend() is AcceleratorType.NPU: + input_weight = kwargs.get("input_weight", None) + grad_output = weight + return _gmm_backward_op_npu(_input, input_weight, batch_sizes, grad_output) diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 34c1479b..53057128 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -52,6 +52,9 @@ from internlm.model.modules.embedding import Embedding1D from internlm.model.modules.linear import ( ColumnParallelLinear, + GroupedColumnLinear, + GroupedRowLinear, + GroupedWPLinear, ParallelLinearWithCommExt, RewardModelLinear, RowParallelLinear, @@ -349,7 +352,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): is_moe=True, ) for moe in _submodule_filter(model, Experts): - for column_linear in _submodule_filter(moe, (ColumnParallelLinear)): + for column_linear in _submodule_filter(moe, (ColumnParallelLinear, GroupedWPLinear)): column_linear.register_communicator(moe_isp_communicator) for row_linear in _submodule_filter(moe, RowParallelLinear): row_linear.register_communicator(None) @@ -369,21 +372,30 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW) ) - if gpc.config.model.get("num_experts", 1) > 1 and gpc.config.parallel.expert.no_tp: - _column_communicator = TensorParallelCommunicator( - process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.COLUMN + if gpc.config.model.get("num_experts", 1) > 1: + GroupedColumnLinear.register_cls_communicator( + TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.COLUMN) ) - _row_communicator = TensorParallelCommunicator( - process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.ROW + GroupedRowLinear.register_cls_communicator( + TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW) ) - for moe in _submodule_filter(model, MoE): - # 1. the linear in MoE degrades as no tp communication pattern - for column_linear in _submodule_filter(moe, ColumnParallelLinear): - column_linear.register_communicator(_column_communicator) - for row_linear in _submodule_filter(moe, RowParallelLinear): - row_linear.register_communicator(_row_communicator) - # 2. register MoESequenceParallelCommunicator for MoE layer - MoESequenceParallelCommunicator(ParallelMode.TENSOR, reverse=True).register_module_hook(moe) + GroupedWPLinear.register_cls_communicator(None) + # treat as sequence paralle if no_tp + if gpc.config.parallel.expert.no_tp: + _column_communicator = TensorParallelCommunicator( + process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.COLUMN + ) + _row_communicator = TensorParallelCommunicator( + process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.ROW + ) + for moe in _submodule_filter(model, MoE): + # 1. the linear in MoE degrades as no tp communication pattern + for column_linear in _submodule_filter(moe, ColumnParallelLinear): + column_linear.register_communicator(_column_communicator) + for row_linear in _submodule_filter(moe, RowParallelLinear): + row_linear.register_communicator(_row_communicator) + # 2. register MoESequenceParallelCommunicator for MoE layer + MoESequenceParallelCommunicator(ParallelMode.TENSOR, reverse=True).register_module_hook(moe) _head_communicator = HeadTensorParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded) _embedding_communicator = EmbeddingTensorParallelCommunicator(ParallelMode.TENSOR) @@ -405,19 +417,35 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): save_total_input_as_activation=save_total_input_as_activation, ) ) - if gpc.config.model.get("num_experts", 1) > 1 and gpc.config.parallel.expert.no_tp: - _column_communicator = TensorParallelCommunicator( - process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.COLUMN + if gpc.config.model.get("num_experts", 1) > 1: + GroupedColumnLinear.register_cls_communicator( + SequenceParallelCommunicator( + process_group=gpc.get_group(ParallelMode.TENSOR), + role=LinearRole.COLUMN, + save_total_input_as_activation=save_total_input_as_activation, + ) ) - _row_communicator = TensorParallelCommunicator( - process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.ROW + GroupedRowLinear.register_cls_communicator( + SequenceParallelCommunicator( + gpc.get_group(ParallelMode.TENSOR), + role=LinearRole.ROW, + save_total_input_as_activation=save_total_input_as_activation, + ) ) - for moe in _submodule_filter(model, MoE): - # 1. the linear in MoE degrades as no tp communication pattern - for column_linear in _submodule_filter(moe, ColumnParallelLinear): - column_linear.register_communicator(_column_communicator) - for row_linear in _submodule_filter(moe, RowParallelLinear): - row_linear.register_communicator(_row_communicator) + GroupedWPLinear.register_cls_communicator(None) + if gpc.config.parallel.expert.no_tp: + _column_communicator = TensorParallelCommunicator( + process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.COLUMN + ) + _row_communicator = TensorParallelCommunicator( + process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.ROW + ) + for moe in _submodule_filter(model, MoE): + # 1. the linear in MoE degrades as no tp communication pattern + for column_linear in _submodule_filter(moe, ColumnParallelLinear): + column_linear.register_communicator(_column_communicator) + for row_linear in _submodule_filter(moe, RowParallelLinear): + row_linear.register_communicator(_row_communicator) _head_communicator = HeadSequenceParallelCommunicator( ParallelMode.TENSOR, _retain_out_sharded, save_total_input_as_activation diff --git a/tools/moe_group_ckpt_converter.py b/tools/moe_group_ckpt_converter.py new file mode 100644 index 00000000..d3fefb7c --- /dev/null +++ b/tools/moe_group_ckpt_converter.py @@ -0,0 +1,216 @@ +import argparse +import os +import re +import shutil +import sys + +import torch +from tqdm import tqdm + +sys.path.append(".") +import internlm # noqa: E402,F401 # pylint: disable=W0611,C0413 + +moe_str_prefix = None +weight_key_suffix = ".weight" + + +def load(fp): + with open(fp, "rb") as f: + pt_data = torch.load(f, map_location="cpu") + return pt_data + + +def list_to_group_ckpt_tp(src, tgt, ep_size, num_layer, num_local_experts, max_tp): + print("Converting checkpoints from sequence module list to group mlp...") + + for layer_id in tqdm(range(num_layer)): + for tp_rank in range(max_tp): + for expp_rank in range(ep_size): + # merge local experts into grouped mlp + expert_state_dict = dict() + # expert_w_state[key][expert] = weight + expert_w_state = {"w1": [], "w2": [], "w3": []} + expert_ids = range(num_local_experts * expp_rank, num_local_experts * (expp_rank + 1)) + for global_expert_id in expert_ids: + fn = f"model_moe_layer{layer_id}_expert{global_expert_id}_tp{tp_rank}.pt" + fp = os.path.join(src, fn) + origin_state = load(fp) + pattern = r"(.*?\.moe_layer\.experts\.wrapped_experts)\.\d+\.(w\d+)(?:\.weight)?" + for key, weight in origin_state.items(): + moe_str_prefix, w_i = re.search(pattern, key).group(1), re.search(pattern, key).group(2) + # [d2, d1] -> [d1, d2] + expert_w_state[w_i].append(weight.T) + # k*[d1, d2] -> [k, d1, d2] + for key, weights in expert_w_state.items(): + local_key = f"{moe_str_prefix}.{expp_rank}.{key}{weight_key_suffix}" + expert_state_dict[local_key] = torch.stack(weights, dim=0) + + torch.save( + expert_state_dict, os.path.join(tgt, f"model_moe_layer{layer_id}_expert{expp_rank}_tp{tp_rank}.pt") + ) + + +def group_to_list_ckpt_tp(src, tgt, ep_size, num_layer, num_local_experts, max_tp): + print("Converting checkpoints from group mlp list to sequence module...") + + for layer_id in tqdm(range(num_layer)): + for tp_rank in range(max_tp): + for expp_rank in range(ep_size): + # split group mlp to local experts, expert_w_state[key][expert] = weight + expert_w_state = {"w1": [], "w2": [], "w3": []} + fn = f"model_moe_layer{layer_id}_expert{expp_rank}_tp{tp_rank}.pt" + fp = os.path.join(src, fn) + origin_state = load(fp) + pattern = r"(.*?\.moe_layer\.experts\.wrapped_experts)\.\d+\.(w\d+)(?:\.weight)?" + for local_expert_id in range(num_local_experts): + expert_state_dict = dict() + global_expert_id = expp_rank * num_local_experts + local_expert_id + for key, weight in origin_state.items(): + moe_str_prefix, w_i = re.search(pattern, key).group(1), re.search(pattern, key).group(2) + # [k, d1, d2] -> k * [d1, d2] + expert_w_state[w_i] = weight.chunk(num_local_experts) + local_key = key.replace(f"{moe_str_prefix}.{expp_rank}", f"{moe_str_prefix}.{global_expert_id}") + # [d2, d1] -> [d1, d2] + value = expert_w_state[w_i][local_expert_id].squeeze().T + expert_state_dict[local_key] = value + torch.save( + expert_state_dict, + os.path.join(tgt, f"model_moe_layer{layer_id}_expert{global_expert_id}_tp{tp_rank}.pt"), + ) + + +def list_to_group_ckpt_wp(src, tgt, ep_size, num_layer, num_local_experts, max_wp): + print("Converting checkpoints from sequence module list to group mlp...") + + for layer_id in tqdm(range(num_layer)): + for expp_rank in range(ep_size): + # expert_w_state[key][expert][wp]=weight + expert_w_state = { + "w1": [[] for _ in range(num_local_experts)], + "w2": [[] for _ in range(num_local_experts)], + "w3": [[] for _ in range(num_local_experts)], + } + expert_ids = range(num_local_experts * expp_rank, num_local_experts * (expp_rank + 1)) + for local_expert_id, global_expert_id in enumerate(expert_ids): + for wp_rank in range(max_wp): + fn = f"model_moe_layer{layer_id}_expert{global_expert_id}_wp{wp_rank}.pt" + fp = os.path.join(src, fn) + origin_state = load(fp) + pattern = r"(.*?\.moe_layer\.experts\.wrapped_experts)\.\d+\.(w\d+)(?:\.weight)?" + for key, weight in origin_state.items(): + moe_str_prefix, w_i = re.search(pattern, key).group(1), re.search(pattern, key).group(2) + # [d2/2, d1] -> [d1, d2/w] + expert_w_state[w_i][local_expert_id].append(weight.T) + # expert_state_dict[wp][key] = value + expert_state_dict = [{} for _ in range(max_wp)] + # k*w*[d1,d2/w] -> k*[d1, d2] -> [k*d1, d2] -> w*[k/w*d1, w*d2] + for key, weights in expert_w_state.items(): + flat_weights = [torch.cat(row, dim=1) for row in weights] + full_weights = torch.cat(flat_weights, dim=0).chunk(max_wp, dim=0) + local_key = f"{moe_str_prefix}.{expp_rank}.{key}{weight_key_suffix}" + for wp_rank in range(max_wp): + expert_state_dict[wp_rank][local_key] = full_weights[wp_rank] + + for wp_rank in range(max_wp): + torch.save( + expert_state_dict[wp_rank], + os.path.join(tgt, f"model_moe_layer{layer_id}_expert{expp_rank}_wp{wp_rank}.pt"), + ) + + +def group_to_list_ckpt_wp(src, tgt, ep_size, num_layer, num_local_experts, max_wp): + print("Converting checkpoints from group mlp list to sequence module...") + + for layer_id in tqdm(range(num_layer)): + for expp_rank in range(ep_size): + # expert_w_state[key][wp]=weight + expert_w_state = { + "w1": [None for _ in range(max_wp)], + "w2": [None for _ in range(max_wp)], + "w3": [None for _ in range(max_wp)], + } + for wp_rank in range(max_wp): + fn = f"model_moe_layer{layer_id}_expert{expp_rank}_wp{wp_rank}.pt" + fp = os.path.join(src, fn) + origin_state = load(fp) + pattern = r"(.*?\.moe_layer\.experts\.wrapped_experts)\.\d+\.(w\d+)(?:\.weight)?" + for key, weight in origin_state.items(): + moe_str_prefix, w_i = re.search(pattern, key).group(1), re.search(pattern, key).group(2) + expert_w_state[w_i][wp_rank] = weight + + # expert_state_dict[expert][wp][key] = value + expert_state_dict = [[{} for _ in range(max_wp)] for _ in range(num_local_experts)] + for key, weight in expert_w_state.items(): + # w*[k*d1/w, d2] -> [k*d1, d2] -> k*[d1, d2/w] + full_weight = torch.cat(weight, dim=0).chunk(num_local_experts, dim=0) + flat_weight = [row.chunk(max_wp, dim=1) for row in full_weight] + for local_expert_id in range(num_local_experts): + for wp_rank in range(max_wp): + global_expert_id = expp_rank * num_local_experts + local_expert_id + local_key = f"{moe_str_prefix}.{global_expert_id}.{key}{weight_key_suffix}" + value = flat_weight[local_expert_id][wp_rank].T + expert_state_dict[local_expert_id][wp_rank][local_key] = value + + for local_expert_id in range(num_local_experts): + global_expert_id = expp_rank * num_local_experts + local_expert_id + for wp_rank in range(max_wp): + torch.save( + expert_state_dict[local_expert_id][wp_rank], + os.path.join(tgt, f"model_moe_layer{layer_id}_expert{global_expert_id}_wp{wp_rank}.pt"), + ) + + +def print_args(args): + print("-------------- Arguments --------------") + print(f"Source Path: {args.src}") + print(f"Target Path: {args.tgt}") + print(f"Expert Number: {args.num_experts}") + print(f"EP Size: {args.ep_size}") + print(f"Convert Mode: {'list to group' if args.convert_mode == 0 else 'group to list'}") + print("---------------------------------------") + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--src", type=str, help="Input folder") + parser.add_argument("--tgt", type=str, help="Output folder") + parser.add_argument("--num-experts", type=int, help="Number of experts") + parser.add_argument("--ep-size", type=int, help="expert parallel size") + parser.add_argument("--convert-mode", type=int, help="parallel mode: 0. list to group, 1.group to list") + + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + print_args(args) + + fns = list(os.listdir(args.src)) + moe_fns = [] + for fn in fns: + if fn.startswith("model_moe") and not fn.endswith("md5"): + moe_fns.append(fn) + elif (fn.startswith("model_t") or fn.startswith("model_w")) and not fn.endswith("md5"): + shutil.copyfile(os.path.join(args.src, fn), os.path.join(args.tgt, fn)) + num_layer, max_mp = -1, -1 + mode = None + for fn in moe_fns: + _, _, layer_info, _, mp_info = os.path.splitext(fn)[0].split("_") + num_layer = max(num_layer, int(layer_info[5:]) + 1) + max_mp = max(max_mp, int(mp_info[2:]) + 1) + mode = mp_info[:2] + num_local_experts = args.num_experts // args.ep_size + + if mode == "tp" and args.convert_mode == 0: + list_to_group_ckpt_tp(args.src, args.tgt, args.ep_size, num_layer, num_local_experts, max_mp) + elif mode == "tp" and args.convert_mode == 1: + group_to_list_ckpt_tp(args.src, args.tgt, args.ep_size, num_layer, num_local_experts, max_mp) + elif mode == "wp" and args.convert_mode == 0: + list_to_group_ckpt_wp(args.src, args.tgt, args.ep_size, num_layer, num_local_experts, max_mp) + elif mode == "wp" and args.convert_mode == 1: + group_to_list_ckpt_wp(args.src, args.tgt, args.ep_size, num_layer, num_local_experts, max_mp) + else: + assert False, "unsupport convert mode"