-
Notifications
You must be signed in to change notification settings - Fork 452
[Diloco] Diloco trainer #2920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Diloco] Diloco trainer #2920
Conversation
246893d to
364cf4e
Compare
364cf4e to
cdba187
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
| _BASE_CONFIG_PATH = os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "configs", "base.yml") | ||
|
|
||
|
|
||
| class SimpleNNXModel(nnx.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use SimpleLayer
khatwanimohit
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add train_compile tests for Diloco
| eval_step, | ||
| eval_data_iterator, | ||
| params_shardings, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe move logics out of train.py
| with jax.profiler.StepTraceAnnotation("train", step_num=step): | ||
| example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager) | ||
| if config.enable_diloco: | ||
| example_batch = diloco.reshape_first_axis_with_diloco(config.num_diloco_replicas, example_batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually you could consider move this logic and the config.input_data_sharding_logical_axes change in sharding.py to MaxText.data_loader, e.g.
def load_next_batch(**args):
if enable_diloco:
....
else:
original logics...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in this case, along with the previous suggested change, you don't need to change anything in train.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was making this change I realized we are calling sharding.maybe_shard_with_name twice.
first inside in data_loader.load_next_batch and the secondly after data_loader.load_next_batch is called in train.py
@NuojCheng can you double check if this is true and then I can remove one of them along with this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes please remove the one in train.py. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have made this change in #2926
Description
This PR introduces support for Distributed Low-Communication (DiLoCo) training in MaxText. It implements both standard DiLoCo, enabling efficient model training across disjoint clusters ("islands") by synchronizing gradients infrequently via an outer optimizer.
Key Changes
and communication synchronization using drjax.
step when enable_diloco is active. This includes handling data reshaping for multiple replicas.
Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.
Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.