Skip to content

Conversation

@khatwanimohit
Copy link
Collaborator

@khatwanimohit khatwanimohit commented Jan 8, 2026

Description

This PR introduces support for Distributed Low-Communication (DiLoCo) training in MaxText. It implements both standard DiLoCo, enabling efficient model training across disjoint clusters ("islands") by synchronizing gradients infrequently via an outer optimizer.

Key Changes

  • Core Logic: Added src/MaxText/diloco.py, which implements the DiLoCoTrainState, inner/outer optimization steps,
    and communication synchronization using drjax.
  • Training Loop Integration: Modified src/MaxText/train.py to initialize the DiLoCo state and adapt the training
    step when enable_diloco is active. This includes handling data reshaping for multiple replicas.
  • Sharding & Configuration:
    • Updated src/MaxText/sharding.py to support a hierarchical "diloco" sharding axis.
    • Added new flags (e.g., enable_diloco, num_diloco_replicas, diloco_outer_optimizer) to base.yml and types.py.
  • Dependencies: Added drjax to the project requirements.
  • Testing: Added comprehensive unit tests in tests/diloco_test.py.

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@khatwanimohit khatwanimohit changed the title Mohit/diloco trainer [Diloco] Diloco trainer Jan 8, 2026
@codecov
Copy link

codecov bot commented Jan 8, 2026

Codecov Report

❌ Patch coverage is 0% with 104 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/diloco.py 0.00% 68 Missing ⚠️
src/MaxText/train.py 0.00% 22 Missing ⚠️
src/MaxText/train_utils.py 0.00% 7 Missing ⚠️
src/MaxText/sharding.py 0.00% 4 Missing ⚠️
src/MaxText/maxtext_utils.py 0.00% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

_BASE_CONFIG_PATH = os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "configs", "base.yml")


class SimpleNNXModel(nnx.Module):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use SimpleLayer

Copy link
Collaborator Author

@khatwanimohit khatwanimohit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add train_compile tests for Diloco

eval_step,
eval_data_iterator,
params_shardings,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe move logics out of train.py

with jax.profiler.StepTraceAnnotation("train", step_num=step):
example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager)
if config.enable_diloco:
example_batch = diloco.reshape_first_axis_with_diloco(config.num_diloco_replicas, example_batch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually you could consider move this logic and the config.input_data_sharding_logical_axes change in sharding.py to MaxText.data_loader, e.g.

def load_next_batch(**args):
   if enable_diloco:
      ....
  else:
    original logics...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case, along with the previous suggested change, you don't need to change anything in train.py

Copy link
Collaborator Author

@khatwanimohit khatwanimohit Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was making this change I realized we are calling sharding.maybe_shard_with_name twice.
first inside in data_loader.load_next_batch and the secondly after data_loader.load_next_batch is called in train.py

@NuojCheng can you double check if this is true and then I can remove one of them along with this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please remove the one in train.py. Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made this change in #2926

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants