Skip to content

Commit

Permalink
Merge branch 'fix_overlap_param_gather' into 'main'
Browse files Browse the repository at this point in the history
fix EP distopt with overlap param gather

See merge request ADLR/megatron-lm!1345

(cherry picked from commit ccfeda4)

ac93d847 fix EP distopt with overlap param gather
bb7b4307 change golden metrics
0ff731ff Minor fix to thrown value error
  • Loading branch information
ericharper committed Apr 18, 2024
1 parent b26d3e3 commit cac60ce
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
26 changes: 16 additions & 10 deletions megatron/core/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,21 +754,27 @@ def load_state_dict(self, state_dict):
self.param_groups += optimizer.param_groups

def disable_pre_hook(self):
if not self.config.use_distributed_optimizer or not self.config.overlap_param_gather:
raise ValueError(
"disable_pre_hook should only be called with 'use_distributed_optimizer' "
"and 'overlap_param_gather' are both enabled."
)
for optimizer in self.chained_optimizers:
if (
not optimizer.config.use_distributed_optimizer
or not optimizer.config.overlap_param_gather
):
raise ValueError(
"disable_pre_hook should only be called with 'use_distributed_optimizer' "
"and 'overlap_param_gather' both enabled."
)
optimizer.disable_pre_hook()

def enable_pre_hook(self):
if not self.config.use_distributed_optimizer or not self.config.overlap_param_gather:
raise ValueError(
"enable_pre_hook should only be called with 'use_distributed_optimizer' "
"and 'overlap_param_gather' are both enabled."
)
for optimizer in self.chained_optimizers:
if (
not optimizer.config.use_distributed_optimizer
or not optimizer.config.overlap_param_gather
):
raise ValueError(
"enable_pre_hook should only be called with 'use_distributed_optimizer' "
"and 'overlap_param_gather' both enabled."
)
optimizer.enable_pre_hook()

def step(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/functional_tests/jet_recipes/MR-gpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ products:
# - {tp_size: [2], pp_size: [1,2], extra_args: ['"--context-parallel-size 2 --sequence-parallel --hidden-dropout 0.0 --attention-dropout 0.0"']} # TODO: need updated container with TE > 1.0.0
- {tp_size: [2], pp_size: [1], extra_args: ['"--sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --moe-router-load-balancing-type sinkhorn --moe-router-topk 1"'], args_meta: ["te_8experts2parallel"]}
- {tp_size: [2], pp_size: [1], extra_args: ['"--sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --use-distributed-optimizer --moe-router-load-balancing-type sinkhorn --moe-router-topk 1"'], args_meta: ["te_8experts2parallel_dist_optimizer"]}
- {tp_size: [2], pp_size: [1], extra_args: ['"--moe-grouped-gemm --disable-bias-linear --sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --use-distributed-optimizer --moe-router-load-balancing-type sinkhorn --moe-router-topk 1 --overlap-grad-reduce"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_overlap_grad_reduce_groupedGEMM"]}
- {tp_size: [2], pp_size: [1], extra_args: ['"--moe-grouped-gemm --disable-bias-linear --sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --use-distributed-optimizer --moe-router-load-balancing-type sinkhorn --moe-router-topk 1 --overlap-grad-reduce --overlap-param-gather"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_overlap_grad_reduce_param_gather_groupedGEMM"]}
- {tp_size: [2], pp_size: [1], extra_args: ['"--moe-grouped-gemm --disable-bias-linear --sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --moe-router-load-balancing-type sinkhorn --moe-router-topk 1"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_groupedGEMM"]}
- {tp_size: [2], pp_size: [1], extra_args: ['"--disable-bias-linear --sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --moe-router-load-balancing-type aux_loss --moe-router-topk 2 --moe-aux-loss-coeff 1e-2"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_top2router"]}
- {tp_size: [1], pp_size: [1], extra_args: ["--use-distributed-optimizer"], args_meta: ["dist_optimizer"]}
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.80961, 10.86088, 10.86703, 10.80386, 10.71988, 10.64698, 10.21161, 10.32003, 10.22052, 9.92363]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [31228.0, 37860.0, 38327.0, 36135.0, 33138.0, 34687.0, 30217.0, 34984.0, 35952.0, 37036.0]}, "iteration_timing_avg": 0.18751352941176463}
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.80961, 10.86088, 10.86703, 10.80386, 10.71988, 10.64698, 10.21161, 10.32003, 10.22052, 9.92363]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [31228.0, 37860.0, 38327.0, 36135.0, 33138.0, 34687.0, 30217.0, 34984.0, 35952.0, 37036.0]}, "iteration_timing_avg": 0.17911029411764712}

0 comments on commit cac60ce

Please sign in to comment.