|
15 | 15 | # Run all tests: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py |
16 | 16 | # Run specific config: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc" |
17 | 17 | # Run multiple configs: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc or 4Proc" |
18 | | -# Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc::test_model_dense_forward_train |
19 | | -# Run tests with a specific prefix: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc -k "forward" |
| 18 | +# Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallelDense2Proc::test_model_dense_forward_train |
| 19 | +# Run tests with a specific prefix: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallelDense2Proc -k "forward" |
| 20 | +# Run MoE tests only: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "Moe" |
| 21 | +# Run dense tests only: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "TestTensorParallelDense2Proc or TestTensorParallelDense4Proc" |
20 | 22 | import os |
21 | 23 | import tempfile |
22 | 24 | import warnings |
@@ -381,7 +383,7 @@ def _test_model_dense_save_impl(rank, tmp_dir): |
381 | 383 | model.save_pretrained(result_dir) |
382 | 384 |
|
383 | 385 |
|
384 | | -class TestTensorParallelBase(TestCasePlus): |
| 386 | +class TestTensorParallelDenseBase(TestCasePlus): |
385 | 387 | """Base class for tensor parallel tests. Subclasses must set nproc_per_node.""" |
386 | 388 |
|
387 | 389 | nproc_per_node = None |
@@ -466,13 +468,281 @@ def test_model_dense_save(self): |
466 | 468 | del non_tp_tensor, tp_tensor |
467 | 469 |
|
468 | 470 |
|
469 | | -class TestTensorParallel2Proc(TestTensorParallelBase): |
470 | | - """Test tensor parallel with 2 processes.""" |
| 471 | +class TestTensorParallelDense2Proc(TestTensorParallelDenseBase): |
| 472 | + """Test tensor parallel dense model with 2 processes.""" |
471 | 473 |
|
472 | 474 | nproc_per_node = 2 |
473 | 475 |
|
474 | 476 |
|
475 | | -class TestTensorParallel4Proc(TestTensorParallelBase): |
476 | | - """Test tensor parallel with 4 processes.""" |
| 477 | +class TestTensorParallelDense4Proc(TestTensorParallelDenseBase): |
| 478 | + """Test tensor parallel dense model with 4 processes.""" |
| 479 | + |
| 480 | + nproc_per_node = 4 |
| 481 | + |
| 482 | + |
| 483 | +# ====== MOE MODEL TEST FUNCTIONS ====== |
| 484 | +def _test_model_moe_forward_impl(rank, mode): |
| 485 | + """Implementation for comparing TP and non-TP MoE model outputs.""" |
| 486 | + model_id = "Qwen/Qwen3-0.6B" |
| 487 | + |
| 488 | + # Ensure same random seed for reproducibility |
| 489 | + torch.manual_seed(0) |
| 490 | + |
| 491 | + # Load tokenizer and prepare inputs - same for both models |
| 492 | + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) |
| 493 | + prompt = "Can I help" |
| 494 | + inputs = tokenizer(prompt, return_tensors="pt") |
| 495 | + |
| 496 | + # Load TP model first to determine device |
| 497 | + model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto") |
| 498 | + dist.barrier() |
| 499 | + if mode == "eval": |
| 500 | + model_tp.eval() |
| 501 | + else: |
| 502 | + model_tp.train() |
| 503 | + |
| 504 | + # Load non-TP model and move to same device as TP model |
| 505 | + device = model_tp.device |
| 506 | + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") |
| 507 | + model = model.to(device) |
| 508 | + |
| 509 | + if mode == "eval": |
| 510 | + model.eval() |
| 511 | + else: |
| 512 | + model.train() |
| 513 | + |
| 514 | + # Prepare inputs on the same device |
| 515 | + input_ids = inputs.input_ids.to(device) |
| 516 | + |
| 517 | + # Run forward pass on both models |
| 518 | + with torch.no_grad(): |
| 519 | + # Non-TP model output |
| 520 | + outputs = model(input_ids) |
| 521 | + logits = outputs.logits |
| 522 | + |
| 523 | + # TP model output |
| 524 | + outputs_tp = model_tp(input_ids) |
| 525 | + logits_tp = outputs_tp.logits |
| 526 | + |
| 527 | + # Compare outputs - they should match |
| 528 | + assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), ( |
| 529 | + f"TP and non-TP MoE model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}" |
| 530 | + ) |
| 531 | + |
| 532 | + dist.barrier() |
| 533 | + |
| 534 | + |
| 535 | +def _test_model_moe_backward_pass_impl(rank): |
| 536 | + """Implementation for comparing TP and non-TP MoE model backward passes.""" |
| 537 | + model_id = "Qwen/Qwen3-0.6B" |
| 538 | + |
| 539 | + torch.manual_seed(0) |
| 540 | + |
| 541 | + model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32, tp_plan="auto") |
| 542 | + dist.barrier() |
| 543 | + model_tp.train() |
| 544 | + |
| 545 | + device = model_tp.device |
| 546 | + model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32) |
| 547 | + model = model.to(device) |
| 548 | + model.train() |
| 549 | + |
| 550 | + batch_size, seq_length = 2, 10 |
| 551 | + torch.manual_seed(42) # Different seed for inputs to ensure they're deterministic |
| 552 | + input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device) |
| 553 | + labels = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device) |
| 554 | + |
| 555 | + outputs = model(input_ids, labels=labels) |
| 556 | + loss = outputs.loss |
| 557 | + loss.backward() |
| 558 | + |
| 559 | + outputs_tp = model_tp(input_ids, labels=labels) |
| 560 | + loss_tp = outputs_tp.loss |
| 561 | + loss_tp.backward() |
| 562 | + |
| 563 | + assert torch.allclose(loss, loss_tp, atol=1e-5, rtol=1e-5), ( |
| 564 | + f"TP and non-TP MoE model losses differ. Non-TP loss: {loss.item()}, TP loss: {loss_tp.item()}, Diff: {(loss - loss_tp).abs().item()}" |
| 565 | + ) |
| 566 | + |
| 567 | + # Compare gradients for matching parameters |
| 568 | + for (name, param), (name_tp, param_tp) in zip(model.named_parameters(), model_tp.named_parameters()): |
| 569 | + if param.grad is not None and param_tp.grad is not None: |
| 570 | + grad = param.grad |
| 571 | + grad_tp = param_tp.grad |
| 572 | + |
| 573 | + if isinstance(param_tp.data, dist.tensor.DTensor): |
| 574 | + placement = param_tp.data.placements[0] |
| 575 | + if hasattr(placement, "dim") and placement.dim is not None: |
| 576 | + grad_shard = get_tensor_shard(grad, grad, param_tp.data.device_mesh, rank, placement.dim) |
| 577 | + else: |
| 578 | + grad_shard = grad |
| 579 | + else: |
| 580 | + grad_shard = grad |
| 581 | + |
| 582 | + grad_tp_local = grad_tp.to_local() if isinstance(grad_tp, dist.tensor.DTensor) else grad_tp |
| 583 | + |
| 584 | + assert torch.allclose(grad_shard.cpu(), grad_tp_local.cpu(), atol=1e-5, rtol=1e-5), ( |
| 585 | + f"Gradients differ for parameter {name}. Max diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().max().item()} | Min diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().min().item()}" |
| 586 | + ) |
| 587 | + |
| 588 | + dist.barrier() |
| 589 | + |
| 590 | + |
| 591 | +def _test_model_moe_forward_compile_impl(rank, mode): |
| 592 | + """Implementation for comparing TP and non-TP MoE model outputs with torch.compile.""" |
| 593 | + model_id = "Qwen/Qwen3-0.6B" |
| 594 | + |
| 595 | + torch.manual_seed(0) |
| 596 | + |
| 597 | + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) |
| 598 | + prompt = "Can I help" |
| 599 | + inputs = tokenizer(prompt, return_tensors="pt") |
| 600 | + |
| 601 | + model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto") |
| 602 | + dist.barrier() |
| 603 | + if mode == "eval": |
| 604 | + model_tp.eval() |
| 605 | + else: |
| 606 | + model_tp.train() |
| 607 | + |
| 608 | + device = model_tp.device |
| 609 | + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") |
| 610 | + model = model.to(device) |
| 611 | + |
| 612 | + if mode == "eval": |
| 613 | + model.eval() |
| 614 | + else: |
| 615 | + model.train() |
| 616 | + |
| 617 | + # Compile both models |
| 618 | + model.forward = torch.compile(model.forward) |
| 619 | + model_tp.forward = torch.compile(model_tp.forward) |
| 620 | + |
| 621 | + input_ids = inputs.input_ids.to(device) |
| 622 | + |
| 623 | + with torch.no_grad(): |
| 624 | + outputs = model(input_ids) |
| 625 | + logits = outputs.logits |
| 626 | + |
| 627 | + outputs_tp = model_tp(input_ids) |
| 628 | + logits_tp = outputs_tp.logits |
| 629 | + |
| 630 | + assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), ( |
| 631 | + f"TP and non-TP MoE model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}" |
| 632 | + ) |
| 633 | + |
| 634 | + dist.barrier() |
| 635 | + |
| 636 | + |
| 637 | +def _test_model_moe_save_impl(rank, tmp_dir): |
| 638 | + """Implementation of test_model_save for MoE model distributed execution.""" |
| 639 | + model_id = "Qwen/Qwen3-0.6B" |
| 640 | + |
| 641 | + if dist.is_initialized(): |
| 642 | + kwargs = {"tp_plan": "auto"} |
| 643 | + result_dir = f"{tmp_dir}/tp" |
| 644 | + else: |
| 645 | + kwargs = {} |
| 646 | + result_dir = f"{tmp_dir}/nontp" |
| 647 | + |
| 648 | + model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs) |
| 649 | + model.save_pretrained(result_dir) |
| 650 | + |
| 651 | + |
| 652 | +class TestTensorParallelMoeBase(TestCasePlus): |
| 653 | + """Base class for MoE tensor parallel tests. Subclasses must set nproc_per_node.""" |
| 654 | + |
| 655 | + nproc_per_node = None |
| 656 | + |
| 657 | + @require_torch_multi_accelerator |
| 658 | + def test_model_moe_forward_eval(self): |
| 659 | + """Test that TP and non-TP MoE models produce the same outputs in eval mode.""" |
| 660 | + if self.nproc_per_node is None: |
| 661 | + self.skipTest("nproc_per_node not set") |
| 662 | + if backend_device_count(torch_device) < self.nproc_per_node: |
| 663 | + self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") |
| 664 | + |
| 665 | + init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_impl)("eval") |
| 666 | + |
| 667 | + @require_torch_multi_accelerator |
| 668 | + def test_model_moe_forward_train(self): |
| 669 | + """Test that TP and non-TP MoE models produce the same outputs in train mode.""" |
| 670 | + if self.nproc_per_node is None: |
| 671 | + self.skipTest("nproc_per_node not set") |
| 672 | + if backend_device_count(torch_device) < self.nproc_per_node: |
| 673 | + self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") |
| 674 | + |
| 675 | + init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_impl)("train") |
| 676 | + |
| 677 | + @require_torch_multi_accelerator |
| 678 | + def test_model_moe_backward_pass(self): |
| 679 | + """Test that TP and non-TP MoE models produce the same gradients.""" |
| 680 | + if self.nproc_per_node is None: |
| 681 | + self.skipTest("nproc_per_node not set") |
| 682 | + if backend_device_count(torch_device) < self.nproc_per_node: |
| 683 | + self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") |
| 684 | + |
| 685 | + init_distributed(tp=self.nproc_per_node)(_test_model_moe_backward_pass_impl)() |
| 686 | + |
| 687 | + @require_torch_multi_accelerator |
| 688 | + def test_model_moe_forward_compile_eval(self): |
| 689 | + """Test that TP and non-TP MoE models produce the same outputs with torch.compile in eval mode.""" |
| 690 | + if self.nproc_per_node is None: |
| 691 | + self.skipTest("nproc_per_node not set") |
| 692 | + if backend_device_count(torch_device) < self.nproc_per_node: |
| 693 | + self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") |
| 694 | + |
| 695 | + init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_compile_impl)("eval") |
| 696 | + |
| 697 | + @require_torch_multi_accelerator |
| 698 | + def test_model_moe_forward_compile_train(self): |
| 699 | + """Test that TP and non-TP MoE models produce the same outputs with torch.compile in train mode.""" |
| 700 | + if self.nproc_per_node is None: |
| 701 | + self.skipTest("nproc_per_node not set") |
| 702 | + if backend_device_count(torch_device) < self.nproc_per_node: |
| 703 | + self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") |
| 704 | + |
| 705 | + init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_compile_impl)("train") |
| 706 | + |
| 707 | + @require_huggingface_hub_greater_or_equal("0.31.4") |
| 708 | + @require_torch_multi_accelerator |
| 709 | + def test_model_moe_save(self): |
| 710 | + """Test that TP MoE model can be saved and matches non-TP version.""" |
| 711 | + if self.nproc_per_node is None: |
| 712 | + self.skipTest("nproc_per_node not set") |
| 713 | + if backend_device_count(torch_device) < self.nproc_per_node: |
| 714 | + self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") |
| 715 | + |
| 716 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 717 | + # First run with TP (distributed) |
| 718 | + init_distributed(tp=self.nproc_per_node)(_test_model_moe_save_impl)(tmp_dir) |
| 719 | + |
| 720 | + # Then run without TP (non-distributed) |
| 721 | + _test_model_moe_save_impl(0, tmp_dir) |
| 722 | + |
| 723 | + non_tp_model_path = os.path.join(tmp_dir, "nontp") |
| 724 | + tp_model_path = os.path.join(tmp_dir, "tp") |
| 725 | + |
| 726 | + for filename in os.listdir(non_tp_model_path): |
| 727 | + if not filename.endswith(".safetensors"): |
| 728 | + continue |
| 729 | + |
| 730 | + non_tp_model = safe_open(os.path.join(non_tp_model_path, filename), device="cpu", framework="pt") |
| 731 | + tp_model = safe_open(os.path.join(tp_model_path, filename), device="cpu", framework="pt") |
| 732 | + for non_tp_key in non_tp_model.keys(): |
| 733 | + non_tp_tensor = non_tp_model.get_tensor(non_tp_key) |
| 734 | + tp_tensor = tp_model.get_tensor(non_tp_key) |
| 735 | + assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match" |
| 736 | + del non_tp_tensor, tp_tensor |
| 737 | + |
| 738 | + |
| 739 | +class TestTensorParallelMoe2Proc(TestTensorParallelMoeBase): |
| 740 | + """Test MoE tensor parallel with 2 processes.""" |
| 741 | + |
| 742 | + nproc_per_node = 2 |
| 743 | + |
| 744 | + |
| 745 | +class TestTensorParallelMoe4Proc(TestTensorParallelMoeBase): |
| 746 | + """Test MoE tensor parallel with 4 processes.""" |
477 | 747 |
|
478 | 748 | nproc_per_node = 4 |
0 commit comments