Skip to content

Cannot see multiple GPUs when using Slurm (with proposed fix) #865

@gabeweisz

Description

@gabeweisz

When using MaxText with slurm, our jobs only see one GPU per node because jax.distributed assumes one GPU per process when used with slurm (see the Jax docs.

This behavior can be overridden by passing local_device_ids to jax.distributed.initialize, so one way to fix this is to change initialize_jax_for_gpu as follows (max_utils.py line 243):
def initialize_jax_for_gpu():
"""Jax distributed initialize for GPUs."""
if os.environ.get("JAX_COORDINATOR_IP") is not None:
coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP"))
coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT"))
device_list = {os.getenv("CUDA_VISIBLE_DEVICES")}
if len(device_list) == 0:
device_list = None
jax.distributed.initialize(
coordinator_address=f"{coordinator_ip}:{coordinator_port}",
num_processes=int(os.getenv("NNODES")),
process_id=int(os.getenv("NODE_RANK")),
local_device_ids=device_list,
)
max_logging.log(f"JAX global devices: {jax.devices()}")

This can probably use more robust error handling.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions