Skip to content

Added PjrtClient::UpdateGlobalProcessInfo method. #28011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 16, 2025

Conversation

copybara-service[bot]
Copy link

Added PjrtClient::UpdateGlobalProcessInfo method.

Overview

Recall that a multi-controller JAX program involves multiple PjRt
clients running across multiple processes. These clients perform collective
operations, like AllReduce and AllGather, to execute a distributed program.

This commit adds an UpdateGlobalProcessInfo method that updates a client with
information about all processes. For example, if there are four processes, we
might call UpdateGlobalProcessInfo on process 0 with the information that
process 0, 1, and 2 are healthy but process 3 is dead.

Motivation

I am currently working on making multi-controller JAX fault tolerant. Part of
this work involves cancelling collectives where one of the participants of the
collective has failed. The UpdateGlobalProcessInfo method will allow a PjRt
client to notice when a peer process has failed and abort any collectives it is
performing with this failed peer.

Previously, I was using the coordination service to
determine when processes failed, but PjRt clients executed via C plugins do not
have access to the coordination service.

Future Work

This commit introduces the new UpdateGlobalProcessInfo method and pipes it
through the C++ sandwich, but it doesn't actually implement it yet. Nothing is
calling the new method either. These things will come in future changes.

Alternatives

Rather than introducing a new PjRtClient method, I could have piped a
coordination service client through the C plugin API into PjRtClients.
However, this would be very complicated. The code to pipe a key-value store
client is complicated
, and the API for the coordination service client is
significantly more complex.

I could shoehorn the new API into the existing key-value store. For example, I
could establish a convention that the state of every process i is stored in
some special key process_{i} in the key-value store. This felt roundabout.

@copybara-service copybara-service bot force-pushed the test_773072434 branch 3 times, most recently from d278fa8 to 7746175 Compare July 7, 2025 21:57
@copybara-service copybara-service bot force-pushed the test_773072434 branch 4 times, most recently from f5985a7 to 12d82b6 Compare July 8, 2025 20:19
## Overview

Recall that a [multi-controller JAX program][mcjax] involves multiple PjRt
clients running across multiple processes. These clients perform collective
operations, like AllReduce and AllGather, to execute a distributed program.

This commit adds an `UpdateGlobalProcessInfo` method that updates a client with
information about all processes. For example, if there are four processes, we
might call `UpdateGlobalProcessInfo` on process 0 with the information that
process 0, 1, and 2 are healthy but process 3 is dead.

## Motivation

I am currently working on making multi-controller JAX fault tolerant. Part of
this work involves cancelling collectives where one of the participants of the
collective has failed. The `UpdateGlobalProcessInfo` method will allow a PjRt
client to notice when a peer process has failed and abort any collectives it is
performing with this failed peer.

Previously, I was using the [coordination service][coordination_service] to
determine when processes failed, but PjRt clients executed via C plugins do not
have access to the coordination service.

## Future Work

This commit introduces the new `UpdateGlobalProcessInfo` method and pipes it
through the C++ sandwich, but it doesn't actually implement it yet. Nothing is
calling the new method either. These things will come in future changes.

## Alternatives

Rather than introducing a new `PjRtClient` method, I could have piped a
coordination service client through the C plugin API into `PjRtClient`s.
However, this would be very complicated. The [code to pipe a key-value store
client is complicated][kvs], and the API for the coordination service client is
significantly more complex.

I could shoehorn the new API into the existing key-value store. For example, I
could establish a convention that the state of every process i is stored in
some special key `process_{i}`  in the key-value store. This felt roundabout.

[mcjax]: https://docs.jax.dev/en/latest/multi_process.html
[coordination_service]: https://github.com/openxla/xla/tree/main/xla/tsl/distributed_runtime/coordination
[kvs]: https://github.com/openxla/xla/blob/f5813dd522b9ef28eb58e638c45e60430b686d1a/xla/pjrt/c/pjrt_c_api.h#L324-L406

PiperOrigin-RevId: 783799013
@copybara-service copybara-service bot merged commit d4292b4 into main Jul 16, 2025
@copybara-service copybara-service bot deleted the test_773072434 branch July 16, 2025 17:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant