-
Notifications
You must be signed in to change notification settings - Fork 696
Description
When benchmarking lax.ragged_all_to_all on B200s, I observe very low bandwidth utilization on a single host with NVLink, but near-peak utilization across hosts. I'm looking for confirmation on whether this is expected and for potential optimizations.
Repro
Setting 1: single-host (8× B200, NVLink; maximum bandwidth 1.8 TB/s)
Command:
python ra2a_bandwidth.py --num-experts 256 --dtype bfloat16 --feature-size 7168 --tokens-per-device 4096
Observed:
Aggregate send-only: 16.42 GiB/s
Aggregate bidirectional: 32.85 GiB/s
This suggests ragged all-to-all bandwidth utilization is about 32.85 GiB/s out of 1.8 TB/s, which is roughly 2%.
Setting 2: multi-host (2 nodes, each 8× B200; maximum cross-node bandwidth 100 GB/s)
Command:
python ra2a_bandwidth.py --num-experts 256 --dtype bfloat16 --feature-size 7168 --tokens-per-device 4096
Observed:
Aggregate send-only: 50.06 GiB/s
Aggregate bidirectional: 100.13 GiB/s
Since cross-node bandwidth is the bottleneck, this indicates ragged all-to-all cross-host bandwidth utilization is nearly 100%.
Questions
- Is low NVLink-domain utilization expected for
lax.ragged_all_to_all? - Is the high performance cross-host utilization expected? Will this behavior hold as we scale beyond two hosts?