Skip to content

Commit

Permalink
Merge branch 'cherry-pick-c4d12e26' into 'core_r0.7.0'
Browse files Browse the repository at this point in the history
Merge branch 'xuwenc/moe_gmm_infer_fix' into 'main'

See merge request ADLR/megatron-lm!1519
  • Loading branch information
jaredcasper committed May 31, 2024
2 parents a967adf + 015e427 commit a645f89
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 7 deletions.
4 changes: 2 additions & 2 deletions megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def __init__(
self.disable_grad_reduce = disable_grad_reduce

self.explicit_expert_comm = self.is_expert and (
config.sequence_parallel or self.expert_parallel
config.tensor_model_parallel_size > 1 or self.expert_parallel
)
if self.explicit_expert_comm and config.moe_extended_tp:
world_size = get_tensor_and_expert_parallel_world_size()
Expand Down Expand Up @@ -941,7 +941,7 @@ def __init__(
raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`")

self.explicit_expert_comm = self.is_expert and (
config.sequence_parallel or self.expert_parallel
config.tensor_model_parallel_size > 1 or self.expert_parallel
)

# Divide the weight matrix along the last dimension.
Expand Down
10 changes: 10 additions & 0 deletions megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ def __init__(
self.moe_layer_recompute = config.moe_layer_recompute

def forward(self, hidden_states: torch.Tensor):
if (
self.training
and self.config.tensor_model_parallel_size > 1
and not self.config.sequence_parallel
):
raise ValueError(
"During training, performance may degrade if MoE and tensor parallelism"
"are enabled without also enabling sequence parallelism."
)

# process MoE
def custom_forward(hidden_states):
probs, indices = self.router(hidden_states)
Expand Down
8 changes: 6 additions & 2 deletions megatron/core/transformer/moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def token_permutation(
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])

# Permute the tokens across the expert parallel devices.
if self.config.sequence_parallel or (self.config.expert_model_parallel_size > 1):
if (self.config.tensor_model_parallel_size > 1) or (
self.config.expert_model_parallel_size > 1
):
with torch.no_grad():
global_indices = tensor_parallel.gather_from_sequence_parallel_region_to_moe(
max_ind
Expand Down Expand Up @@ -214,7 +216,9 @@ def token_unpermutation(
output_bias_total = unpermuted_local_bias

# Unpermute the tokens across expert parallel devices.
if self.config.sequence_parallel or (self.config.expert_model_parallel_size > 1):
if (self.config.tensor_model_parallel_size > 1) or (
self.config.expert_model_parallel_size > 1
):
assert (
self.global_local_map is not None
), "global_local_map is necessary for `AllGather`."
Expand Down
3 changes: 0 additions & 3 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,9 +498,6 @@ def validate_args(args, defaults={}):
# MoE Spec check
if args.num_experts is not None:
assert args.spec is None, "Model Spec must be None when using MoEs"
if args.tensor_model_parallel_size > 1:
assert args.sequence_parallel, \
"When using MoE and tensor parallelism, sequence parallelism must be used."

# Expert parallelism check
if args.expert_model_parallel_size > 1:
Expand Down

0 comments on commit a645f89

Please sign in to comment.