Skip to content

Commit 40b3e2b

Browse files
committed
begin Moe test tensor parallel
1 parent 80134e6 commit 40b3e2b

File tree

1 file changed

+277
-7
lines changed

1 file changed

+277
-7
lines changed

tests/tensor_parallel/test_tensor_parallel.py

Lines changed: 277 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
# Run all tests: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py
1616
# Run specific config: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc"
1717
# Run multiple configs: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc or 4Proc"
18-
# Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc::test_model_dense_forward_train
19-
# Run tests with a specific prefix: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc -k "forward"
18+
# Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallelDense2Proc::test_model_dense_forward_train
19+
# Run tests with a specific prefix: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallelDense2Proc -k "forward"
20+
# Run MoE tests only: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "Moe"
21+
# Run dense tests only: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "TestTensorParallelDense2Proc or TestTensorParallelDense4Proc"
2022
import os
2123
import tempfile
2224
import warnings
@@ -381,7 +383,7 @@ def _test_model_dense_save_impl(rank, tmp_dir):
381383
model.save_pretrained(result_dir)
382384

383385

384-
class TestTensorParallelBase(TestCasePlus):
386+
class TestTensorParallelDenseBase(TestCasePlus):
385387
"""Base class for tensor parallel tests. Subclasses must set nproc_per_node."""
386388

387389
nproc_per_node = None
@@ -466,13 +468,281 @@ def test_model_dense_save(self):
466468
del non_tp_tensor, tp_tensor
467469

468470

469-
class TestTensorParallel2Proc(TestTensorParallelBase):
470-
"""Test tensor parallel with 2 processes."""
471+
class TestTensorParallelDense2Proc(TestTensorParallelDenseBase):
472+
"""Test tensor parallel dense model with 2 processes."""
471473

472474
nproc_per_node = 2
473475

474476

