Skip to content

[FEATURE] AWS EFA Support for DeepEP Expert Parallelism - Performance Optimization #198

@dmvevents

Description

@dmvevents

Community Note

  • Please vote on this issue by adding a 👍 reaction to the original issue to help the community and maintainers prioritize this request
  • Please do not leave "+1" or other comments that do not add relevant new information or questions, they generate extra noise for issue followers and do not help prioritize the request
  • If you are interested in working on this issue or have submitted a pull request, please leave a comment

What is the outcome that you are trying to reach?

Enable high-performance Expert Parallelism on AWS P5 instances by optimizing the newly-developed AWS EFA support in DeepEP. Previously, DeepEP was completely incompatible with AWS infrastructure due to EFA's lack of NVSHMEM IBGDA support, creating a critical blocker for AI organizations deploying mixture-of-experts (MoE) models on AWS.

Our team has successfully developed an early prototype that bridges this fundamental incompatibility, but performance optimization is needed to achieve production-ready latency targets.

Current Status:

  • Breakthrough: DeepEP now runs on AWS P5 instances (previously impossible)
  • Compatibility: Three-tier optimization strategy implemented (Local → NVLink → EFA)

Describe the solution you would like

Phase 1: EFA Transport Optimization (This Issue)
Optimize the EFA compatibility layer to reduce communication latencies by 60-75%:

1. Eliminate EFA Fabric Stalls

// Current: Per-operation fencing causes EFA pipeline stalls
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, bytes, dst_rank, qp_id, lane_id, msg_idx);
nvshmem_fence(); // ❌ Kills EFA performance

// Optimized: Batched fencing for EFA efficiency  
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, bytes, dst_rank, qp_id, lane_id, msg_idx);
// ... process all tokens ...
if (lane_id == 0) nvshmem_fence(); // ✅ Single fence after batch

2. EFA-Optimized NVSHMEM Operations

// Replace single-threaded operations with warp-collective
- if (lane_id == 0) nvshmem_putmem_nbi(...);
+ nvshmemx_putmem_nbi_warp(...); // Better EFA fabric utilization

// Use EFA-aware waiting instead of busy polling
- while (ld_acquire_sys_global(flag) == 0); // Inefficient on EFA SRD
+ nvshmem_int_wait_until(flag, NVSHMEM_CMP_NE, 0); // EFA-optimized

3. AWS EFA Environment Tuning

# Critical EFA settings for P5 instances
- name: NVSHMEM_BOOTSTRAP
  value: "MPI"
- name: NVSHMEM_SYMMETRIC_SIZE  
  value: "4294967296"
- name: FI_EFA_USE_DEVICE_RDMA
  value: "1"
- name: FI_EFA_ENABLE_SHM_TRANSFER  # NEW: Optimize intra-node
  value: "1" 
- name: NVSHMEM_ENABLE_NIC_PE_MAPPING  # NEW: Better EFA PE mapping
  value: "1"
- name: NVSHMEM_DISABLE_CUDA_VMM  # NEW: EFA compatibility
  value: "1"

Technical Background

The EFA Challenge:

  • AWS EFA uses Scalable Reliable Datagram (SRD) transport
  • Lacks native GPU-initiated operation primitives required by NVSHMEM IBGDA
  • Previous assessment: "P5 support not possible"

Our Breakthrough Solution:

  1. Local Operations: Direct memory access for same-GPU communication
  2. NVLink P2P: Direct GPU-to-GPU for intra-node communication
  3. EFA Transport: Custom compatibility layer for inter-node communication

Current Performance Analysis:

  • Send operations: ~180-524 us (efficient)
  • Receive operations: ~12K-27K us (60-70x slower, indicates fabric inefficiency)
  • Root cause: EFA fabric stalls due to excessive synchronization

Describe alternatives you have considered

Alternative 1: Continue with Current Implementation

  • Pros: Works functionally
  • Cons: 60-70x performance gap makes production deployment impractical

Alternative 2: Pure NCCL Fallback

  • Pros: Battle-tested reliability
  • Cons: Not designed for Expert Parallel communication patterns, higher baseline latency

Alternative 3: Custom EFA Protocol Implementation

  • Pros: Maximum performance potential
  • Cons: Significant development effort, maintenance burden

Alternative 4: Wait for Native NVSHMEM EFA Support

  • Pros: Official support
  • Cons: Timeline uncertain, may not address EP-specific optimizations

Expected Impact

Performance Targets:

  • Dispatch receive: 12689 us → ~3000 us (75% reduction)
  • Combine receive: 27486 us → ~8000 us (70% reduction)
  • Overall MoE training: 2-3x throughput improvement

Business Impact:

  • Enables AI organizations to deploy MoE models on AWS P5 infrastructure
  • Unlocks AWS market for Expert Parallelism workloads
  • Provides cost-effective alternative to on-premise InfiniBand clusters

Implementation Plan

Phase 1 (This Issue): EFA Performance Optimization

  • Remove per-operation fencing in dispatch/combine kernels
  • Implement warp-collective NVSHMEM operations
  • Add EFA-optimized waiting primitives
  • Tune EFA environment variables

Phase 2 (Future): Advanced EFA Features

  • EFA memory registration optimization
  • Multi-rail EFA support for bandwidth scaling
  • EFA-aware memory allocation strategies

Testing Environment

Hardware:

  • AWS EC2 P5.48xlarge instances
  • 8x H100 GPUs per node
  • 3200 Gbps EFA network fabric
  • Multi-node Expert Parallel training

Validation Criteria:

  • Functional: All EP operations complete successfully
  • Performance: <5000 us dispatch receive, <10000 us combine receive
  • Scalability: Linear performance scaling across multiple P5 nodes
  • Compatibility: No regression on existing InfiniBand systems

Additional Context

This work represents a significant breakthrough for the AWS AI ecosystem. The EFA transport layer was considered fundamentally incompatible with NVSHMEM's memory semantics, making Expert Parallelism impossible on AWS. Our prototype proves this can be overcome, but performance optimization is critical for production adoption.

The current 60-70x performance gap between send and receive operations clearly indicates EFA fabric utilization issues rather than bandwidth limitations, making this optimization both necessary and achievable.

Files Modified:

  • csrc/kernels/internode_ll.cu - Low-latency kernel optimizations
  • `csrc/kernels/internode.cu - kernel optimizations
  • csrc/kernels/nvshmem_device.cuh - EFA compatibility layer
  • Environment variable configuration for P5 instances

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions