Skip to content

Commit d75f4b8

Browse files
committed
create tiny moe model + fix test tensor parallel Moe
eaeaae fix tensor parallel MoE test fix tensor parallel MoE test
1 parent 05172a9 commit d75f4b8

File tree

1 file changed

+13
-39
lines changed

1 file changed

+13
-39
lines changed

tests/tensor_parallel/test_tensor_parallel.py

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -481,28 +481,9 @@ class TestTensorParallelDense4Proc(TestTensorParallelDenseBase):
481481

482482

483483
# ====== MOE MODEL TEST FUNCTIONS ======
484-
def _get_tiny_moe_config():
485-
"""Create a tiny MoE config for testing."""
486-
from transformers import Qwen3MoeConfig
487-
488-
model_id = "Qwen/Qwen3-30B-A3B-Base"
489-
490-
# Create a tiny MoE config to reduce model size for testing
491-
config = Qwen3MoeConfig.from_pretrained(model_id)
492-
config.num_hidden_layers = 2 # Reduce from 24 to 2
493-
config.num_experts = 4 # Reduce from 128 to 4
494-
config.num_experts_per_tok = 2 # Reduce from 8 to 2
495-
config.hidden_size = 512 # Reduce from 2048 to 512
496-
config.intermediate_size = 1536 # Reduce from 6144 to 1536
497-
config.moe_intermediate_size = 256 # Reduce from 768 to 256
498-
config.num_attention_heads = 8 # Reduce from 32 to 8
499-
config.num_key_value_heads = 2 # Reduce from 4 to 2
500-
501-
return config
502-
503484
def _test_model_moe_forward_impl(rank, mode):
504485
"""Implementation for comparing TP and non-TP MoE model outputs."""
505-
model_id = "Qwen/Qwen3-30B-A3B-Base"
486+
model_id = "hf-internal-testing/tiny-qwen3-moe"
506487

507488
# Ensure same random seed for reproducibility
508489
torch.manual_seed(0)
@@ -512,11 +493,8 @@ def _test_model_moe_forward_impl(rank, mode):
512493
prompt = "Can I help"
513494
inputs = tokenizer(prompt, return_tensors="pt")
514495

515-
# Create a tiny MoE config for testing
516-
config = _get_tiny_moe_config()
517-
518496
# Load TP model first to determine device
519-
model_tp = AutoModelForCausalLM.from_config(config, dtype="auto", tp_plan="auto")
497+
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
520498
dist.barrier()
521499
if mode == "eval":
522500
model_tp.eval()
@@ -525,7 +503,7 @@ def _test_model_moe_forward_impl(rank, mode):
525503

526504
# Load non-TP model and move to same device as TP model
527505
device = model_tp.device
528-
model = AutoModelForCausalLM.from_config(config, dtype="auto")
506+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
529507
model = model.to(device)
530508

531509
if mode == "eval":
@@ -556,17 +534,16 @@ def _test_model_moe_forward_impl(rank, mode):
556534

557535
def _test_model_moe_backward_pass_impl(rank):
558536
"""Implementation for comparing TP and non-TP MoE model backward passes."""
537+
model_id = "hf-internal-testing/tiny-qwen3-moe"
538+
559539
torch.manual_seed(0)
560540

561-
# Create a tiny MoE config for testing
562-
config = _get_tiny_moe_config()
563-
564-
model_tp = AutoModelForCausalLM.from_config(config, dtype="auto", tp_plan="auto")
541+
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
565542
dist.barrier()
566543
model_tp.train()
567544

568545
device = model_tp.device
569-
model = AutoModelForCausalLM.from_config(config, dtype="auto")
546+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
570547
model = model.to(device)
571548
model.train()
572549

@@ -613,26 +590,23 @@ def _test_model_moe_backward_pass_impl(rank):
613590

614591
def _test_model_moe_forward_compile_impl(rank, mode):
615592
"""Implementation for comparing TP and non-TP MoE model outputs with torch.compile."""
616-
model_id = "Qwen/Qwen3-30B-A3B-Base"
593+
model_id = "hf-internal-testing/tiny-qwen3-moe"
617594

618595
torch.manual_seed(0)
619596

620597
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
621598
prompt = "Can I help"
622599
inputs = tokenizer(prompt, return_tensors="pt")
623600

624-
# Create a tiny MoE config for testing
625-
config = _get_tiny_moe_config()
626-
627-
model_tp = AutoModelForCausalLM.from_config(config, dtype="auto", tp_plan="auto")
601+
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
628602
dist.barrier()
629603
if mode == "eval":
630604
model_tp.eval()
631605
else:
632606
model_tp.train()
633607

634608
device = model_tp.device
635-
model = AutoModelForCausalLM.from_config(config, dtype="auto")
609+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
636610
model = model.to(device)
637611

638612
if mode == "eval":
@@ -662,16 +636,16 @@ def _test_model_moe_forward_compile_impl(rank, mode):
662636

663637
def _test_model_moe_save_impl(rank, tmp_dir):
664638
"""Implementation of test_model_save for MoE model distributed execution."""
665-
# Create a tiny MoE config for testing
666-
config = _get_tiny_moe_config()
639+
model_id = "hf-internal-testing/tiny-qwen3-moe"
667640

668641
if dist.is_initialized():
669642
kwargs = {"tp_plan": "auto"}
670643
result_dir = f"{tmp_dir}/tp"
671644
else:
672645
kwargs = {}
673646
result_dir = f"{tmp_dir}/nontp"
674-
model = AutoModelForCausalLM.from_config(config, dtype="auto", **kwargs)
647+
648+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", **kwargs)
675649
model.save_pretrained(result_dir)
676650

677651

0 commit comments

Comments
 (0)