475-
class TestTensorParallel4Proc(TestTensorParallelBase):
476-
"""Test tensor parallel with 4 processes."""
477+
class TestTensorParallelDense4Proc(TestTensorParallelDenseBase):
478+
"""Test tensor parallel dense model with 4 processes."""
479+
480+
nproc_per_node = 4
481+
482+
483+
# ====== MOE MODEL TEST FUNCTIONS ======
484+
def _test_model_moe_forward_impl(rank, mode):
485+
"""Implementation for comparing TP and non-TP MoE model outputs."""
486+
model_id = "Qwen/Qwen3-0.6B"
487+
488+
# Ensure same random seed for reproducibility
489+
torch.manual_seed(0)
490+
491+
# Load tokenizer and prepare inputs - same for both models
492+
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
493+
prompt = "Can I help"
494+
inputs = tokenizer(prompt, return_tensors="pt")
495+
496+
# Load TP model first to determine device
497+
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
498+
dist.barrier()
499+
if mode == "eval":
500+
model_tp.eval()
501+
else:
502+
model_tp.train()
503+
504+
# Load non-TP model and move to same device as TP model
505+
device = model_tp.device
506+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
507+
model = model.to(device)
508+
509+
if mode == "eval":
510+
model.eval()
511+
else:
512+
model.train()
513+
514+
# Prepare inputs on the same device
515+
input_ids = inputs.input_ids.to(device)
516+
517+
# Run forward pass on both models
518+
with torch.no_grad():
519+
# Non-TP model output
520+
outputs = model(input_ids)
521+
logits = outputs.logits
522+
523+
# TP model output
524+
outputs_tp = model_tp(input_ids)
525+
logits_tp = outputs_tp.logits
526+
527+
# Compare outputs - they should match
528+
assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), (
529+
f"TP and non-TP MoE model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}"
530+
)
531+
532+
dist.barrier()
533+
534+
535+
def _test_model_moe_backward_pass_impl(rank):
536+
"""Implementation for comparing TP and non-TP MoE model backward passes."""
537+
model_id = "Qwen/Qwen3-0.6B"
538+
539+
torch.manual_seed(0)
540+
541+
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32, tp_plan="auto")
542+
dist.barrier()
543+
model_tp.train()
544+
545+
device = model_tp.device
546+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32)
547+
model = model.to(device)
548+
model.train()
549+
550+
batch_size, seq_length = 2, 10
551+
torch.manual_seed(42) # Different seed for inputs to ensure they're deterministic
552+
input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device)
553+
labels = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device)
554+
555+
outputs = model(input_ids, labels=labels)
556+
loss = outputs.loss
557+
loss.backward()
558+
559+
outputs_tp = model_tp(input_ids, labels=labels)
560+
loss_tp = outputs_tp.loss
561+
loss_tp.backward()
562+
563+
assert torch.allclose(loss, loss_tp, atol=1e-5, rtol=1e-5), (
564+
f"TP and non-TP MoE model losses differ. Non-TP loss: {loss.item()}, TP loss: {loss_tp.item()}, Diff: {(loss - loss_tp).abs().item()}"
565+
)
566+
567+
# Compare gradients for matching parameters
568+
for (name, param), (name_tp, param_tp) in zip(model.named_parameters(), model_tp.named_parameters()):
569+
if param.grad is not None and param_tp.grad is not None:
570+
grad = param.grad
571+
grad_tp = param_tp.grad
572+
573+
if isinstance(param_tp.data, dist.tensor.DTensor):
574+
placement = param_tp.data.placements[0]
575+
if hasattr(placement, "dim") and placement.dim is not None:
576+
grad_shard = get_tensor_shard(grad, grad, param_tp.data.device_mesh, rank, placement.dim)
577+
else:
578+
grad_shard = grad
579+
else:
580+
grad_shard = grad
581+
582+
grad_tp_local = grad_tp.to_local() if isinstance(grad_tp, dist.tensor.DTensor) else grad_tp
583+
584+
assert torch.allclose(grad_shard.cpu(), grad_tp_local.cpu(), atol=1e-5, rtol=1e-5), (
585+
f"Gradients differ for parameter {name}. Max diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().max().item()} | Min diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().min().item()}"
586+
)
587+
588+
dist.barrier()
589+
590+
591+
def _test_model_moe_forward_compile_impl(rank, mode):
592+
"""Implementation for comparing TP and non-TP MoE model outputs with torch.compile."""
593+
model_id = "Qwen/Qwen3-0.6B"
594+
595+
torch.manual_seed(0)
596+
597+
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
598+
prompt = "Can I help"
599+
inputs = tokenizer(prompt, return_tensors="pt")
600+
601+
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
602+
dist.barrier()
603+
if mode == "eval":
604+
model_tp.eval()
605+
else:
606+
model_tp.train()
607+
608+
device = model_tp.device
609+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
610+
model = model.to(device)
611+
612+
if mode == "eval":
613+
model.eval()
614+
else:
615+
model.train()
616+
617+
# Compile both models
618+
model.forward = torch.compile(model.forward)
619+
model_tp.forward = torch.compile(model_tp.forward)
620+
621+
input_ids = inputs.input_ids.to(device)
622+
623+
with torch.no_grad():
624+
outputs = model(input_ids)
625+
logits = outputs.logits
626+
627+
outputs_tp = model_tp(input_ids)
628+
logits_tp = outputs_tp.logits
629+
630+
assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), (
631+
f"TP and non-TP MoE model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}"
632+
)
633+
634+
dist.barrier()
635+
636+
637+
def _test_model_moe_save_impl(rank, tmp_dir):
638+
"""Implementation of test_model_save for MoE model distributed execution."""
639+
model_id = "Qwen/Qwen3-0.6B"
640+
641+
if dist.is_initialized():
642+
kwargs = {"tp_plan": "auto"}
643+
result_dir = f"{tmp_dir}/tp"
644+
else:
645+
kwargs = {}
646+
result_dir = f"{tmp_dir}/nontp"
647+
648+
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
649+
model.save_pretrained(result_dir)
650+
651+
652+
class TestTensorParallelMoeBase(TestCasePlus):
653+
"""Base class for MoE tensor parallel tests. Subclasses must set nproc_per_node."""
654+
655+
nproc_per_node = None
656+
657+
@require_torch_multi_accelerator
658+
def test_model_moe_forward_eval(self):
659+
"""Test that TP and non-TP MoE models produce the same outputs in eval mode."""
660+
if self.nproc_per_node is None:
661+
self.skipTest("nproc_per_node not set")
662+
if backend_device_count(torch_device) < self.nproc_per_node:
663+
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
664+
665+
init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_impl)("eval")
666+
667+
@require_torch_multi_accelerator
668+
def test_model_moe_forward_train(self):
669+
"""Test that TP and non-TP MoE models produce the same outputs in train mode."""
670+
if self.nproc_per_node is None:
671+
self.skipTest("nproc_per_node not set")
672+
if backend_device_count(torch_device) < self.nproc_per_node:
673+
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
674+
675+
init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_impl)("train")
676+
677+
@require_torch_multi_accelerator
678+
def test_model_moe_backward_pass(self):
679+
"""Test that TP and non-TP MoE models produce the same gradients."""
680+
if self.nproc_per_node is None:
681+
self.skipTest("nproc_per_node not set")
682+
if backend_device_count(torch_device) < self.nproc_per_node:
683+
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
684+
685+
init_distributed(tp=self.nproc_per_node)(_test_model_moe_backward_pass_impl)()
686+
687+
@require_torch_multi_accelerator
688+
def test_model_moe_forward_compile_eval(self):
689+
"""Test that TP and non-TP MoE models produce the same outputs with torch.compile in eval mode."""
690+
if self.nproc_per_node is None:
691+
self.skipTest("nproc_per_node not set")
692+
if backend_device_count(torch_device) < self.nproc_per_node:
693+
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
694+
695+
init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_compile_impl)("eval")
696+
697+
@require_torch_multi_accelerator
698+
def test_model_moe_forward_compile_train(self):
699+
"""Test that TP and non-TP MoE models produce the same outputs with torch.compile in train mode."""
700+
if self.nproc_per_node is None:
701+
self.skipTest("nproc_per_node not set")
702+
if backend_device_count(torch_device) < self.nproc_per_node:
703+
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
704+
705+
init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_compile_impl)("train")
706+
707+
@require_huggingface_hub_greater_or_equal("0.31.4")
708+
@require_torch_multi_accelerator
709+
def test_model_moe_save(self):
710+
"""Test that TP MoE model can be saved and matches non-TP version."""
711+
if self.nproc_per_node is None:
712+
self.skipTest("nproc_per_node not set")
713+
if backend_device_count(torch_device) < self.nproc_per_node:
714+
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
715+
716+
with tempfile.TemporaryDirectory() as tmp_dir:
717+
# First run with TP (distributed)
718+
init_distributed(tp=self.nproc_per_node)(_test_model_moe_save_impl)(tmp_dir)
719+
720+
# Then run without TP (non-distributed)
721+
_test_model_moe_save_impl(0, tmp_dir)
722+
723+
non_tp_model_path = os.path.join(tmp_dir, "nontp")
724+
tp_model_path = os.path.join(tmp_dir, "tp")
725+
726+
for filename in os.listdir(non_tp_model_path):
727+
if not filename.endswith(".safetensors"):
728+
continue
729+
730+
non_tp_model = safe_open(os.path.join(non_tp_model_path, filename), device="cpu", framework="pt")
731+
tp_model = safe_open(os.path.join(tp_model_path, filename), device="cpu", framework="pt")
732+
for non_tp_key in non_tp_model.keys():
733+
non_tp_tensor = non_tp_model.get_tensor(non_tp_key)
734+
tp_tensor = tp_model.get_tensor(non_tp_key)
735+
assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match"
736+
del non_tp_tensor, tp_tensor
737+
738+
739+
class TestTensorParallelMoe2Proc(TestTensorParallelMoeBase):
740+
"""Test MoE tensor parallel with 2 processes."""
741+
742+
nproc_per_node = 2
743+
744+
745+
class TestTensorParallelMoe4Proc(TestTensorParallelMoeBase):
746+
"""Test MoE tensor parallel with 4 processes."""
477747

478748
nproc_per_node = 4

0 commit comments

Comments
 (0)