From c580a9d0125d40cd3a2e53c25350634daa4e0c35 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 27 Apr 2025 15:07:23 +0800 Subject: [PATCH 01/14] support megatron moe --- swift/megatron/argument/megatron_args.py | 19 ++++++++++++ swift/megatron/model/config.py | 7 +++++ swift/megatron/model/gpt/__init__.py | 1 + swift/megatron/model/gpt/config.py | 5 ++- swift/megatron/model/gpt/hf2mcore.py | 39 ++++++++++++++++-------- swift/megatron/model/gpt/mcore2hf.py | 31 +++++++++++++------ swift/megatron/model/gpt/model.py | 2 ++ tests/megatron/test_align/test_llm.py | 13 ++++++-- 8 files changed, 92 insertions(+), 25 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index a31df06cd9..9c3dcb3047 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -107,6 +107,21 @@ class MegatronArguments(ExtraMegatronArguments): qk_layernorm: bool = False transformer_impl: Literal['local', 'transformer_engine'] = 'transformer_engine' + # moe + expert_model_parallel_size: int = 1 + num_experts: Optional[int] = None + moe_ffn_hidden_size: Optional[int] = None + moe_shared_expert_intermediate_size: Optional[int] = None + moe_router_topk: int = 2 + moe_token_dispatcher_type: Literal['allgather', 'alltoall', 'alltoall_seq'] = 'alltoall' + moe_grouped_gemm: bool = False + moe_router_load_balancing_type: Literal['aux_loss', 'seq_aux_loss', 'sinkhorn', 'none'] = 'aux_loss' + moe_aux_loss_coeff: float = 0. + moe_z_loss_coeff: Optional[float] = None + moe_router_pre_softmax: bool = False + moe_expert_capacity_factor: Optional[float] = None + moe_shared_expert_overlap: bool = False + # mixed precision fp16: Optional[bool] = None bf16: Optional[bool] = None @@ -154,6 +169,10 @@ def __post_init__(self): self.seq_length = self.max_position_embeddings if self.tensorboard_dir is None and self.save is not None: self.tensorboard_dir = f'{self.save}/runs' + if self.moe_ffn_hidden_size is None: + self.moe_ffn_hidden_size = self.ffn_hidden_size + else: + self.ffn_hidden_size = self.moe_ffn_hidden_size self._init_mixed_precision() self.tensorboard_dir = to_abspath(self.tensorboard_dir) diff --git a/swift/megatron/model/config.py b/swift/megatron/model/config.py index f7a640beef..f8cdf94eea 100644 --- a/swift/megatron/model/config.py +++ b/swift/megatron/model/config.py @@ -21,6 +21,13 @@ 'disable_bias_linear': ['mlp_bias'], 'kv_channels': ['head_dim'], 'model_type': ['model_type'], + # moe + 'moe_ffn_hidden_size': ['moe_intermediate_size'], + 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'], + 'moe_router_topk': ['num_experts_per_tok'], + 'num_experts': ['num_experts'], + 'moe_router_pre_softmax': ['norm_topk_prob'], + 'moe_aux_loss_coeff': ['router_aux_loss_coef'], } diff --git a/swift/megatron/model/gpt/__init__.py b/swift/megatron/model/gpt/__init__.py index ab3756aa51..319306a551 100644 --- a/swift/megatron/model/gpt/__init__.py +++ b/swift/megatron/model/gpt/__init__.py @@ -35,4 +35,5 @@ ModelType.ziya, ModelType.mengzi3, ModelType.qwen3, + ModelType.qwen2_moe, ], model_provider, convert_gpt_hf_config, convert_mcore2hf, convert_hf2mcore)) diff --git a/swift/megatron/model/gpt/config.py b/swift/megatron/model/gpt/config.py index 5adb46b63b..2b13191fb1 100644 --- a/swift/megatron/model/gpt/config.py +++ b/swift/megatron/model/gpt/config.py @@ -5,6 +5,9 @@ def convert_gpt_hf_config(config) -> Dict[str, Any]: res = convert_hf_config(config) - if res.get('model_type') == 'qwen3': + model_type = res.get('model_type') + if model_type == 'qwen3': res['qk_layernorm'] = True + elif model_type in {'qwen2_moe', 'qwen3_moe'}: + res.pop('ffn_hidden_size', None) return res diff --git a/swift/megatron/model/gpt/hf2mcore.py b/swift/megatron/model/gpt/hf2mcore.py index 8725dd4b2b..441aa052b8 100644 --- a/swift/megatron/model/gpt/hf2mcore.py +++ b/swift/megatron/model/gpt/hf2mcore.py @@ -3,10 +3,7 @@ from megatron.training import get_args -def set_attn_state(args, mg_layer, hf_layer): - mg_attn = mg_layer.self_attention - hf_attn = hf_layer.self_attn - +def set_attn_state(args, mg_attn, hf_attn): num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) # Copy weights @@ -33,19 +30,37 @@ def set_attn_state(args, mg_layer, hf_layer): mg_attn.k_layernorm.weight.data.copy_(hf_attn.k_norm.weight) -def set_mlp_state(args, mg_layer, hf_layer): - mg_layer.mlp.linear_fc1.weight.data.copy_( - torch.cat([hf_layer.mlp.gate_proj.weight, hf_layer.mlp.up_proj.weight], dim=0)) - mg_layer.mlp.linear_fc2.weight.data.copy_(hf_layer.mlp.down_proj.weight) +def _set_mlp_state(mg_mlp, hf_mlp): + mg_mlp.linear_fc1.weight.data.copy_(torch.cat([hf_mlp.gate_proj.weight, hf_mlp.up_proj.weight], dim=0)) + mg_mlp.linear_fc2.weight.data.copy_(hf_mlp.down_proj.weight) + + +def set_mlp_state(args, mg_mlp, hf_mlp): + if args.num_experts: + mg_mlp.router.weight.data.copy_(hf_mlp.gate.weight) + if mg_mlp.shared_experts is not None: + mg_mlp.shared_experts.gate_weight.data.copy_(hf_mlp.shared_expert_gate.weight) + for expert_idx in range(args.num_experts): + _set_mlp_state(mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx]) + + if mg_mlp.shared_experts is not None: + _set_mlp_state(mg_mlp.shared_experts, mg_mlp.shared_experts) + else: + _set_mlp_state(mg_mlp, hf_mlp) def set_layer_state(args, mg_model, hf_model, layer_idx): mg_layer = mg_model.decoder.layers[layer_idx] hf_layer = hf_model.model.layers[layer_idx] - # self-attention - set_attn_state(args, mg_layer, hf_layer) - set_mlp_state(args, mg_layer, hf_layer) - mg_layer.mlp.linear_fc1.layer_norm_weight.data.copy_(hf_layer.post_attention_layernorm.weight) + + set_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) + set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) + + post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight + if args.num_experts: + mg_layer.pre_mlp_layernorm.weight.data.copy_(post_attention_layernorm_weight) + else: + mg_layer.mlp.linear_fc1.layer_norm_weight.data.copy_(post_attention_layernorm_weight) mg_layer.self_attention.linear_qkv.layer_norm_weight.data.copy_(hf_layer.input_layernorm.weight) diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py index 4ada232e55..ac572f5f0e 100644 --- a/swift/megatron/model/gpt/mcore2hf.py +++ b/swift/megatron/model/gpt/mcore2hf.py @@ -2,10 +2,7 @@ from megatron.training import get_args -def set_attn_state(args, mg_layer, hf_layer): - mg_attn = mg_layer.self_attention - hf_attn = hf_layer.self_attn - +def set_attn_state(args, mg_attn, hf_attn): num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) # Copy weights mg_attn_weight = mg_attn.linear_qkv.weight.reshape((num_query_groups, -1, args.hidden_size)) @@ -28,17 +25,31 @@ def set_attn_state(args, mg_layer, hf_layer): hf_attn.k_norm.weight.data.copy_(mg_attn.k_layernorm.weight) -def set_mlp_state(args, mg_layer, hf_layer): - hf_layer.mlp.gate_proj.weight.data.copy_(mg_layer.mlp.linear_fc1.weight[:args.ffn_hidden_size]) - hf_layer.mlp.up_proj.weight.data.copy_(mg_layer.mlp.linear_fc1.weight[args.ffn_hidden_size:]) - hf_layer.mlp.down_proj.weight.data.copy_(mg_layer.mlp.linear_fc2.weight) +def _set_mlp_state(mg_mlp, hf_mlp): + hf_mlp.gate_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[:args.ffn_hidden_size]) + hf_mlp.up_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[args.ffn_hidden_size:]) + hf_mlp.down_proj.weight.data.copy_(mg_mlp.linear_fc2.weight) + + +def set_mlp_state(args, mg_mlp, hf_mlp): + if args.num_experts: + hf_mlp.gate.weight.data.copy_(mg_mlp.router.weight) + if mg_mlp.shared_experts is not None: + hf_mlp.shared_expert_gate.weight.data.copy_(mg_mlp.shared_experts.gate_weight) + for expert_idx in range(args.num_experts): + _set_mlp_state(mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx]) + + if mg_mlp.shared_experts is not None: + _set_mlp_state(mg_mlp.shared_experts, hf_mlp.shared_expert) + else: + _set_mlp_state(mg_mlp, hf_mlp) def set_layer_state(args, mg_model, hf_model, layer_idx): mg_layer = mg_model.decoder.layers[layer_idx] hf_layer = hf_model.model.layers[layer_idx] - set_attn_state(args, mg_layer, hf_layer) - set_mlp_state(args, mg_layer, hf_layer) + set_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) + set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) hf_layer.post_attention_layernorm.weight.data.copy_(mg_layer.mlp.linear_fc1.layer_norm_weight) hf_layer.input_layernorm.weight.data.copy_(mg_layer.self_attention.linear_qkv.layer_norm_weight) diff --git a/swift/megatron/model/gpt/model.py b/swift/megatron/model/gpt/model.py index e23515b6fc..a5382f5e0f 100644 --- a/swift/megatron/model/gpt/model.py +++ b/swift/megatron/model/gpt/model.py @@ -13,6 +13,8 @@ def model_provider(pre_process=True, post_process=True): config.variable_seq_lengths = True transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention) + if args.num_experts and args.moe_shared_expert_intermediate_size: + transformer_layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} model = GPTModel( config=config, transformer_layer_spec=transformer_layer_spec, diff --git a/tests/megatron/test_align/test_llm.py b/tests/megatron/test_align/test_llm.py index e1c9814530..2d712c5a8a 100644 --- a/tests/megatron/test_align/test_llm.py +++ b/tests/megatron/test_align/test_llm.py @@ -8,6 +8,10 @@ def _test_model(model_id): export_main(ExportArguments(model=model_id, to_mcore=True, exist_ok=True, test_convert_precision=True)) +def test_qwen2(): + _test_model('Qwen/Qwen2-0.5B-Instruct') + + def test_llama2(): _test_model('modelscope/Llama-2-7b-chat-ms') @@ -21,7 +25,6 @@ def test_marco_o1(): def test_deepseek_r1_llama(): - # TODO: FIX rope _test_model('deepseek-ai/DeepSeek-R1-Distill-Llama-8B') @@ -49,7 +52,12 @@ def test_qwen3(): _test_model('Qwen/Qwen3-0.6B-Base') +def test_qwen2_moe(): + _test_model('Qwen/Qwen1.5-MoE-A2.7B-Chat') + + if __name__ == '__main__': + # test_qwen2() # test_llama2() # test_llama3() # test_marco_o1() @@ -59,4 +67,5 @@ def test_qwen3(): # test_megrez() # test_llama3_1() # test_llama3_2() - test_qwen3() + # test_qwen3() + test_qwen2_moe() From b8bb97b891b253694e59d03577e2d93beaeb60a2 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 27 Apr 2025 16:07:50 +0800 Subject: [PATCH 02/14] update --- ...13\345\222\214\346\225\260\346\215\256\351\233\206.md" | 8 ++++---- .../Instruction/Supported-models-and-datasets.md | 8 ++++---- swift/megatron/model/config.py | 2 +- swift/megatron/model/gpt/hf2mcore.py | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" index d7558d3304..28ccb11c30 100644 --- "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" @@ -173,11 +173,11 @@ |[Qwen/Qwen2.5-Math-1.5B](https://modelscope.cn/models/Qwen/Qwen2.5-Math-1.5B)|qwen2_5_math|qwen2_5_math|transformers>=4.37|✔|math|[Qwen/Qwen2.5-Math-1.5B](https://huggingface.co/Qwen/Qwen2.5-Math-1.5B)| |[Qwen/Qwen2.5-Math-7B](https://modelscope.cn/models/Qwen/Qwen2.5-Math-7B)|qwen2_5_math|qwen2_5_math|transformers>=4.37|✔|math|[Qwen/Qwen2.5-Math-7B](https://huggingface.co/Qwen/Qwen2.5-Math-7B)| |[Qwen/Qwen2.5-Math-72B](https://modelscope.cn/models/Qwen/Qwen2.5-Math-72B)|qwen2_5_math|qwen2_5_math|transformers>=4.37|✔|math|[Qwen/Qwen2.5-Math-72B](https://huggingface.co/Qwen/Qwen2.5-Math-72B)| -|[Qwen/Qwen1.5-MoE-A2.7B-Chat](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B-Chat)|qwen2_moe|qwen|transformers>=4.40|✘|-|[Qwen/Qwen1.5-MoE-A2.7B-Chat](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B-Chat)| -|[Qwen/Qwen1.5-MoE-A2.7B](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B)|qwen2_moe|qwen|transformers>=4.40|✘|-|[Qwen/Qwen1.5-MoE-A2.7B](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B)| +|[Qwen/Qwen1.5-MoE-A2.7B-Chat](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B-Chat)|qwen2_moe|qwen|transformers>=4.40|✔|-|[Qwen/Qwen1.5-MoE-A2.7B-Chat](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B-Chat)| +|[Qwen/Qwen1.5-MoE-A2.7B](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B)|qwen2_moe|qwen|transformers>=4.40|✔|-|[Qwen/Qwen1.5-MoE-A2.7B](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B)| |[Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4)|qwen2_moe|qwen|transformers>=4.40|✘|-|[Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4)| -|[Qwen/Qwen2-57B-A14B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-57B-A14B-Instruct)|qwen2_moe|qwen|transformers>=4.40|✘|-|[Qwen/Qwen2-57B-A14B-Instruct](https://huggingface.co/Qwen/Qwen2-57B-A14B-Instruct)| -|[Qwen/Qwen2-57B-A14B](https://modelscope.cn/models/Qwen/Qwen2-57B-A14B)|qwen2_moe|qwen|transformers>=4.40|✘|-|[Qwen/Qwen2-57B-A14B](https://huggingface.co/Qwen/Qwen2-57B-A14B)| +|[Qwen/Qwen2-57B-A14B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-57B-A14B-Instruct)|qwen2_moe|qwen|transformers>=4.40|✔|-|[Qwen/Qwen2-57B-A14B-Instruct](https://huggingface.co/Qwen/Qwen2-57B-A14B-Instruct)| +|[Qwen/Qwen2-57B-A14B](https://modelscope.cn/models/Qwen/Qwen2-57B-A14B)|qwen2_moe|qwen|transformers>=4.40|✔|-|[Qwen/Qwen2-57B-A14B](https://huggingface.co/Qwen/Qwen2-57B-A14B)| |[Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4](https://modelscope.cn/models/Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4)|qwen2_moe|qwen|transformers>=4.40|✘|-|[Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4](https://huggingface.co/Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4)| |[Qwen/QwQ-32B-Preview](https://modelscope.cn/models/Qwen/QwQ-32B-Preview)|qwq_preview|qwq_preview|transformers>=4.37|✔|-|[Qwen/QwQ-32B-Preview](https://huggingface.co/Qwen/QwQ-32B-Preview)| |[Qwen/QwQ-32B](https://modelscope.cn/models/Qwen/QwQ-32B)|qwq|qwq|transformers>=4.37|✔|-|[Qwen/QwQ-32B](https://huggingface.co/Qwen/QwQ-32B)| diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index adaace369f..d5c2e475a5 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -173,11 +173,11 @@ The table below introduces the models integrated with ms-swift: |[Qwen/Qwen2.5-Math-1.5B](https://modelscope.cn/models/Qwen/Qwen2.5-Math-1.5B)|qwen2_5_math|qwen2_5_math|transformers>=4.37|✔|math|[Qwen/Qwen2.5-Math-1.5B](https://huggingface.co/Qwen/Qwen2.5-Math-1.5B)| |[Qwen/Qwen2.5-Math-7B](https://modelscope.cn/models/Qwen/Qwen2.5-Math-7B)|qwen2_5_math|qwen2_5_math|transformers>=4.37|✔|math|[Qwen/Qwen2.5-Math-7B](https://huggingface.co/Qwen/Qwen2.5-Math-7B)| |[Qwen/Qwen2.5-Math-72B](https://modelscope.cn/models/Qwen/Qwen2.5-Math-72B)|qwen2_5_math|qwen2_5_math|transformers>=4.37|✔|math|[Qwen/Qwen2.5-Math-72B](https://huggingface.co/Qwen/Qwen2.5-Math-72B)| -|[Qwen/Qwen1.5-MoE-A2.7B-Chat](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B-Chat)|qwen2_moe|qwen|transformers>=4.40|✘|-|[Qwen/Qwen1.5-MoE-A2.7B-Chat](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B-Chat)| -|[Qwen/Qwen1.5-MoE-A2.7B](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B)|qwen2_moe|qwen|transformers>=4.40|✘|-|[Qwen/Qwen1.5-MoE-A2.7B](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B)| +|[Qwen/Qwen1.5-MoE-A2.7B-Chat](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B-Chat)|qwen2_moe|qwen|transformers>=4.40|✔|-|[Qwen/Qwen1.5-MoE-A2.7B-Chat](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B-Chat)| +|[Qwen/Qwen1.5-MoE-A2.7B](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B)|qwen2_moe|qwen|transformers>=4.40|✔|-|[Qwen/Qwen1.5-MoE-A2.7B](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B)| |[Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4)|qwen2_moe|qwen|transformers>=4.40|✘|-|[Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4)| -|[Qwen/Qwen2-57B-A14B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-57B-A14B-Instruct)|qwen2_moe|qwen|transformers>=4.40|✘|-|[Qwen/Qwen2-57B-A14B-Instruct](https://huggingface.co/Qwen/Qwen2-57B-A14B-Instruct)| -|[Qwen/Qwen2-57B-A14B](https://modelscope.cn/models/Qwen/Qwen2-57B-A14B)|qwen2_moe|qwen|transformers>=4.40|✘|-|[Qwen/Qwen2-57B-A14B](https://huggingface.co/Qwen/Qwen2-57B-A14B)| +|[Qwen/Qwen2-57B-A14B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-57B-A14B-Instruct)|qwen2_moe|qwen|transformers>=4.40|✔|-|[Qwen/Qwen2-57B-A14B-Instruct](https://huggingface.co/Qwen/Qwen2-57B-A14B-Instruct)| +|[Qwen/Qwen2-57B-A14B](https://modelscope.cn/models/Qwen/Qwen2-57B-A14B)|qwen2_moe|qwen|transformers>=4.40|✔|-|[Qwen/Qwen2-57B-A14B](https://huggingface.co/Qwen/Qwen2-57B-A14B)| |[Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4](https://modelscope.cn/models/Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4)|qwen2_moe|qwen|transformers>=4.40|✘|-|[Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4](https://huggingface.co/Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4)| |[Qwen/QwQ-32B-Preview](https://modelscope.cn/models/Qwen/QwQ-32B-Preview)|qwq_preview|qwq_preview|transformers>=4.37|✔|-|[Qwen/QwQ-32B-Preview](https://huggingface.co/Qwen/QwQ-32B-Preview)| |[Qwen/QwQ-32B](https://modelscope.cn/models/Qwen/QwQ-32B)|qwq|qwq|transformers>=4.37|✔|-|[Qwen/QwQ-32B](https://huggingface.co/Qwen/QwQ-32B)| diff --git a/swift/megatron/model/config.py b/swift/megatron/model/config.py index f8cdf94eea..bd9c9656cf 100644 --- a/swift/megatron/model/config.py +++ b/swift/megatron/model/config.py @@ -39,7 +39,7 @@ def convert_hf_config(config) -> Dict[str, Any]: hf_v = getattr(config, hf_k) if k == 'rotary_base': megatron_config[k] = int(hf_v) - elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear'}: + elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}: megatron_config[k] = not hf_v elif k == 'swiglu': if hf_v == 'silu': diff --git a/swift/megatron/model/gpt/hf2mcore.py b/swift/megatron/model/gpt/hf2mcore.py index 441aa052b8..46525df3c7 100644 --- a/swift/megatron/model/gpt/hf2mcore.py +++ b/swift/megatron/model/gpt/hf2mcore.py @@ -44,7 +44,7 @@ def set_mlp_state(args, mg_mlp, hf_mlp): _set_mlp_state(mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx]) if mg_mlp.shared_experts is not None: - _set_mlp_state(mg_mlp.shared_experts, mg_mlp.shared_experts) + _set_mlp_state(mg_mlp.shared_experts, hf_mlp.shared_expert) else: _set_mlp_state(mg_mlp, hf_mlp) From f0a2da7e50b9938a84493f911a3f9eea057a566e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 27 Apr 2025 16:55:42 +0800 Subject: [PATCH 03/14] update --- ...\273\244\350\241\214\345\217\202\346\225\260.md" | 1 + .../Instruction/Command-line-parameters.md | 1 + swift/megatron/argument/megatron_args.py | 13 +++++++++---- swift/megatron/model/gpt/__init__.py | 1 + swift/megatron/model/gpt/config.py | 2 +- 5 files changed, 13 insertions(+), 5 deletions(-) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index d779d31cd1..2da5acafe8 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -48,6 +48,7 @@ - 注意:CPT/SFT的随机包括两个部分:数据集的随机,由`dataset_shuffle`控制;train_dataloader中的随机,由`train_dataloader_shuffle`控制。 - val_dataset_shuffle: 是否对val_dataset进行随机操作。默认为False。 - 🔥streaming: 流式读取并处理数据集,默认False。通常在处理大型数据集时,设置为True + - 注意:需要额外设置`--max_steps`,因为流式数据集无法获得其长度 - interleave_prob: 默认值为 None。在组合多个数据集时,默认使用 `concatenate_datasets` 函数;如果设置了该参数,则会使用 `interleave_datasets` 函数。该参数通常用于流式数据集的组合,并会作为参数传入 `interleave_datasets` 函数中 - stopping_strategy: 可选为"first_exhausted", "all_exhausted",默认为"first_exhausted"。传入interleave_datasets函数中 - shuffle_buffer_size: 该参数用于指定流式数据集的随机buffer大小,默认为1000 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 3b28807c86..f5238c1537 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -49,6 +49,7 @@ Hints: - Note: The shuffling in CPT/SFT consists of two parts: dataset shuffling, controlled by `dataset_shuffle`; and shuffling in the train_dataloader, controlled by `train_dataloader_shuffle`. - val_dataset_shuffle: Whether to perform shuffling on the val_dataset. Default is False. - 🔥streaming: Stream reading and processing of the dataset, default is False. It is typically set to True when handling large datasets. + - Note: It is necessary to set `--max_steps` additionally, as the length of the streaming dataset cannot be obtained. - interleave_prob: Defaults to None. When combining multiple datasets, the `concatenate_datasets` function is used by default. If this parameter is set, the `interleave_datasets` function will be used instead. This parameter is typically used when combining streaming datasets and is passed to the `interleave_datasets` function. - stopping_strategy: Can be either "first_exhausted" or "all_exhausted", with the default being "first_exhausted". This parameter is passed to the `interleave_datasets` function. - shuffle_buffer_size: This parameter is used to specify the shuffle buffer size for streaming datasets. Defaults to 1000. diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 9c3dcb3047..4a7d18fe2c 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -157,6 +157,14 @@ def _init_mixed_precision(self): if self.apply_query_key_layer_scaling: os.environ['NVTE_APPLY_QK_LAYER_SCALING'] = '1' + def _init_moe(self): + if self.moe_shared_expert_intermediate_size == 0: + self.moe_shared_expert_intermediate_size = None + if self.moe_ffn_hidden_size is None: + self.moe_ffn_hidden_size = self.ffn_hidden_size + else: + self.ffn_hidden_size = self.moe_ffn_hidden_size + def __post_init__(self): from swift.llm.argument.base_args.model_args import ModelArguments os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' @@ -169,10 +177,7 @@ def __post_init__(self): self.seq_length = self.max_position_embeddings if self.tensorboard_dir is None and self.save is not None: self.tensorboard_dir = f'{self.save}/runs' - if self.moe_ffn_hidden_size is None: - self.moe_ffn_hidden_size = self.ffn_hidden_size - else: - self.ffn_hidden_size = self.moe_ffn_hidden_size + self._init_moe() self._init_mixed_precision() self.tensorboard_dir = to_abspath(self.tensorboard_dir) diff --git a/swift/megatron/model/gpt/__init__.py b/swift/megatron/model/gpt/__init__.py index 319306a551..8d9af9f711 100644 --- a/swift/megatron/model/gpt/__init__.py +++ b/swift/megatron/model/gpt/__init__.py @@ -36,4 +36,5 @@ ModelType.mengzi3, ModelType.qwen3, ModelType.qwen2_moe, + ModelType.qwen3_moe, ], model_provider, convert_gpt_hf_config, convert_mcore2hf, convert_hf2mcore)) diff --git a/swift/megatron/model/gpt/config.py b/swift/megatron/model/gpt/config.py index 2b13191fb1..5183239c2f 100644 --- a/swift/megatron/model/gpt/config.py +++ b/swift/megatron/model/gpt/config.py @@ -6,7 +6,7 @@ def convert_gpt_hf_config(config) -> Dict[str, Any]: res = convert_hf_config(config) model_type = res.get('model_type') - if model_type == 'qwen3': + if model_type in {'qwen3', 'qwen3_moe'}: res['qk_layernorm'] = True elif model_type in {'qwen2_moe', 'qwen3_moe'}: res.pop('ffn_hidden_size', None) From 6f36d9ee087f331aa1e0a65cca9c2d2f654516ba Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 27 Apr 2025 19:14:30 +0800 Subject: [PATCH 04/14] update --- swift/megatron/argument/megatron_args.py | 50 ++++++++++++++++++------ swift/megatron/argument/train_args.py | 3 +- swift/megatron/model/gpt/hf2mcore.py | 8 ++-- swift/megatron/model/gpt/mcore2hf.py | 8 ++-- swift/megatron/train/patcher.py | 7 ++-- 5 files changed, 52 insertions(+), 24 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 4a7d18fe2c..74749d84c5 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -90,35 +90,36 @@ class MegatronArguments(ExtraMegatronArguments): ffn_hidden_size: Optional[int] = None num_attention_heads: Optional[int] = None group_query_attention: Optional[bool] = None - num_query_groups: int = 1 + num_query_groups: Optional[int] = None max_position_embeddings: Optional[int] = None position_embedding_type: Literal['learned_absolute', 'rope', 'relative', 'none'] = 'rope' - rotary_base: int = 10000 + rotary_base: Optional[int] = None rotary_percent: float = 1. normalization: Literal['LayerNorm', 'RMSNorm'] = 'RMSNorm' - norm_epsilon: float = 1e-5 - swiglu: bool = True - untie_embeddings_and_output_weights: bool = True - disable_bias_linear: bool = True - add_qkv_bias: bool = True - attention_dropout: float = 0. + norm_epsilon: Optional[float] = None + swiglu: Optional[bool] = None + untie_embeddings_and_output_weights: Optional[bool] = None + disable_bias_linear: Optional[bool] = None + add_qkv_bias: Optional[bool] = None + attention_dropout: Optional[float] = None hidden_dropout: float = 0. kv_channels: Optional[int] = None qk_layernorm: bool = False transformer_impl: Literal['local', 'transformer_engine'] = 'transformer_engine' # moe - expert_model_parallel_size: int = 1 num_experts: Optional[int] = None moe_ffn_hidden_size: Optional[int] = None moe_shared_expert_intermediate_size: Optional[int] = None - moe_router_topk: int = 2 + moe_router_topk: Optional[int] = None + moe_router_pre_softmax: Optional[bool] = None + moe_aux_loss_coeff: Optional[float] = None + + expert_model_parallel_size: int = 1 moe_token_dispatcher_type: Literal['allgather', 'alltoall', 'alltoall_seq'] = 'alltoall' moe_grouped_gemm: bool = False moe_router_load_balancing_type: Literal['aux_loss', 'seq_aux_loss', 'sinkhorn', 'none'] = 'aux_loss' - moe_aux_loss_coeff: float = 0. moe_z_loss_coeff: Optional[float] = None - moe_router_pre_softmax: bool = False moe_expert_capacity_factor: Optional[float] = None moe_shared_expert_overlap: bool = False @@ -149,6 +150,30 @@ class MegatronArguments(ExtraMegatronArguments): num_workers: int = 4 no_create_attention_mask_in_dataloader: bool = True + def _set_default(self): + if self.num_query_groups is None: + self.num_query_groups = 1 + if self.norm_epsilon is None: + self.norm_epsilon = 1e-5 + if self.rotary_base is None: + self.rotary_base = 10000 + if self.attention_dropout is None: + self.attention_dropout = 0. + if self.untie_embeddings_and_output_weights is None: + self.untie_embeddings_and_output_weights = True + if self.swiglu is None: + self.swiglu = True + if self.add_qkv_bias is None: + self.add_qkv_bias = True + if self.disable_bias_linear is None: + self.disable_bias_linear = True + if self.moe_router_topk is None: + self.moe_router_topk = 2 + if self.moe_router_pre_softmax is None: + self.moe_router_pre_softmax = False + if self.moe_aux_loss_coeff is None: + self.moe_aux_loss_coeff = 0. + def _init_mixed_precision(self): from swift.llm.argument.base_args.model_args import ModelArguments ModelArguments._init_mixed_precision(self) @@ -168,6 +193,7 @@ def _init_moe(self): def __post_init__(self): from swift.llm.argument.base_args.model_args import ModelArguments os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' + self._set_default() self.group_query_attention = self.num_query_groups > 1 if self.rope_scaling is not None: self.rope_scaling = ModelArguments.parse_to_dict(self.rope_scaling) diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 3c89b2eb48..7627515c40 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -24,7 +24,8 @@ def init_model_args(self, config): self.megatron_model_meta = get_megatron_model_meta(self.model_type) kwargs = self.megatron_model_meta.convert_hf_config(config) for k, v in kwargs.items(): - setattr(self, k, v) + if getattr(self, k) is None: + setattr(self, k, v) MegatronArguments.__post_init__(self) self.extra_args = self.parse_to_megatron() diff --git a/swift/megatron/model/gpt/hf2mcore.py b/swift/megatron/model/gpt/hf2mcore.py index 46525df3c7..2cbe6dc320 100644 --- a/swift/megatron/model/gpt/hf2mcore.py +++ b/swift/megatron/model/gpt/hf2mcore.py @@ -30,7 +30,7 @@ def set_attn_state(args, mg_attn, hf_attn): mg_attn.k_layernorm.weight.data.copy_(hf_attn.k_norm.weight) -def _set_mlp_state(mg_mlp, hf_mlp): +def _set_mlp_state(args, mg_mlp, hf_mlp): mg_mlp.linear_fc1.weight.data.copy_(torch.cat([hf_mlp.gate_proj.weight, hf_mlp.up_proj.weight], dim=0)) mg_mlp.linear_fc2.weight.data.copy_(hf_mlp.down_proj.weight) @@ -41,12 +41,12 @@ def set_mlp_state(args, mg_mlp, hf_mlp): if mg_mlp.shared_experts is not None: mg_mlp.shared_experts.gate_weight.data.copy_(hf_mlp.shared_expert_gate.weight) for expert_idx in range(args.num_experts): - _set_mlp_state(mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx]) + _set_mlp_state(args, mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx]) if mg_mlp.shared_experts is not None: - _set_mlp_state(mg_mlp.shared_experts, hf_mlp.shared_expert) + _set_mlp_state(args, mg_mlp.shared_experts, hf_mlp.shared_expert) else: - _set_mlp_state(mg_mlp, hf_mlp) + _set_mlp_state(args, mg_mlp, hf_mlp) def set_layer_state(args, mg_model, hf_model, layer_idx): diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py index ac572f5f0e..caf7047e5a 100644 --- a/swift/megatron/model/gpt/mcore2hf.py +++ b/swift/megatron/model/gpt/mcore2hf.py @@ -25,7 +25,7 @@ def set_attn_state(args, mg_attn, hf_attn): hf_attn.k_norm.weight.data.copy_(mg_attn.k_layernorm.weight) -def _set_mlp_state(mg_mlp, hf_mlp): +def _set_mlp_state(args, mg_mlp, hf_mlp): hf_mlp.gate_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[:args.ffn_hidden_size]) hf_mlp.up_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[args.ffn_hidden_size:]) hf_mlp.down_proj.weight.data.copy_(mg_mlp.linear_fc2.weight) @@ -37,12 +37,12 @@ def set_mlp_state(args, mg_mlp, hf_mlp): if mg_mlp.shared_experts is not None: hf_mlp.shared_expert_gate.weight.data.copy_(mg_mlp.shared_experts.gate_weight) for expert_idx in range(args.num_experts): - _set_mlp_state(mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx]) + _set_mlp_state(args, mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx]) if mg_mlp.shared_experts is not None: - _set_mlp_state(mg_mlp.shared_experts, hf_mlp.shared_expert) + _set_mlp_state(args, mg_mlp.shared_experts, hf_mlp.shared_expert) else: - _set_mlp_state(mg_mlp, hf_mlp) + _set_mlp_state(args, mg_mlp, hf_mlp) def set_layer_state(args, mg_model, hf_model, layer_idx): diff --git a/swift/megatron/train/patcher.py b/swift/megatron/train/patcher.py index d1f09b98af..76a9862421 100644 --- a/swift/megatron/train/patcher.py +++ b/swift/megatron/train/patcher.py @@ -26,9 +26,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r if isinstance(v, torch.Tensor): v = v.item() logs[k] = round(v, 8) - logs['grad_norm'] = round(grad_norm, 8) - logs['params_norm'] = round(params_norm, 8) - logs['learning_rate'] = round(learning_rate, 8) + for k in {'grad_norm', 'params_norm', 'learning_rate'}: + v = locals()[k] + if v is not None: + logs[k] = round(v, 8) logs['consumed_samples'] = args.consumed_train_samples logs['global_step/max_steps'] = f'{iteration}/{args.train_iters}' if jsonl_writer is None: From 4aff7013cb888b693917d2c060827d15667e60a7 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 27 Apr 2025 22:29:07 +0800 Subject: [PATCH 05/14] update --- ...Megatron-SWIFT\350\256\255\347\273\203.md" | 26 +++++++++++++++--- .../Instruction/Megatron-SWIFT-Training.md | 26 +++++++++++++++--- swift/megatron/argument/megatron_args.py | 4 ++- tests/megatron/test_align/test_llm.py | 27 +++++++++++++++++-- 4 files changed, 74 insertions(+), 9 deletions(-) diff --git "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" index 8b54ecb2af..d63bbef43e 100644 --- "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" +++ "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" @@ -187,7 +187,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I - overlap_param_gather: 启用分布式优化器中参数all-gather的重叠(降低DP通信耗时)。默认为False。 - distributed_timeout_minutes: torch.distributed的timeout时间(单位为分钟),默认为60分钟。 -**日志参数** +**日志参数**: - log_params_norm: 记录参数的norm。默认为True。 - log_throughput: 记录每个GPU的吞吐量。默认为True。 - 注意:在非packing情况下,log_throughput并不准确,因为`seq_length`并不等于真实序列长度。 @@ -199,11 +199,11 @@ I am a language model developed by swift, you can call me swift-robot. How can I - log_memory_to_tensorboard: 将内存日志写入tensorboard。默认为True。 - logging_leval: 日志级别。默认为None。 -**评估参数** +**评估参数**: - 🔥eval_iters: 评估的迭代次数,默认为100。 - 🔥eval_interval: 评估的间隔(steps),默认为None,即设置为save_interval。 -**混合精度参数** +**混合精度参数**: - fp16: fp16模式。默认为None,会根据模型的torch_dtype进行设置。torch_dtype默认读取config.json。 - bf16: bf16模式。默认为None,会根据模型的torch_dtype进行设置。 - apply_query_key_layer_scaling: 将`Q * K^T` 缩放为 `1 / 层数`(例如:第layer_num层则除以layer_num)。这对fp16训练很有帮助。默认为None,即若使用`--fp16`,则设置为True。 @@ -228,9 +228,29 @@ I am a language model developed by swift, you can call me swift-robot. How can I - add_qkv_bias: 仅在QKV的linear中增加bias,默认为True。 - attention_dropout: 默认为0.。 - hidden_dropout: 默认为0.。 +- kv_channels: 默认为None,设置为`args.hidden_size // args.num_attention_heads`。 +- qk_layernorm: 是否对Q和K进行层归一化。 - transformer_impl: 使用哪种transformer实现,可选项为'local'和'transformer_engine'。默认为transformer_engine。 - padded_vocab_size: 完整词表大小,默认为None。 - rope_scaling: rope_scaling相关参数,默认为None。格式参考[llama3.1 config.json](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-8B-Instruct/file/view/master?fileName=config.json&status=1),传入json字符串。 +- model_type: Huggingface模型权重中config.json中的model_type。 + + +**MoE参数**: +- num_experts: MoE的专家数,默认为None。自动从config.json读取。 +- moe_ffn_hidden_siz: 每个专家的前馈网络(ffn)的隐藏层大小。默认为None,设置为ffn_hidden_size。自动从config.json读取。 +- moe_shared_expert_intermediate_size: 共享专家的总FFN隐藏层大小。如果有多个共享专家,它应等于 `num_shared_experts * ffn_size_of_each_shared_expert`。 默认为None。自动从config.json读取。 +- moe_router_topk: 每个token路由到的专家数量。默认为None。自动从config.json读取。 +- moe_router_pre_softmax: 为MoE启用预softmax路由,这意味着softmax会在top-k选择之前进行。默认为None。自动从config.json读取。 +- moe_aux_loss_coeff: 辅助损失的缩放系数:建议的初始值为 1e-2。默认为None。自动从config.json读取。 +- expert_model_parallel_size: 专家并行数,默认为1。 +- moe_token_dispatcher_type: 要使用的token分发器类型。可选选项包括 'allgather'、'alltoall' 和 'alltoall_seq'。默认值为 'alltoall'。 +- moe_grouped_gemm: 当每个rank包含多个专家时,通过在多个流中启动多个本地 GEMM 内核,利用 TransformerEngine中的GroupedLinear提高利用率和性能。默认为False。 +- moe_router_load_balancing_type: 确定路由器的负载均衡策略。可选项为"aux_loss"、"seq_aux_loss"、"sinkhorn"、"none"。默认值为 "aux_loss"。 +- moe_z_loss_coeff: z-loss 的缩放系数。默认为None。 +- moe_expert_capacity_factor: 每个专家的容量因子,None表示不会丢弃任何token。默认为None。 +- moe_shared_expert_overlap: 启用共享专家计算与调度器通信之间的重叠。如果不启用此选项,共享专家将在路由专家之后执行。仅在设置了`moe_shared_expert_intermediate_size`时有效。默认为False。 + ### Megatron训练参数 diff --git a/docs/source_en/Instruction/Megatron-SWIFT-Training.md b/docs/source_en/Instruction/Megatron-SWIFT-Training.md index 6195281cf0..fbc5a9f784 100644 --- a/docs/source_en/Instruction/Megatron-SWIFT-Training.md +++ b/docs/source_en/Instruction/Megatron-SWIFT-Training.md @@ -193,7 +193,7 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the - overlap_param_gather: Overlap all-gather of parameters in the distributed optimizer (to reduce DP communication time). Default is False. - distributed_timeout_minutes: Timeout duration for torch.distributed (in minutes), default is 60 minutes. -**Logging Parameters** +**Logging Parameters**: - log_params_norm: Logs the norm of parameters. Default is True. - log_throughput: Logs throughput per GPU. Default is True. @@ -206,12 +206,12 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the - log_memory_to_tensorboard: Writes memory logs to TensorBoard. Default is True. - logging_level: Logging level. Default is None. -**Evaluation Parameters** +**Evaluation Parameters**: - 🔥eval_iters: Number of evaluation iterations, default is 100. - 🔥eval_interval: Evaluation interval (steps), default is None, meaning it will be set to save_interval. -**Mixed Precision Parameters** +**Mixed Precision Parameters**: - fp16: FP16 mode. The default is None, and it will be set according to the model's torch_dtype. The torch_dtype is read from the config.json by default. - bf16: BF16 mode. The default is None, and it will be set according to the model's torch_dtype. @@ -238,9 +238,29 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the - add_qkv_bias: Adds bias only to QKV linear layers. Default is True. - attention_dropout: Default is 0. - hidden_dropout: Default is 0. +- kv_channels: Defaults to None, set to `args.hidden_size // args.num_attention_heads`. +- qk_layernorm: Whether to apply layer normalization to Q and K. - transformer_impl: Which transformer implementation to use, options are 'local' and 'transformer_engine'. Default is transformer_engine. - padded_vocab_size: Full vocabulary size, default is None. - rope_scaling: Related parameters for rope_scaling, default is None. Refer to the format in [llama3.1 config.json](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-8B-Instruct/file/view/master?fileName=config.json&status=1). Pass the value as a JSON string. +- model_type: The model_type in the config.json of the Huggingface model weights. + + +**MoE Parameters**: + +- num_experts: The number of experts in MoE, default is None. Automatically read from config.json. +- moe_ffn_hidden_size: The hidden layer size of the feed-forward network (ffn) for each expert. Default is None, set to ffn_hidden_size. Automatically read from config.json. +- moe_shared_expert_intermediate_size: The total FFN hidden layer size for shared experts. If there are multiple shared experts, it should equal `num_shared_experts * ffn_size_of_each_shared_expert`. Default is None. Automatically read from config.json. +- moe_router_topk: The number of experts each token is routed to. Default is None. Automatically read from config.json. +- moe_router_pre_softmax: Enable pre-softmax routing for MoE, meaning that softmax will be applied before top-k selection. Default is None. Automatically read from config.json. +- moe_aux_loss_coeff: Scaling coefficient for the auxiliary loss: the recommended initial value is 1e-2. Default is None. Automatically read from config.json. +- expert_model_parallel_size: The degree of expert parallelism, default is 1. +- moe_token_dispatcher_type: The type of token dispatcher to use. Options include 'allgather', 'alltoall', and 'alltoall_seq'. Default is 'alltoall'. +- moe_grouped_gemm: When each rank contains multiple experts, improve utilization and performance by launching multiple local GEMM kernels across multiple streams using GroupedLinear in TransformerEngine. Default is False. +- moe_router_load_balancing_type: Determines the load balancing strategy for the router. Options are "aux_loss", "seq_aux_loss", "sinkhorn", "none". Default is "aux_loss". +- moe_z_loss_coeff: Scaling coefficient for z-loss. Default is None. +- moe_expert_capacity_factor: Capacity factor for each expert, None means no tokens will be dropped. Default is None. +- moe_shared_expert_overlap: Enable overlapping of shared expert computation with scheduler communication. If this option is not enabled, shared experts will execute after the routing experts. Only effective when `moe_shared_expert_intermediate_size` is set. Default is False. ### Megatron Training Parameters diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 74749d84c5..65867d0944 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -104,7 +104,7 @@ class MegatronArguments(ExtraMegatronArguments): attention_dropout: Optional[float] = None hidden_dropout: float = 0. kv_channels: Optional[int] = None - qk_layernorm: bool = False + qk_layernorm: Optional[bool] = None transformer_impl: Literal['local', 'transformer_engine'] = 'transformer_engine' # moe @@ -173,6 +173,8 @@ def _set_default(self): self.moe_router_pre_softmax = False if self.moe_aux_loss_coeff is None: self.moe_aux_loss_coeff = 0. + if self.qk_layernorm is None: + self.qk_layernorm = False def _init_mixed_precision(self): from swift.llm.argument.base_args.model_args import ModelArguments diff --git a/tests/megatron/test_align/test_llm.py b/tests/megatron/test_align/test_llm.py index 2d712c5a8a..7c5cdf3946 100644 --- a/tests/megatron/test_align/test_llm.py +++ b/tests/megatron/test_align/test_llm.py @@ -1,11 +1,29 @@ import os +import torch + os.environ['CUDA_VISIBLE_DEVICES'] = '0' def _test_model(model_id): from swift.llm import export_main, ExportArguments - export_main(ExportArguments(model=model_id, to_mcore=True, exist_ok=True, test_convert_precision=True)) + if model_id.endswith('mcore'): + export_main( + ExportArguments( + mcore_model=model_id, + to_hf=True, + exist_ok=True, + test_convert_precision=True, + torch_dtype=torch.bfloat16)) + else: + export_main( + ExportArguments( + model=model_id, + to_mcore=True, + exist_ok=True, + test_convert_precision=True, + torch_dtype=torch.bfloat16, + )) def test_qwen2(): @@ -56,6 +74,10 @@ def test_qwen2_moe(): _test_model('Qwen/Qwen1.5-MoE-A2.7B-Chat') +def test_qwen3_moe(): + _test_model('Qwen/Qwen3-15B-A2B-Base') + + if __name__ == '__main__': # test_qwen2() # test_llama2() @@ -68,4 +90,5 @@ def test_qwen2_moe(): # test_llama3_1() # test_llama3_2() # test_qwen3() - test_qwen2_moe() + # test_qwen2_moe() + test_qwen3_moe() From 76b67556342957a01492638bcea5caea1dd3750e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 27 Apr 2025 22:34:54 +0800 Subject: [PATCH 06/14] fix --- swift/megatron/model/gpt/mcore2hf.py | 7 ++++++- swift/megatron/model/gpt/model.py | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py index caf7047e5a..b9a4a9cb1a 100644 --- a/swift/megatron/model/gpt/mcore2hf.py +++ b/swift/megatron/model/gpt/mcore2hf.py @@ -50,7 +50,12 @@ def set_layer_state(args, mg_model, hf_model, layer_idx): hf_layer = hf_model.model.layers[layer_idx] set_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) - hf_layer.post_attention_layernorm.weight.data.copy_(mg_layer.mlp.linear_fc1.layer_norm_weight) + + post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight + if args.num_experts: + post_attention_layernorm_weight.data.copy_(mg_layer.pre_mlp_layernorm.weight) + else: + post_attention_layernorm_weight.data.copy_(mg_layer.mlp.linear_fc1.layer_norm_weight) hf_layer.input_layernorm.weight.data.copy_(mg_layer.self_attention.linear_qkv.layer_norm_weight) diff --git a/swift/megatron/model/gpt/model.py b/swift/megatron/model/gpt/model.py index a5382f5e0f..9bc6bf4fbc 100644 --- a/swift/megatron/model/gpt/model.py +++ b/swift/megatron/model/gpt/model.py @@ -14,6 +14,7 @@ def model_provider(pre_process=True, post_process=True): transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention) if args.num_experts and args.moe_shared_expert_intermediate_size: + # qwen2_moe/qwen3_moe transformer_layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} model = GPTModel( config=config, From 1ab0dd517b75b20e4b0c7571a0d380d42582187b Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 28 Apr 2025 10:40:23 +0800 Subject: [PATCH 07/14] update --- .../Megatron-SWIFT\350\256\255\347\273\203.md" | 7 +++++-- .../source_en/Instruction/Megatron-SWIFT-Training.md | 7 +++++-- swift/megatron/argument/megatron_args.py | 6 ++++-- swift/megatron/init.py | 12 ++++++++++++ 4 files changed, 26 insertions(+), 6 deletions(-) diff --git "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" index d63bbef43e..c18b9d01a4 100644 --- "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" +++ "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" @@ -16,6 +16,9 @@ pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable git clone https://github.com/NVIDIA/apex cd apex pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ + +# megatron-core +pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_r0.11.0 ``` 或者你也可以使用镜像: @@ -24,7 +27,7 @@ modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubunt modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py311-torch2.6.0-vllm0.8.3-modelscope1.25.0-swift3.3.0.post1 ``` -依赖库Megatron-LM将会由swift进行git clone并安装,不需要用户手动安装。你也可以通过环境变量`MEGATRON_LM_PATH`指向已经下载好的repo路径(断网环境,[core_r0.11.0分支](https://github.com/NVIDIA/Megatron-LM/tree/core_r0.11.0))。 +依赖库Megatron-LM中的训练模块将由swift进行git clone并安装。你也可以通过环境变量`MEGATRON_LM_PATH`指向已经下载好的repo路径(断网环境,[core_r0.11.0分支](https://github.com/NVIDIA/Megatron-LM/tree/core_r0.12.0))。 ## 快速入门案例 @@ -188,7 +191,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I - distributed_timeout_minutes: torch.distributed的timeout时间(单位为分钟),默认为60分钟。 **日志参数**: -- log_params_norm: 记录参数的norm。默认为True。 +- log_params_norm: 记录参数的norm。默认为False。 - log_throughput: 记录每个GPU的吞吐量。默认为True。 - 注意:在非packing情况下,log_throughput并不准确,因为`seq_length`并不等于真实序列长度。 - tensorboard_log_interval: 记录到tensorboard的间隔(steps),默认为1。 diff --git a/docs/source_en/Instruction/Megatron-SWIFT-Training.md b/docs/source_en/Instruction/Megatron-SWIFT-Training.md index fbc5a9f784..3bb859fc38 100644 --- a/docs/source_en/Instruction/Megatron-SWIFT-Training.md +++ b/docs/source_en/Instruction/Megatron-SWIFT-Training.md @@ -17,6 +17,9 @@ pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable git clone https://github.com/NVIDIA/apex cd apex pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ + +# megatron-core +pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_r0.12.0 ``` Alternatively, you can also use the image: @@ -25,7 +28,7 @@ modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubunt modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py311-torch2.6.0-vllm0.8.3-modelscope1.25.0-swift3.3.0.post1 ``` -The dependency library Megatron-LM will be git cloned and installed by swift, no manual installation by the user is required. You can also use the environment variable `MEGATRON_LM_PATH` to point to the already downloaded repo path (for offline environments, use the [core_r0.11.0 branch](https://github.com/NVIDIA/Megatron-LM/tree/core_r0.11.0)). +The training module in the dependent library Megatron-LM will be cloned and installed by swift via `git clone`. Alternatively, you can use the environment variable `MEGATRON_LM_PATH` to point to the path of an already downloaded repository (in offline environments, use the [core_r0.12.0 branch](https://github.com/NVIDIA/Megatron-LM/tree/core_r0.12.0)). ## Quick Start Example @@ -195,7 +198,7 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the **Logging Parameters**: -- log_params_norm: Logs the norm of parameters. Default is True. +- log_params_norm: Logs the norm of parameters. Default is False. - log_throughput: Logs throughput per GPU. Default is True. - Note: In non-packing scenarios, log_throughput is not accurate because `seq_length` does not equal the actual sequence length. - tensorboard_log_interval: Interval (steps) for logging to TensorBoard, default is 1. diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 65867d0944..7ca588dbe3 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -3,7 +3,7 @@ import sys from dataclasses import asdict, dataclass from typing import Any, Dict, List, Literal, Optional, Tuple, Union - +from transformers.utils.versions import require_version import torch from swift.llm.argument.base_args import to_abspath @@ -130,7 +130,7 @@ class MegatronArguments(ExtraMegatronArguments): attention_softmax_in_fp32: bool = True # logging - log_params_norm: bool = True + log_params_norm: bool = False log_throughput: bool = True tensorboard_log_interval: int = 1 tensorboard_queue_size: int = 50 @@ -194,6 +194,8 @@ def _init_moe(self): def __post_init__(self): from swift.llm.argument.base_args.model_args import ModelArguments + if self.use_flash_attn: + require_version('flash-attn') os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' self._set_default() self.group_query_attention = self.num_query_groups > 1 diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 4eca7959d4..7f31e3b37a 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -5,6 +5,17 @@ from swift.llm import git_clone_github from swift.utils import is_megatron_available, safe_ddp_context, subprocess_run +def _patch_megatron(): + try: + from transformer_engine.pytorch.attention import FusedRoPEFunc + except ImportError: + try: + import transformer_engine + transformer_engine.pytorch.attention.FusedRoPEFunc = ( + transformer_engine.pytorch.dot_product_attention.rope.FusedRoPEFunc + ) + except (ImportError, AttributeError): + pass def init_megatron_env() -> None: if 'MEGATRON_LM_PATH' not in os.environ: @@ -14,3 +25,4 @@ def init_megatron_env() -> None: if not is_megatron_available(): subprocess_run([sys.executable, '-m', 'pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']]) sys.path.insert(0, os.environ['MEGATRON_LM_PATH']) + _patch_megatron() From 53c5386e376ac78bfce0d517166753dead7b57d3 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 28 Apr 2025 13:03:31 +0800 Subject: [PATCH 08/14] update --- ...Megatron-SWIFT\350\256\255\347\273\203.md" | 11 ++++++- .../Instruction/Megatron-SWIFT-Training.md | 19 ++++++++--- examples/train/megatron/moe.sh | 32 +++++++++++++++++++ 3 files changed, 57 insertions(+), 5 deletions(-) create mode 100644 examples/train/megatron/moe.sh diff --git "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" index c18b9d01a4..742feaf36f 100644 --- "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" +++ "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" @@ -107,13 +107,22 @@ I am a language model developed by swift, you can call me swift-robot. How can I ## Benchmark -使用`megatron sft`和`swift sft`在单机八卡A800环境下进行14B模型全参数训练的速度对比如下,对应脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/benchmark)。 +使用`megatron sft`和`swift sft`在单机八卡A800环境下进行Dense/MoE模型全参数训练的速度对比如下,对应脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/benchmark)。 + +**Dense** Qwen2.5-14B: | | Megatron-LM | Deepspeed-ZeRO2 | Deepspeed-ZeRO3 | | -------- | ----------- | ---------- | ---------- | | 训练速度 | 9.04s/it | 10.32s/it | 10.56s/it | | 显存占用 | 8\*64GB | 8\*80GB | 8\*58GB | +**MoE** Qwen1.5-MoE-A2.7B: + +| | Megatron-LM | Deepspeed-ZeRO2 | Deepspeed-ZeRO3 | +| -------- | ----------- | ---------- | ---------- | +| 训练速度 | 3.53s/it | 6.02s/it | 24.30s/it | +| 显存占用 | 8\*66GB | 8\*72GB | 8\*50GB | + ## 命令行参数 diff --git a/docs/source_en/Instruction/Megatron-SWIFT-Training.md b/docs/source_en/Instruction/Megatron-SWIFT-Training.md index 3bb859fc38..4f6cf0bb4f 100644 --- a/docs/source_en/Instruction/Megatron-SWIFT-Training.md +++ b/docs/source_en/Instruction/Megatron-SWIFT-Training.md @@ -112,11 +112,22 @@ I am a language model developed by swift, you can call me swift-robot. How can I - For pretraining, you can use `megatron pt` instead of `megatron sft`, which will use a generative template for training. ## Benchmark +The speed comparison of full-parameter training for Dense/MoE models using `megatron sft` and `swift sft` on a single machine with eight A800 GPUs is shown below. The corresponding scripts can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/benchmark). -| | Megatron-LM | Deepspeed-ZeRO2 | Deepspeed-ZeRO3 | -| ---------------- | --------------- | --------------- | --------- | -| Training Speed | 9.04s/it | 10.32s/it | 10.56s/it | -| GPU Memory Usage | 8\*64GB | 8\*80GB | 8\*58GB | +**Dense** Qwen2.5-14B: + + +| | Megatron-LM | Deepspeed-ZeRO2 | Deepspeed-ZeRO3 | +| ---------------- | ----------- | --------------- | --------------- | +| Training Speed | 9.04s/it | 10.32s/it | 10.56s/it | +| GPU Memory Usage | 8\*64GB | 8\*80GB | 8\*58GB | + +**MoE** Qwen1.5-MoE-A2.7B: + +| | Megatron-LM | Deepspeed-ZeRO2 | Deepspeed-ZeRO3 | +| ---------------- | ----------- | --------------- | --------------- | +| Training Speed | 3.53s/it | 6.02s/it | 24.30s/it | +| GPU Memory Usage | 8\*66GB | 8\*72GB | 8\*50GB | ## Command Line Arguments diff --git a/examples/train/megatron/moe.sh b/examples/train/megatron/moe.sh new file mode 100644 index 0000000000..34e2f8f2c8 --- /dev/null +++ b/examples/train/megatron/moe.sh @@ -0,0 +1,32 @@ +# 8 * 65GiB +NPROC_PER_NODE=8 \ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +megatron sft \ + --load Qwen1.5-MoE-A2.7B-mcore \ + --dataset 'liucong/Chinese-DeepSeek-R1-Distill-data-110k-SFT' \ + --tensor_model_parallel_size 2 \ + --expert_model_parallel_size 4 \ + --moe_grouped_gemm true \ + --moe_shared_expert_overlap true \ + --moe_aux_loss_coeff 0.01 \ + --micro_batch_size 1 \ + --global_batch_size 16 \ + --packing true \ + --recompute_granularity selective \ + --train_iters 2000 \ + --eval_iters 50 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-5 \ + --lr_warmup_iters 100 \ + --min_lr 1e-6 \ + --save megatron_output/Qwen1.5-MoE-A2.7B \ + --eval_interval 200 \ + --save_interval 200 \ + --max_length 8192 \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim true \ + --no_save_rng true \ + --sequence_parallel true \ + --use_flash_attn true From bad8a03ba15bfdfaa7f13fa07219f94bdfb041db Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 28 Apr 2025 13:30:24 +0800 Subject: [PATCH 09/14] update --- swift/megatron/argument/megatron_args.py | 3 ++- swift/megatron/init.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 7ca588dbe3..8b3e398764 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -3,8 +3,9 @@ import sys from dataclasses import asdict, dataclass from typing import Any, Dict, List, Literal, Optional, Tuple, Union -from transformers.utils.versions import require_version + import torch +from transformers.utils.versions import require_version from swift.llm.argument.base_args import to_abspath diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 7f31e3b37a..90016688c0 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -5,6 +5,7 @@ from swift.llm import git_clone_github from swift.utils import is_megatron_available, safe_ddp_context, subprocess_run + def _patch_megatron(): try: from transformer_engine.pytorch.attention import FusedRoPEFunc @@ -12,11 +13,11 @@ def _patch_megatron(): try: import transformer_engine transformer_engine.pytorch.attention.FusedRoPEFunc = ( - transformer_engine.pytorch.dot_product_attention.rope.FusedRoPEFunc - ) + transformer_engine.pytorch.dot_product_attention.rope.FusedRoPEFunc) except (ImportError, AttributeError): pass + def init_megatron_env() -> None: if 'MEGATRON_LM_PATH' not in os.environ: os.environ['MEGATRON_LM_PATH'] = git_clone_github( From 5b813e941461b516e86bb1cbb501eca6a69c64d3 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 28 Apr 2025 13:31:35 +0800 Subject: [PATCH 10/14] update --- .../Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" | 2 +- docs/source_en/Instruction/Megatron-SWIFT-Training.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" index 742feaf36f..b168045041 100644 --- "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" +++ "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" @@ -27,7 +27,7 @@ modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubunt modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py311-torch2.6.0-vllm0.8.3-modelscope1.25.0-swift3.3.0.post1 ``` -依赖库Megatron-LM中的训练模块将由swift进行git clone并安装。你也可以通过环境变量`MEGATRON_LM_PATH`指向已经下载好的repo路径(断网环境,[core_r0.11.0分支](https://github.com/NVIDIA/Megatron-LM/tree/core_r0.12.0))。 +依赖库Megatron-LM中的训练模块将由swift进行git clone并安装。你也可以通过环境变量`MEGATRON_LM_PATH`指向已经下载好的repo路径(断网环境,[core_r0.11.0分支](https://github.com/NVIDIA/Megatron-LM/tree/core_r0.11.0))。 ## 快速入门案例 diff --git a/docs/source_en/Instruction/Megatron-SWIFT-Training.md b/docs/source_en/Instruction/Megatron-SWIFT-Training.md index 4f6cf0bb4f..d14971b511 100644 --- a/docs/source_en/Instruction/Megatron-SWIFT-Training.md +++ b/docs/source_en/Instruction/Megatron-SWIFT-Training.md @@ -19,7 +19,7 @@ cd apex pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ # megatron-core -pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_r0.12.0 +pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_r0.11.0 ``` Alternatively, you can also use the image: @@ -28,7 +28,7 @@ modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubunt modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py311-torch2.6.0-vllm0.8.3-modelscope1.25.0-swift3.3.0.post1 ``` -The training module in the dependent library Megatron-LM will be cloned and installed by swift via `git clone`. Alternatively, you can use the environment variable `MEGATRON_LM_PATH` to point to the path of an already downloaded repository (in offline environments, use the [core_r0.12.0 branch](https://github.com/NVIDIA/Megatron-LM/tree/core_r0.12.0)). +The training module in the dependent library Megatron-LM will be cloned and installed by swift via `git clone`. Alternatively, you can use the environment variable `MEGATRON_LM_PATH` to point to the path of an already downloaded repository (in offline environments, use the [core_r0.11.0 branch](https://github.com/NVIDIA/Megatron-LM/tree/core_r0.11.0)). ## Quick Start Example From 8c649bbf4b2c8e967dd4c7310a53749ca5df750d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 28 Apr 2025 13:34:06 +0800 Subject: [PATCH 11/14] update --- .../Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" | 2 +- docs/source_en/Instruction/Megatron-SWIFT-Training.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" index b168045041..f5e91bb904 100644 --- "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" +++ "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" @@ -120,7 +120,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I | | Megatron-LM | Deepspeed-ZeRO2 | Deepspeed-ZeRO3 | | -------- | ----------- | ---------- | ---------- | -| 训练速度 | 3.53s/it | 6.02s/it | 24.30s/it | +| 训练速度 | 2.93s/it | 6.02s/it | 24.30s/it | | 显存占用 | 8\*66GB | 8\*72GB | 8\*50GB | diff --git a/docs/source_en/Instruction/Megatron-SWIFT-Training.md b/docs/source_en/Instruction/Megatron-SWIFT-Training.md index d14971b511..f4db6f8495 100644 --- a/docs/source_en/Instruction/Megatron-SWIFT-Training.md +++ b/docs/source_en/Instruction/Megatron-SWIFT-Training.md @@ -126,7 +126,7 @@ The speed comparison of full-parameter training for Dense/MoE models using `mega | | Megatron-LM | Deepspeed-ZeRO2 | Deepspeed-ZeRO3 | | ---------------- | ----------- | --------------- | --------------- | -| Training Speed | 3.53s/it | 6.02s/it | 24.30s/it | +| Training Speed | 2.93s/it | 6.02s/it | 24.30s/it | | GPU Memory Usage | 8\*66GB | 8\*72GB | 8\*50GB | ## Command Line Arguments From fda4d0501085294dee6c8e992233577802af409b Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 28 Apr 2025 13:45:51 +0800 Subject: [PATCH 12/14] update --- swift/llm/template/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index eb0822f2ff..e7cb1b8098 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1047,7 +1047,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: labels = labels[:self.max_length] if loss_scale is not None: loss_scale = loss_scale[:self.max_length] - else: + elif self.truncation_strategy == 'left': if len(input_ids) > self.max_length: logger.warning_once( 'Input data was left-truncated because its length exceeds `max_length` (input length: ' From a6891978e73ab05b6f82525ee0cc5f0dc9779551 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 28 Apr 2025 14:06:25 +0800 Subject: [PATCH 13/14] update --- .../Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" | 2 ++ docs/source_en/Instruction/Megatron-SWIFT-Training.md | 2 ++ swift/megatron/argument/megatron_args.py | 2 ++ swift/megatron/train/utils.py | 3 ++- 4 files changed, 8 insertions(+), 1 deletion(-) diff --git "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" index f5e91bb904..30d85a3f8e 100644 --- "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" +++ "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" @@ -272,3 +272,5 @@ Megatron训练参数继承自Megatron参数和基本参数。基本参数的内 - 🔥packing: 是否使用序列packing,默认为False。 - 🔥streaming: 流式读取并处理数据集,默认False。通常在处理大型数据集时,设置为True。更多流式的参数查看命令行参数文档。 - lazy_tokenize: 默认为False。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(这可以避免在训练中出现报错);设置为True,则在训练中对数据集进行tokenize(这可以节约内存)。 +- dataloader_persistent_workers: 透传入dataloader的参数,默认为True。 +- dataloader_prefetch_factor: 透传入dataloader的参数,默认为10。 diff --git a/docs/source_en/Instruction/Megatron-SWIFT-Training.md b/docs/source_en/Instruction/Megatron-SWIFT-Training.md index f4db6f8495..0772f821d4 100644 --- a/docs/source_en/Instruction/Megatron-SWIFT-Training.md +++ b/docs/source_en/Instruction/Megatron-SWIFT-Training.md @@ -284,3 +284,5 @@ Megatron training parameters inherit from Megatron parameters and basic paramete - 🔥packing: Whether to use sequence packing, defaults to False. - 🔥streaming: Stream reading and processing of the dataset, default is False. It is typically set to True when handling large datasets. For more information on streaming parameters, refer to the command-line parameters documentation. - lazy_tokenize: Default is False. If this parameter is set to False, all dataset samples are tokenized before training (this avoids errors during training); if set to True, tokenization occurs during training (this saves memory). +- dataloader_persistent_workers: A parameter passed directly to the dataloader, with a default value of True. +- dataloader_prefetch_factor: A parameter passed directly to the dataloader, with a default value of 10. diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 8b3e398764..a1cf7f7f6f 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -16,6 +16,8 @@ class ExtraMegatronArguments: rope_scaling: Optional[Union[dict, str]] = None torch_dtype: Optional[torch.dtype] = None model_type: Optional[str] = None + dataloader_persistent_workers: bool = True + dataloader_prefetch_factor: int = 10 @dataclass diff --git a/swift/megatron/train/utils.py b/swift/megatron/train/utils.py index 06fe8d34f0..21833014df 100644 --- a/swift/megatron/train/utils.py +++ b/swift/megatron/train/utils.py @@ -40,7 +40,8 @@ def build_streaming_dataloader(args, dataset, collate_fn): pin_memory=True, collate_fn=collate_fn, batch_size=args.micro_batch_size, - prefetch_factor=10, + prefetch_factor=args.dataloader_prefetch_factor, + persistent_workers=args.dataloader_persistent_workers, ) return iter(cyclic_iter(MegatronDataLoaderDispatcher(base_dataloader))) From b1d5461718ef5d89ae2a4e9fdf34ec6acd481689 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 28 Apr 2025 16:05:44 +0800 Subject: [PATCH 14/14] update --- swift/llm/argument/train_args.py | 3 +++ swift/megatron/model/gpt/config.py | 2 +- swift/megatron/model/gpt/hf2mcore.py | 8 ++++---- swift/megatron/model/gpt/mcore2hf.py | 13 +++++++------ 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index 60b07cd0a0..25e18eac64 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -136,6 +136,9 @@ def _init_lazy_tokenize(self): logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}') def __post_init__(self) -> None: + if self.packing and self.attn_impl != 'flash_attn': + logger.warning('The "packing" feature needs to be used in conjunction with "flash_attn". ' + 'Please specify `--attn_impl flash_attn`.') if self.resume_from_checkpoint: self.resume_from_checkpoint = to_abspath(self.resume_from_checkpoint, True) if self.train_type == 'full': diff --git a/swift/megatron/model/gpt/config.py b/swift/megatron/model/gpt/config.py index 5183239c2f..6658a952ab 100644 --- a/swift/megatron/model/gpt/config.py +++ b/swift/megatron/model/gpt/config.py @@ -8,6 +8,6 @@ def convert_gpt_hf_config(config) -> Dict[str, Any]: model_type = res.get('model_type') if model_type in {'qwen3', 'qwen3_moe'}: res['qk_layernorm'] = True - elif model_type in {'qwen2_moe', 'qwen3_moe'}: + if model_type in {'qwen2_moe', 'qwen3_moe'}: res.pop('ffn_hidden_size', None) return res diff --git a/swift/megatron/model/gpt/hf2mcore.py b/swift/megatron/model/gpt/hf2mcore.py index 2cbe6dc320..46525df3c7 100644 --- a/swift/megatron/model/gpt/hf2mcore.py +++ b/swift/megatron/model/gpt/hf2mcore.py @@ -30,7 +30,7 @@ def set_attn_state(args, mg_attn, hf_attn): mg_attn.k_layernorm.weight.data.copy_(hf_attn.k_norm.weight) -def _set_mlp_state(args, mg_mlp, hf_mlp): +def _set_mlp_state(mg_mlp, hf_mlp): mg_mlp.linear_fc1.weight.data.copy_(torch.cat([hf_mlp.gate_proj.weight, hf_mlp.up_proj.weight], dim=0)) mg_mlp.linear_fc2.weight.data.copy_(hf_mlp.down_proj.weight) @@ -41,12 +41,12 @@ def set_mlp_state(args, mg_mlp, hf_mlp): if mg_mlp.shared_experts is not None: mg_mlp.shared_experts.gate_weight.data.copy_(hf_mlp.shared_expert_gate.weight) for expert_idx in range(args.num_experts): - _set_mlp_state(args, mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx]) + _set_mlp_state(mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx]) if mg_mlp.shared_experts is not None: - _set_mlp_state(args, mg_mlp.shared_experts, hf_mlp.shared_expert) + _set_mlp_state(mg_mlp.shared_experts, hf_mlp.shared_expert) else: - _set_mlp_state(args, mg_mlp, hf_mlp) + _set_mlp_state(mg_mlp, hf_mlp) def set_layer_state(args, mg_model, hf_model, layer_idx): diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py index b9a4a9cb1a..6f29abaf0e 100644 --- a/swift/megatron/model/gpt/mcore2hf.py +++ b/swift/megatron/model/gpt/mcore2hf.py @@ -25,9 +25,10 @@ def set_attn_state(args, mg_attn, hf_attn): hf_attn.k_norm.weight.data.copy_(mg_attn.k_layernorm.weight) -def _set_mlp_state(args, mg_mlp, hf_mlp): - hf_mlp.gate_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[:args.ffn_hidden_size]) - hf_mlp.up_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[args.ffn_hidden_size:]) +def _set_mlp_state(mg_mlp, hf_mlp): + ffn_hidden_size = hf_mlp.gate_proj.weight.shape[0] + hf_mlp.gate_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[:ffn_hidden_size]) + hf_mlp.up_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[ffn_hidden_size:]) hf_mlp.down_proj.weight.data.copy_(mg_mlp.linear_fc2.weight) @@ -37,12 +38,12 @@ def set_mlp_state(args, mg_mlp, hf_mlp): if mg_mlp.shared_experts is not None: hf_mlp.shared_expert_gate.weight.data.copy_(mg_mlp.shared_experts.gate_weight) for expert_idx in range(args.num_experts): - _set_mlp_state(args, mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx]) + _set_mlp_state(mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx]) if mg_mlp.shared_experts is not None: - _set_mlp_state(args, mg_mlp.shared_experts, hf_mlp.shared_expert) + _set_mlp_state(mg_mlp.shared_experts, hf_mlp.shared_expert) else: - _set_mlp_state(args, mg_mlp, hf_mlp) + _set_mlp_state(mg_mlp, hf_mlp) def set_layer_state(args, mg_model, hf_model, layer_idx):