Skip to content

Adding basic elastic training (pause-and-resume) #1256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

lukebaumann
Copy link

@lukebaumann lukebaumann commented Jun 12, 2025

Elastic training with Pathways on Cloud allows jobs to react to failures without a full restart.

The elastic training mode in this PR is one form of Fast Resume where the main run is wrapped in a while loop and try-except blocks so that if a slice is lost mid-run, the run wait for the lost slice to rejoin the Pathways cluster, and recover from the most recent checkpoint.

This type of elastic training significantly re-initializes the run and relies on checkpoints to recover from. The primary benefit of this flavor of elastic training is avoiding the need to restart the entire JobSet on failure and instead only restart a single slice's worth of pods.

pyproject.toml Outdated
pathways-tpu = [
"axlearn[gcp]",
"jax==0.5.3", # must be >=0.4.19 for compat with v5p.
"pathwaysutils @ git+https://github.com/AI-Hypercomputer/pathways-utils", # For JAX+Pathways single-controller accelerator coordinator.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will 0.1.1 not work for the pause-and-resume use case?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Manager.wait_for_slices was added recently. I will cut a new release of pathwaysutils that includes it. Before I do that, I need to ensure there is not a newly introduced dependency on jax>0.5.3 within pathwaysutils.

This change was for verification.

0.1.1 will work if we add a similar wait_for_slices to axlearn (as is done in MaxText for the last couple months) but I rather not do that.

Copy link
Contributor

@Ethanlm Ethanlm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you include in the PR summary to demonstrate and validate the pause-and-resume feature works?

Thanks

trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder))
measurement.start_monitoring()
launch_trainer.run_trainer(trainer_config)
elastic_manager = manager.Manager()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make it enabled only in pathways environment?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we need pathwaysutils.initialize() before calling this

@shauryagup
Copy link

Note that we also need to set elasticity flags on the proxy server inside the jobset YAML for the complete e2e test to work

@lukebaumann lukebaumann marked this pull request as draft June 17, 2025 21:12
@lukebaumann
Copy link
Author

FYI this is a draft PR and I am still verifying functionality. At this point for this flavor of fast-resume, there should not be any major additional code changes.

except jax.errors.JaxRuntimeError as error:
if not elastic_manager.is_error_due_to_slice_down(error):
raise
ten_minutes = 10 * 60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why ten minutes?

trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
output = trainer.run(prng_key)
if False and FLAGS.jax_backend == "proxy":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this disabled?

# pylint: disable-next=import-error,import-outside-toplevel
from pathwaysutils.elastic import manager
elastic_manager = manager.Manager()
while True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want a max number of failed retries before we terminate? (E.g., if every retry terminates with an error in <1hr more than 5 times in a row, fail, or something?

@apghml
Copy link
Contributor

apghml commented Jun 19, 2025

Could you include in the PR summary to demonstrate and validate the pause-and-resume feature works?

Thanks

Is there an internal PR for this? (If so, DM could you me the link internally?) Maybe Luke can post a job the link there?

@Ethanlm
Copy link
Contributor

Ethanlm commented Jun 20, 2025

Could you include in the PR summary to demonstrate and validate the pause-and-resume feature works?
Thanks

Is there an internal PR for this? (If so, DM could you me the link internally?) Maybe Luke can post a job the link there?

I was able to verify this feature works. I will slack you an internal doc

@Ethanlm Ethanlm changed the title Adding basic elastic training Adding basic elastic training (pause-and-resume) Jun 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants