Add reduce_scatter and All_gather benchmark#279
Add reduce_scatter and All_gather benchmark#279xiaohuguo2023 wants to merge 16 commits intoROCm:mainfrom
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR introduces a new example (example 22) demonstrating a complete multi-GPU tensor processing pipeline using Iris. The pipeline combines reduce-scatter, RMSNorm, FP8 quantization, and all-gather operations for distributed tensor processing on AMD GPUs.
Key changes:
- Implements distributed tensor processing with IRIS remote memory access operations
- Provides both a standalone script and comprehensive benchmark suite
- Includes validation against PyTorch reference implementations
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 8 comments.
| File | Description |
|---|---|
| examples/22_rs_rmsnorm_fp8quant_ag/reduce_scatter_rmsnorm_quant.py | Main implementation with Triton kernels for reduce-scatter, RMSNorm, FP8 quantization, and all-gather operations |
| examples/22_rs_rmsnorm_fp8quant_ag/benchmark.py | Comprehensive benchmarking suite with multi-process spawning, performance timing, and validation |
| examples/22_rs_rmsnorm_fp8quant_ag/README.md | Documentation with usage examples and pipeline description |
|
|
||
| 1. **Reduce-Scatter**: Sum tensors across all GPUs and distribute shards | ||
| 2. **RMSNorm**: Apply Root Mean Square normalization to each shard | ||
| 3. **FP8 Quantization**: Quantize to 8-bit floating point (optional) 4. **All-Gather**: Reconstruct the full tensor across all GPUs (optional) |
There was a problem hiding this comment.
Line 11 contains a formatting issue where item 4 runs onto the same line as item 3 without a line break. This should be split into separate lines for proper markdown list formatting.
| 3. **FP8 Quantization**: Quantize to 8-bit floating point (optional) 4. **All-Gather**: Reconstruct the full tensor across all GPUs (optional) | |
| 3. **FP8 Quantization**: Quantize to 8-bit floating point (optional) | |
| 4. **All-Gather**: Reconstruct the full tensor across all GPUs (optional) |
|
|
||
| max_val = input_tensor.abs().max().item() | ||
| scale = max(max_val / 448.0, 1e-8) | ||
| scale_tensor = torch.tensor([scale], device=device, dtype=torch.float32) |
There was a problem hiding this comment.
The run_quantize_fp8 function has a hardcoded num_warps=16 value on line 226, but it should use the user-configurable parameters passed to the function. According to the command-line arguments (lines 91, 451), the default for FP8 quantization should be 4, not 16. The main script also uses num_warps=4 for FP8 quantization (line 522). This function should accept and use the FP8-specific parameters like the benchmark loop does (lines 889-893, 921).
| output = torch.empty(M_shard, N, device=device, dtype=torch.float8_e4m3fn) | ||
| else: | ||
| output = torch.empty_like(input_tensor) | ||
|
|
There was a problem hiding this comment.
[nitpick] This function signature is extremely long with 14 parameters on a single line (extending beyond typical line length limits). Consider reformatting with one parameter per line or grouping related parameters for better readability and maintainability.
| final_num_warps = num_warps if num_warps is not None else 8 | ||
|
|
||
| # Set waves_per_eu (default to 2) | ||
| final_waves_per_eu = waves_per_eu if waves_per_eu is not None else 2 |
There was a problem hiding this comment.
The run_quantize_fp8 function signature is missing the FP8-specific tuning parameters (num_warps, num_stages, waves_per_eu) that are available in command-line arguments and used in the benchmarking section (lines 889-893). This inconsistency means users cannot configure these parameters when calling this function, limiting its flexibility. Consider adding these parameters with defaults matching the documented values (num_warps=4, num_stages=2, waves_per_eu=0).
| num_warps=8, | ||
| num_stages=3, | ||
| waves_per_eu=2, | ||
| ) |
There was a problem hiding this comment.
Variable result is not used.
|
|
||
| import argparse | ||
| import json | ||
| import os |
There was a problem hiding this comment.
Import of 'os' is not used.
| import os |
| import json | ||
| import os | ||
| import random | ||
| import sys |
There was a problem hiding this comment.
Import of 'sys' is not used.
| import sys |
| import os | ||
| import random | ||
| import sys | ||
| import time |
There was a problem hiding this comment.
Import of 'time' is not used.
| import time |
|
@xiaohuguo2023 do you want me or @mawad-amd to review/merge this? Or is this still WIP? |
it is ready |
| def main(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--num_rows", "--m", type=int, default=8192, help="Number of rows (M)") | ||
| parser.add_argument("--num_cols", "--n", type=int, default=7168, help="Number of columns (N)") | ||
| parser.add_argument("--num_ranks", "--world_size", type=int, default=8, help="Number of ranks") | ||
| parser.add_argument("--dtype", type=str, default="fp16", choices=["bf16", "fp16", "fp32"]) | ||
| parser.add_argument("--fp8_out", action="store_true", help="Enable FP8 quantization") | ||
| parser.add_argument("--eps", type=float, default=1e-6, help="RMSNorm epsilon") | ||
| parser.add_argument("--all_gather", action="store_true", help="All-gather at the end to reconstruct full M×N") | ||
| parser.add_argument("--verify", action="store_true", help="Verify against PyTorch reference") | ||
| args = parser.parse_args() | ||
|
|
||
| M = args.num_rows | ||
| N = args.num_cols | ||
| world_size = args.num_ranks | ||
|
|
||
| assert M % world_size == 0, f"M ({M}) must be divisible by world_size ({world_size})" | ||
| M_shard = M // world_size | ||
|
|
||
| if args.dtype == "bf16": | ||
| dtype = torch.bfloat16 | ||
| elif args.dtype == "fp16": | ||
| dtype = torch.float16 | ||
| else: | ||
| dtype = torch.float32 | ||
|
|
||
| # Set device | ||
| local_rank = int(os.environ.get("LOCAL_RANK", "0")) | ||
| torch.cuda.set_device(local_rank) | ||
| device = torch.device(f"cuda:{local_rank}") | ||
|
|
||
| cur_rank = int(os.environ.get("RANK", "0")) | ||
| actual_world_size = int(os.environ.get("WORLD_SIZE", str(world_size))) | ||
|
|
||
| if actual_world_size != world_size: | ||
| print(f"Warning: WORLD_SIZE ({actual_world_size}) != requested world_size ({world_size})") | ||
| world_size = actual_world_size | ||
| assert M % world_size == 0, f"M ({M}) must be divisible by world_size ({world_size})" | ||
| M_shard = M // world_size | ||
|
|
||
| print(f"Rank {cur_rank}/{world_size}: M={M}, N={N}, M_shard={M_shard}") | ||
|
|
||
| # ================================================================ | ||
| # Initialize PyTorch Distributed (required for IRIS) | ||
| # ================================================================ | ||
| if not dist.is_initialized(): | ||
| # Set up distributed environment | ||
| os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1") | ||
| os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") | ||
| os.environ["RANK"] = str(cur_rank) | ||
| os.environ["WORLD_SIZE"] = str(world_size) | ||
|
|
||
| dist.init_process_group(backend="gloo", rank=cur_rank, world_size=world_size) | ||
|
|
||
| # ================================================================ | ||
| # Initialize IRIS for distributed communication | ||
| # ================================================================ | ||
| heap_size = 1 << 28 # 256MB | ||
| shmem = iris.iris(heap_size) | ||
|
|
||
| # Get heap base addresses for all ranks | ||
| heap_bases = shmem.get_heap_bases() | ||
|
|
||
| # ================================================================ | ||
| # Create input: Each rank has M×N tensor (same position, different values) | ||
| # Must be in IRIS shared memory for remote access via iris.load | ||
| # ================================================================ | ||
| torch.manual_seed(42 + cur_rank) # Different seed per rank for different values | ||
| local_input_temp = torch.randn(M, N, device=device, dtype=dtype) * (cur_rank + 1) | ||
|
|
||
| # Allocate in IRIS shared memory | ||
| local_input = shmem.empty((M, N), dtype=dtype) | ||
| local_input.copy_(local_input_temp) | ||
| del local_input_temp | ||
|
|
||
| print(f"Rank {cur_rank}: Input shape: {local_input.shape}") | ||
|
|
||
| # Barrier to ensure all ranks have allocated their input tensors | ||
| shmem.barrier() | ||
|
|
||
| # Default parameters (can be overridden via tuning) | ||
| BLOCK_M = 16 | ||
| BLOCK_N = 64 | ||
| GROUP_SIZE_M = 8 | ||
| # MI350 | ||
| NUM_SMS = 256 | ||
|
|
||
| # ================================================================ | ||
| # Step 1: Reduce-Scatter along M dimension | ||
| # Sum all M×N tensors and each rank gets (M/world_size)×N piece | ||
| # ================================================================ | ||
| print(f"Rank {cur_rank}: Step 1 - Reduce-Scatter along M dimension") | ||
|
|
||
| # Allocate output buffer in IRIS shared memory (must be accessible to all ranks) | ||
| reduced_shard = shmem.zeros((M_shard, N), dtype=dtype) | ||
|
|
||
| grid_rs = (NUM_SMS,) | ||
|
|
||
| # Call kernel once - it will use iris.load() to pull data from all source ranks | ||
| reduce_scatter_m_kernel[grid_rs]( | ||
| local_input, | ||
| reduced_shard, | ||
| M, | ||
| M_shard, | ||
| N, | ||
| local_input.stride(0), | ||
| local_input.stride(1), | ||
| reduced_shard.stride(0), | ||
| reduced_shard.stride(1), | ||
| cur_rank, | ||
| world_size, | ||
| heap_bases, | ||
| BLOCK_M=BLOCK_M, | ||
| BLOCK_N=BLOCK_N, | ||
| GROUP_SIZE_M=GROUP_SIZE_M, | ||
| NUM_SMS=NUM_SMS, | ||
| num_warps=16, # Tuned for better performance | ||
| num_stages=4, | ||
| waves_per_eu=4, | ||
| ) | ||
|
|
||
| # Synchronize to ensure all ranks have completed their loads and reductions | ||
| torch.cuda.synchronize() | ||
| shmem.barrier() | ||
|
|
||
| print(f"Rank {cur_rank}: Reduce-scatter complete, shard shape: {reduced_shard.shape}") | ||
|
|
||
| # ================================================================ | ||
| # Step 2: RMSNorm on (M_shard)×N with FULL N dimension | ||
| # ================================================================ | ||
| print(f"Rank {cur_rank}: Step 2 - RMSNorm on (M_shard)×N") | ||
|
|
||
| gamma = torch.ones(N, device=device, dtype=dtype) | ||
| rmsnorm_output = torch.empty_like(reduced_shard) | ||
| rsigma = torch.empty(M_shard, device=device, dtype=dtype) | ||
|
|
||
| # AITer RMSNorm configuration | ||
| # Note: Tuning found BLOCK_SIZE=1024 optimal for N=7168 (avoid VGPR spills with larger sizes) | ||
| BLOCK_SIZE = 1024 | ||
| USE_BLOCKED = False # Tuned: non-blocked mode is faster for moderate N | ||
| NUM_PRGMS = M_shard # Full parallelism: each program processes one row | ||
|
|
||
| aiter_rmsnorm[(M_shard,)]( | ||
| reduced_shard, | ||
| rmsnorm_output, | ||
| gamma, | ||
| rsigma, | ||
| reduced_shard.stride(0), | ||
| rmsnorm_output.stride(0), | ||
| M_shard, | ||
| N, | ||
| args.eps, | ||
| BLOCK_SIZE=BLOCK_SIZE, | ||
| USE_BLOCKED=USE_BLOCKED, | ||
| NUM_PRGMS=NUM_PRGMS, | ||
| num_warps=8, # Tuned for better occupancy | ||
| waves_per_eu=2, | ||
| ) | ||
|
|
||
| print(f"Rank {cur_rank}: RMSNorm complete, output shape: {rmsnorm_output.shape}") | ||
|
|
||
| # ================================================================ | ||
| # Step 3: FP8 Quantization | ||
| # ================================================================ | ||
| if args.fp8_out: | ||
| print(f"Rank {cur_rank}: Step 3 - FP8 Quantization") | ||
|
|
||
| # Compute scale | ||
| max_val = rmsnorm_output.abs().max() | ||
| scale = (max_val / 448.0).clamp(min=1e-8) | ||
| scale_tensor = torch.tensor([scale], device=device, dtype=torch.float32) | ||
|
|
||
| # Quantize | ||
| if hasattr(torch, "float8_e4m3fn"): | ||
| quantized_output = torch.empty_like(rmsnorm_output, dtype=torch.float8_e4m3fn) | ||
| else: | ||
| quantized_output = torch.empty_like(rmsnorm_output) | ||
|
|
||
| # FP8 quantization uses medium tile sizes | ||
| FP8_BLOCK_M = 64 | ||
| FP8_BLOCK_N = 64 | ||
| grid_quant = (triton.cdiv(M_shard, FP8_BLOCK_M), triton.cdiv(N, FP8_BLOCK_N)) | ||
|
|
||
| quantize_fp8_kernel[grid_quant]( | ||
| rmsnorm_output, | ||
| quantized_output, | ||
| scale_tensor, | ||
| M_shard, | ||
| N, | ||
| rmsnorm_output.stride(0), | ||
| rmsnorm_output.stride(1), | ||
| quantized_output.stride(0), | ||
| quantized_output.stride(1), | ||
| BLOCK_M=FP8_BLOCK_M, | ||
| BLOCK_N=FP8_BLOCK_N, | ||
| num_warps=4, | ||
| num_stages=2, | ||
| waves_per_eu=2, | ||
| ) | ||
|
|
||
| final_shard = quantized_output | ||
| print( | ||
| f"Rank {cur_rank}: Quantization complete, shape: {quantized_output.shape}, dtype: {quantized_output.dtype}" | ||
| ) | ||
| else: | ||
| final_shard = rmsnorm_output | ||
| print(f"Rank {cur_rank}: No quantization, final shard shape: {final_shard.shape}") | ||
|
|
||
| # ================================================================ | ||
| # Step 4 (Optional): All-Gather along M dimension | ||
| # ================================================================ | ||
| if args.all_gather: | ||
| print(f"Rank {cur_rank}: Step 4 - All-Gather along M dimension") | ||
|
|
||
| # Determine output dtype | ||
| if args.fp8_out and hasattr(torch, "float8_e4m3fn"): | ||
| out_dtype = torch.float8_e4m3fn | ||
| else: | ||
| out_dtype = dtype | ||
|
|
||
| # Allocate output in IRIS shared memory | ||
| full_output = shmem.zeros((M, N), dtype=out_dtype) | ||
|
|
||
| grid_ag = (NUM_SMS,) | ||
|
|
||
| # All-gather uses similar parameters to reduce-scatter | ||
| AG_BLOCK_M = 64 | ||
| AG_BLOCK_N = 64 | ||
|
|
||
| all_gather_m_kernel[grid_ag]( | ||
| final_shard, | ||
| full_output, | ||
| M, | ||
| M_shard, | ||
| N, | ||
| final_shard.stride(0), | ||
| final_shard.stride(1), | ||
| full_output.stride(0), | ||
| full_output.stride(1), | ||
| cur_rank, | ||
| world_size, | ||
| heap_bases, | ||
| BLOCK_M=AG_BLOCK_M, | ||
| BLOCK_N=AG_BLOCK_N, | ||
| GROUP_SIZE_M=GROUP_SIZE_M, | ||
| NUM_SMS=NUM_SMS, | ||
| num_warps=8, | ||
| num_stages=3, | ||
| waves_per_eu=2, | ||
| ) | ||
|
|
||
| # Synchronize to ensure all ranks have completed their puts | ||
| torch.cuda.synchronize() | ||
|
|
||
| print(f"Rank {cur_rank}: All-gather complete, full output shape: {full_output.shape}") | ||
| result = full_output | ||
| else: | ||
| result = final_shard | ||
| print(f"Rank {cur_rank}: Skipping all-gather, result shape: {result.shape}") | ||
|
|
||
| # ================================================================ | ||
| # Verification | ||
| # ================================================================ | ||
| if args.verify and cur_rank == 0: | ||
| print("\n" + "=" * 60) | ||
| print("Verification against PyTorch reference") | ||
| print("=" * 60) | ||
|
|
||
| import torch.nn as nn | ||
|
|
||
| # Reference computation | ||
| torch.manual_seed(42) | ||
| ref_tensors = [] | ||
| for i in range(world_size): | ||
| torch.manual_seed(42 + i) | ||
| tensor = torch.randn(M, N, device=device, dtype=dtype) * (i + 1) | ||
| ref_tensors.append(tensor) | ||
|
|
||
| # Pointwise reduce (sum) | ||
| ref_reduced = torch.zeros(M, N, device=device, dtype=dtype) | ||
| for tensor in ref_tensors: | ||
| ref_reduced += tensor | ||
|
|
||
| print(f"Reference reduced sum: {ref_reduced.sum(dtype=torch.float32):.4f}") | ||
|
|
||
| # Extract this rank's shard | ||
| start_row = cur_rank * M_shard | ||
| end_row = (cur_rank + 1) * M_shard | ||
| ref_shard = ref_reduced[start_row:end_row, :] | ||
|
|
||
| # Compare reduce-scatter result | ||
| rs_diff = torch.abs(ref_shard - reduced_shard) | ||
| print(f"Reduce-scatter max diff: {rs_diff.max().item():.8f}") | ||
|
|
||
| if rs_diff.max().item() < 1e-5: | ||
| print("✅ Reduce-scatter verification PASSED") | ||
| else: | ||
| print("❌ Reduce-scatter verification FAILED") | ||
|
|
||
| # RMSNorm | ||
| rmsnorm_layer = nn.RMSNorm(N, eps=args.eps, device=device, dtype=dtype) | ||
| ref_normed = rmsnorm_layer(ref_shard) | ||
|
|
||
| print(f"\nReference RMSNorm sum: {ref_normed.sum(dtype=torch.float32):.4f}") | ||
| print(f"Triton RMSNorm sum: {rmsnorm_output.sum(dtype=torch.float32):.4f}") | ||
|
|
||
| rms_diff = torch.abs(ref_normed - rmsnorm_output) | ||
| print(f"RMSNorm max diff: {rms_diff.max().item():.8f}") | ||
| print(f"RMSNorm mean diff: {rms_diff.mean().item():.8f}") | ||
|
|
||
| if rms_diff.max().item() < 1e-2: | ||
| print("✅ RMSNorm verification PASSED") | ||
| else: | ||
| print("❌ RMSNorm verification FAILED") | ||
|
|
||
| print(f"\nRank {cur_rank}: Pipeline completed successfully!") | ||
|
|
||
| # Cleanup | ||
| if dist.is_initialized(): | ||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
There was a problem hiding this comment.
Do we need this main function where kernels are defined?
There was a problem hiding this comment.
My main feedback on this is the benchmark.py file needs to be written like some of the other benchmark.py files. You can use iris.do_bench() instead of duplicating the code into warm-up -> benchmark, etc. That way we can promote some code reuse and make the file more readable. For example;
- https://github.com/ROCm/iris/blob/main/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py#L166 This is where the experiment is defined from start to end.
- This is where it gets timed: https://github.com/ROCm/iris/blob/main/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py#L274 Unlike triton.do_bench, it is based on iterations, so you can do warmup in there as well.
This example implements alternative All_reduce across multiple GPUs: