-
Notifications
You must be signed in to change notification settings - Fork 346
Adding basic elastic training (pause-and-resume) #1256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
pyproject.toml
Outdated
pathways-tpu = [ | ||
"axlearn[gcp]", | ||
"jax==0.5.3", # must be >=0.4.19 for compat with v5p. | ||
"pathwaysutils @ git+https://github.com/AI-Hypercomputer/pathways-utils", # For JAX+Pathways single-controller accelerator coordinator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will 0.1.1 not work for the pause-and-resume use case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Manager.wait_for_slices
was added recently. I will cut a new release of pathwaysutils
that includes it. Before I do that, I need to ensure there is not a newly introduced dependency on jax>0.5.3
within pathwaysutils
.
This change was for verification.
0.1.1
will work if we add a similar wait_for_slices
to axlearn
(as is done in MaxText for the last couple months) but I rather not do that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you include in the PR summary to demonstrate and validate the pause-and-resume feature works?
Thanks
trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder)) | ||
measurement.start_monitoring() | ||
launch_trainer.run_trainer(trainer_config) | ||
elastic_manager = manager.Manager() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make it enabled only in pathways environment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we need pathwaysutils.initialize()
before calling this
Note that we also need to set elasticity flags on the proxy server inside the jobset YAML for the complete e2e test to work |
FYI this is a draft PR and I am still verifying functionality. At this point for this flavor of fast-resume, there should not be any major additional code changes. |
Added guards to only use fast-resume if the proxy backend is used.
…ntiation of the trainer and creation of the PRNGKey inside the elastic loop
…rease checkpoint frequency
except jax.errors.JaxRuntimeError as error: | ||
if not elastic_manager.is_error_due_to_slice_down(error): | ||
raise | ||
ten_minutes = 10 * 60 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why ten minutes?
trainer: SpmdTrainer = trainer_config.instantiate(parent=None) | ||
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed) | ||
output = trainer.run(prng_key) | ||
if False and FLAGS.jax_backend == "proxy": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this disabled?
# pylint: disable-next=import-error,import-outside-toplevel | ||
from pathwaysutils.elastic import manager | ||
elastic_manager = manager.Manager() | ||
while True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want a max number of failed retries before we terminate? (E.g., if every retry terminates with an error in <1hr more than 5 times in a row, fail, or something?
Is there an internal PR for this? (If so, DM could you me the link internally?) Maybe Luke can post a job the link there? |
I was able to verify this feature works. I will slack you an internal doc |
Elastic training with Pathways on Cloud allows jobs to react to failures without a full restart.
The elastic training mode in this PR is one form of Fast Resume where the main run is wrapped in a while loop and try-except blocks so that if a slice is lost mid-run, the run wait for the lost slice to rejoin the Pathways cluster, and recover from the most recent checkpoint.
This type of elastic training significantly re-initializes the run and relies on checkpoints to recover from. The primary benefit of this flavor of elastic training is avoiding the need to restart the entire JobSet on failure and instead only restart a single slice's worth of pods.