-
Notifications
You must be signed in to change notification settings - Fork 505
Encapsulate Mesh invariants #8882
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
9b8e2db
to
e4df499
Compare
636eb45
to
ddb791a
Compare
cc @lsy323 |
e746546
to
c0647a4
Compare
Switched the mesh shape' size (np.prod) assert to the mesh shape' length, since partition spec (0, 1) is still accurate to represent a single device with (1,1) mesh shape (size = 1, len = 2). |
c0647a4
to
3aac5ce
Compare
We also support an empty tuple for the partition spec for scalar values, as well as specifying ('x', None) for 1D mesh shapes - as long as the tensor shape matches it. These semantics seem quite relaxed (we could check with JAX), but I don't see a major reason to change that, especially since it'd be making it more restrict for existing use cases. @tengyifei in this case, do you see any concern with removing the added expected constraint in https://github.com/pytorch/xla/pull/8882/files#diff-3dcff2b7395bbf1f8a09170775388ef686a1e5f593b3c3889996d78c93a9c394R582 (L582) - I updated the PR, so see below:
Technically, it was meant as a way to enforce a constraint with the specified partition spec and the provided mesh - but that is really an added delta. We just moved the global device check to the mesh, since it's more suitable there (mesh should enforce that invariant, instead of each of the individual sharding's annotations). |
3aac5ce
to
b59039c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert len(partition_spec) == len(mesh.shape()), \
f"Partition spec length ({len(partition_spec)}) should be equal to the mesh shape dimensions ({len(mesh.shape())})."
This check wouldn't make sense. I can shard a 2D tensor over a 4D mesh by e.g. sharding each tensor dim over two mesh axes.
It sounds like you're removing it, and I don't see this check in the latest commit, so that sounds reasonable to me.
However, there's a failed HybridMesh
test. That test failed because it didn't mock out the global_runtime_device_count
method. We can probably fix it by mocking out this function to return the desired number of devices for the test.
b59039c
to
993d46f
Compare
Absolutely, agree - thanks. Added the mock, I'll re-request review once the CI succeeds. |
This PR improves the input validation in the Mesh class by making error messages more descriptive. The changes make it easier to debug mesh configuration issues and provide clearer feedback to users. In addition, it moves the invariant encapsulation to the Mesh constructor, as opposed to mark_sharding. This makes it so that we only validate the sharding annotations with respect to the specified Mesh, as opposed to having standalone checks against the number of global participating devices.