-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Core Data Structures & Communication Primitives for Tensor Parallel for Keras #21697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Core Data Structures & Communication Primitives for Tensor Parallel for Keras #21697
Conversation
Summary of ChangesHello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request lays the fundamental groundwork for introducing backend-agnostic auto-sharding and Tensor Parallelism into Keras 3.0. It establishes a modular and extensible architecture by defining core data structures, abstracting distributed backend functionalities, and providing high-level communication primitives. This initial set of changes is crucial for enabling future capabilities that will allow users to train very large models across multiple devices with significantly simplified code. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request lays a solid foundation for tensor parallelism in Keras by introducing backend-agnostic abstractions for distributed operations and core data structures for sharding. The overall design is well-structured, separating concerns between backend-specific implementations, communication primitives, and configuration. However, there are several areas that need attention, particularly regarding the correctness of some backend implementations (especially JAX), placeholder logic, API clarity, and code consistency. Addressing these points will strengthen the foundation and prevent issues in future development.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21697 +/- ##
========================================
Coverage 82.59% 82.60%
========================================
Files 572 574 +2
Lines 58327 58596 +269
Branches 9131 9165 +34
========================================
+ Hits 48177 48403 +226
- Misses 7818 7854 +36
- Partials 2332 2339 +7
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a few initial comments and questions during my first look.
To make the review more manageable, I propose we split this change up. At almost 1,800 lines, the current change is quite difficult to review properly. What do you think about limiting this PR to just the JAX backend, and introducing the others in subsequent, smaller PRs?
…uhana/keras into Tensor_parallel_keras
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR!
Some high level comments:
- Out of context, it's really hard for me to understand why these abstractions are needed for Tensor Parallel.
- Why do we need all these primitives?
- Why do we need 3 layers of abstraction for the same concepts: the
communications
layer, thestate_actions
layer and thekeras.distributed.get_communication_ops
layer? Can we just have one?
- These abstraction look Torch-like and not JAX-like. On JAX you never have to manually split and do an all-gather, you simply shard. You never have to explicitly have to do a "collective sum". You just do a sum, and if the tensors are sharded, it will magically do all the needed collectives for you. So it's unclear to me why any of these are needed for JAX.
- I wouldn't export these symbols that you added to
keras.distributed
, I don't think they are needed. What we'll expose is the "Tensor Parallel" API. - For the better or worse, we don't do type annotations in Keras. And unfortunately, mixing code with type annotations with code without type annotation doesn't work well. It's better to not have any type annotations at all.
This Pull Request introduces the foundational components for a new, backend-agnostic auto-sharding system in Keras, specifically designed for tensor parallelism. It establishes the core data structures and the JAX-specific implementation of communication primitives.
The most significant part of this PR is the creation of a generic, backend-agnostic system for defining sharding plans. This logic resides in keras/src/distribution/tensor_parallel/tensor_layout.py.
LayoutAction: An abstract base class that serves as a blueprint for any sharding operation. It defines a contract that requires a forward method (call) to apply the sharding and a reverse method (undo) to reconstruct the original tensor.
Split: The first concrete implementation of LayoutAction. This class contains the logic to manually slice a tensor along a specific dimension. It is robustly designed to handle cases where the tensor size is not perfectly divisible by the number of devices.
Note on Design: This class performs explicit, manual slicing rather than using a declarative approach. This imperative design is intentional and crucial for ensuring multi-backend compatibility. While a backend like JAX could infer this slicing automatically, other frameworks like PyTorch require these explicit instructions to distribute data correctly. This universal approach allows the core Keras logic to remain backend-agnostic.
LayoutMap: A simple data class that acts as a container for the complete sharding plan. It is designed to hold the LayoutAction rules for both the model's weights (state_rules) and its intermediate outputs (output_rules).
This PR provides the first backend-specific implementation of the required distributed communication primitives, targeting JAX. This code is located in keras/src/backend/jax/distributed_backend.py.
get_device_info() and is_multi_device_capable(): Utility functions to query the JAX environment for available devices.
get_communication_ops(): This is the key function that provides the necessary tools for cross-device communication. It returns a dictionary containing backend-agnostic names (all_reduce, all_gather) mapped to their efficient, native JAX implementations (jax.lax.psum and jax.lax.all_gather). This acts as a translation layer, allowing the core Keras logic to use these operations without needing to know it's running on JAX.
Design Document: Autosharding for Keras
The full code of Tensor parallel for Keras has been devided into 4 PRs, this is the first PR for the same.