Skip to content

[XLA:GPU] Very low utilization for lax.ragged_all_to_all on 8x B200, but near optimal utilization cross-host #33386

@haoliuhl

Description

@haoliuhl

When benchmarking lax.ragged_all_to_all on B200s, I observe very low bandwidth utilization on a single host with NVLink, but near-peak utilization across hosts. I'm looking for confirmation on whether this is expected and for potential optimizations.

Repro

Code

Setting 1: single-host (8× B200, NVLink; maximum bandwidth 1.8 TB/s)

Command:
python ra2a_bandwidth.py --num-experts 256 --dtype bfloat16 --feature-size 7168 --tokens-per-device 4096

Observed:

Aggregate send-only: 16.42 GiB/s
Aggregate bidirectional: 32.85 GiB/s

This suggests ragged all-to-all bandwidth utilization is about 32.85 GiB/s out of 1.8 TB/s, which is roughly 2%.

Setting 2: multi-host (2 nodes, each 8× B200; maximum cross-node bandwidth 100 GB/s)

Command:
python ra2a_bandwidth.py --num-experts 256 --dtype bfloat16 --feature-size 7168 --tokens-per-device 4096

Observed:

Aggregate send-only: 50.06 GiB/s
Aggregate bidirectional: 100.13 GiB/s

Since cross-node bandwidth is the bottleneck, this indicates ragged all-to-all cross-host bandwidth utilization is nearly 100%.

Questions

  • Is low NVLink-domain utilization expected for lax.ragged_all_to_all?
  • Is the high performance cross-host utilization expected? Will this behavior hold as we scale beyond two hosts?

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions