@@ -481,9 +481,28 @@ class TestTensorParallelDense4Proc(TestTensorParallelDenseBase):
481481
482482
483483# ====== MOE MODEL TEST FUNCTIONS ======
484+ def _get_tiny_moe_config ():
485+ """Create a tiny MoE config for testing."""
486+ from transformers import Qwen3MoeConfig
487+
488+ model_id = "Qwen/Qwen3-30B-A3B-Base"
489+
490+ # Create a tiny MoE config to reduce model size for testing
491+ config = Qwen3MoeConfig .from_pretrained (model_id )
492+ config .num_hidden_layers = 2 # Reduce from 24 to 2
493+ config .num_experts = 4 # Reduce from 128 to 4
494+ config .num_experts_per_tok = 2 # Reduce from 8 to 2
495+ config .hidden_size = 512 # Reduce from 2048 to 512
496+ config .intermediate_size = 1536 # Reduce from 6144 to 1536
497+ config .moe_intermediate_size = 256 # Reduce from 768 to 256
498+ config .num_attention_heads = 8 # Reduce from 32 to 8
499+ config .num_key_value_heads = 2 # Reduce from 4 to 2
500+
501+ return config
502+
484503def _test_model_moe_forward_impl (rank , mode ):
485504 """Implementation for comparing TP and non-TP MoE model outputs."""
486- model_id = "Qwen/Qwen3-0.6B "
505+ model_id = "Qwen/Qwen3-30B-A3B-Base "
487506
488507 # Ensure same random seed for reproducibility
489508 torch .manual_seed (0 )
@@ -493,8 +512,11 @@ def _test_model_moe_forward_impl(rank, mode):
493512 prompt = "Can I help"
494513 inputs = tokenizer (prompt , return_tensors = "pt" )
495514
515+ # Create a tiny MoE config for testing
516+ config = _get_tiny_moe_config ()
517+
496518 # Load TP model first to determine device
497- model_tp = AutoModelForCausalLM .from_pretrained ( model_id , dtype = "auto" , tp_plan = "auto" )
519+ model_tp = AutoModelForCausalLM .from_config ( config , dtype = "auto" , tp_plan = "auto" )
498520 dist .barrier ()
499521 if mode == "eval" :
500522 model_tp .eval ()
@@ -503,7 +525,7 @@ def _test_model_moe_forward_impl(rank, mode):
503525
504526 # Load non-TP model and move to same device as TP model
505527 device = model_tp .device
506- model = AutoModelForCausalLM .from_pretrained ( model_id , dtype = "auto" )
528+ model = AutoModelForCausalLM .from_config ( config , dtype = "auto" )
507529 model = model .to (device )
508530
509531 if mode == "eval" :
@@ -534,16 +556,17 @@ def _test_model_moe_forward_impl(rank, mode):
534556
535557def _test_model_moe_backward_pass_impl (rank ):
536558 """Implementation for comparing TP and non-TP MoE model backward passes."""
537- model_id = "Qwen/Qwen3-0.6B"
538-
539559 torch .manual_seed (0 )
540560
541- model_tp = AutoModelForCausalLM .from_pretrained (model_id , dtype = torch .float32 , tp_plan = "auto" )
561+ # Create a tiny MoE config for testing
562+ config = _get_tiny_moe_config ()
563+
564+ model_tp = AutoModelForCausalLM .from_config (config , dtype = "auto" , tp_plan = "auto" )
542565 dist .barrier ()
543566 model_tp .train ()
544567
545568 device = model_tp .device
546- model = AutoModelForCausalLM .from_pretrained ( model_id , dtype = torch . float32 )
569+ model = AutoModelForCausalLM .from_config ( config , dtype = "auto" )
547570 model = model .to (device )
548571 model .train ()
549572
@@ -590,23 +613,26 @@ def _test_model_moe_backward_pass_impl(rank):
590613
591614def _test_model_moe_forward_compile_impl (rank , mode ):
592615 """Implementation for comparing TP and non-TP MoE model outputs with torch.compile."""
593- model_id = "Qwen/Qwen3-0.6B "
616+ model_id = "Qwen/Qwen3-30B-A3B-Base "
594617
595618 torch .manual_seed (0 )
596619
597620 tokenizer = AutoTokenizer .from_pretrained (model_id , use_fast = False )
598621 prompt = "Can I help"
599622 inputs = tokenizer (prompt , return_tensors = "pt" )
600623
601- model_tp = AutoModelForCausalLM .from_pretrained (model_id , dtype = "auto" , tp_plan = "auto" )
624+ # Create a tiny MoE config for testing
625+ config = _get_tiny_moe_config ()
626+
627+ model_tp = AutoModelForCausalLM .from_config (config , dtype = "auto" , tp_plan = "auto" )
602628 dist .barrier ()
603629 if mode == "eval" :
604630 model_tp .eval ()
605631 else :
606632 model_tp .train ()
607633
608634 device = model_tp .device
609- model = AutoModelForCausalLM .from_pretrained ( model_id , dtype = "auto" )
635+ model = AutoModelForCausalLM .from_config ( config , dtype = "auto" )
610636 model = model .to (device )
611637
612638 if mode == "eval" :
@@ -636,16 +662,16 @@ def _test_model_moe_forward_compile_impl(rank, mode):
636662
637663def _test_model_moe_save_impl (rank , tmp_dir ):
638664 """Implementation of test_model_save for MoE model distributed execution."""
639- model_id = "Qwen/Qwen3-0.6B"
640-
665+ # Create a tiny MoE config for testing
666+ config = _get_tiny_moe_config ()
667+
641668 if dist .is_initialized ():
642669 kwargs = {"tp_plan" : "auto" }
643670 result_dir = f"{ tmp_dir } /tp"
644671 else :
645672 kwargs = {}
646673 result_dir = f"{ tmp_dir } /nontp"
647-
648- model = AutoModelForCausalLM .from_pretrained (model_id , ** kwargs )
674+ model = AutoModelForCausalLM .from_config (config , dtype = "auto" , ** kwargs )
649675 model .save_pretrained (result_dir )
650676
651677
0 commit comments