Skip to content

Commit 05172a9

Browse files
committed
create tiny moe model + fix test tensor parallel Moe
eaeaae
1 parent 40b3e2b commit 05172a9

File tree

1 file changed

+40
-14
lines changed

1 file changed

+40
-14
lines changed

tests/tensor_parallel/test_tensor_parallel.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -481,9 +481,28 @@ class TestTensorParallelDense4Proc(TestTensorParallelDenseBase):
481481

482482

483483
# ====== MOE MODEL TEST FUNCTIONS ======
484+
def _get_tiny_moe_config():
485+
"""Create a tiny MoE config for testing."""
486+
from transformers import Qwen3MoeConfig
487+
488+
model_id = "Qwen/Qwen3-30B-A3B-Base"
489+
490+
# Create a tiny MoE config to reduce model size for testing
491+
config = Qwen3MoeConfig.from_pretrained(model_id)
492+
config.num_hidden_layers = 2 # Reduce from 24 to 2
493+
config.num_experts = 4 # Reduce from 128 to 4
494+
config.num_experts_per_tok = 2 # Reduce from 8 to 2
495+
config.hidden_size = 512 # Reduce from 2048 to 512
496+
config.intermediate_size = 1536 # Reduce from 6144 to 1536
497+
config.moe_intermediate_size = 256 # Reduce from 768 to 256
498+
config.num_attention_heads = 8 # Reduce from 32 to 8
499+
config.num_key_value_heads = 2 # Reduce from 4 to 2
500+
501+
return config
502+
484503
def _test_model_moe_forward_impl(rank, mode):
485504
"""Implementation for comparing TP and non-TP MoE model outputs."""
486-
model_id = "Qwen/Qwen3-0.6B"
505+
model_id = "Qwen/Qwen3-30B-A3B-Base"
487506

488507
# Ensure same random seed for reproducibility
489508
torch.manual_seed(0)
@@ -493,8 +512,11 @@ def _test_model_moe_forward_impl(rank, mode):
493512
prompt = "Can I help"
494513
inputs = tokenizer(prompt, return_tensors="pt")
495514

515+
# Create a tiny MoE config for testing
516+
config = _get_tiny_moe_config()
517+
496518
# Load TP model first to determine device
497-
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
519+
model_tp = AutoModelForCausalLM.from_config(config, dtype="auto", tp_plan="auto")
498520
dist.barrier()
499521
if mode == "eval":
500522
model_tp.eval()
@@ -503,7 +525,7 @@ def _test_model_moe_forward_impl(rank, mode):
503525

504526
# Load non-TP model and move to same device as TP model
505527
device = model_tp.device
506-
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
528+
model = AutoModelForCausalLM.from_config(config, dtype="auto")
507529
model = model.to(device)
508530

509531
if mode == "eval":
@@ -534,16 +556,17 @@ def _test_model_moe_forward_impl(rank, mode):
534556

535557
def _test_model_moe_backward_pass_impl(rank):
536558
"""Implementation for comparing TP and non-TP MoE model backward passes."""
537-
model_id = "Qwen/Qwen3-0.6B"
538-
539559
torch.manual_seed(0)
540560

541-
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32, tp_plan="auto")
561+
# Create a tiny MoE config for testing
562+
config = _get_tiny_moe_config()
563+
564+
model_tp = AutoModelForCausalLM.from_config(config, dtype="auto", tp_plan="auto")
542565
dist.barrier()
543566
model_tp.train()
544567

545568
device = model_tp.device
546-
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32)
569+
model = AutoModelForCausalLM.from_config(config, dtype="auto")
547570
model = model.to(device)
548571
model.train()
549572

@@ -590,23 +613,26 @@ def _test_model_moe_backward_pass_impl(rank):
590613

591614
def _test_model_moe_forward_compile_impl(rank, mode):
592615
"""Implementation for comparing TP and non-TP MoE model outputs with torch.compile."""
593-
model_id = "Qwen/Qwen3-0.6B"
616+
model_id = "Qwen/Qwen3-30B-A3B-Base"
594617

595618
torch.manual_seed(0)
596619

597620
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
598621
prompt = "Can I help"
599622
inputs = tokenizer(prompt, return_tensors="pt")
600623

601-
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
624+
# Create a tiny MoE config for testing
625+
config = _get_tiny_moe_config()
626+
627+
model_tp = AutoModelForCausalLM.from_config(config, dtype="auto", tp_plan="auto")
602628
dist.barrier()
603629
if mode == "eval":
604630
model_tp.eval()
605631
else:
606632
model_tp.train()
607633

608634
device = model_tp.device
609-
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
635+
model = AutoModelForCausalLM.from_config(config, dtype="auto")
610636
model = model.to(device)
611637

612638
if mode == "eval":
@@ -636,16 +662,16 @@ def _test_model_moe_forward_compile_impl(rank, mode):
636662

637663
def _test_model_moe_save_impl(rank, tmp_dir):
638664
"""Implementation of test_model_save for MoE model distributed execution."""
639-
model_id = "Qwen/Qwen3-0.6B"
640-
665+
# Create a tiny MoE config for testing
666+
config = _get_tiny_moe_config()
667+
641668
if dist.is_initialized():
642669
kwargs = {"tp_plan": "auto"}
643670
result_dir = f"{tmp_dir}/tp"
644671
else:
645672
kwargs = {}
646673
result_dir = f"{tmp_dir}/nontp"
647-
648-
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
674+
model = AutoModelForCausalLM.from_config(config, dtype="auto", **kwargs)
649675
model.save_pretrained(result_dir)
650676

651677

0 commit comments

Comments
 (0)