Encapsulate Mesh invariants#8882
Conversation
9b8e2db to
e4df499
Compare
636eb45 to
ddb791a
Compare
|
cc @lsy323 |
e746546 to
c0647a4
Compare
|
Switched the mesh shape' size (np.prod) assert to the mesh shape' length, since partition spec (0, 1) is still accurate to represent a single device with (1,1) mesh shape (size = 1, len = 2). |
c0647a4 to
3aac5ce
Compare
We also support an empty tuple for the partition spec for scalar values, as well as specifying ('x', None) for 1D mesh shapes - as long as the tensor shape matches it. These semantics seem quite relaxed (we could check with JAX), but I don't see a major reason to change that, especially since it'd be making it more restrict for existing use cases. @tengyifei in this case, do you see any concern with removing the added expected constraint in https://github.com/pytorch/xla/pull/8882/files#diff-3dcff2b7395bbf1f8a09170775388ef686a1e5f593b3c3889996d78c93a9c394R582 (L582) - I updated the PR, so see below: Technically, it was meant as a way to enforce a constraint with the specified partition spec and the provided mesh - but that is really an added delta. We just moved the global device check to the mesh, since it's more suitable there (mesh should enforce that invariant, instead of each of the individual sharding's annotations). |
3aac5ce to
b59039c
Compare
tengyifei
left a comment
There was a problem hiding this comment.
assert len(partition_spec) == len(mesh.shape()), \
f"Partition spec length ({len(partition_spec)}) should be equal to the mesh shape dimensions ({len(mesh.shape())})."
This check wouldn't make sense. I can shard a 2D tensor over a 4D mesh by e.g. sharding each tensor dim over two mesh axes.
It sounds like you're removing it, and I don't see this check in the latest commit, so that sounds reasonable to me.
However, there's a failed HybridMesh test. That test failed because it didn't mock out the global_runtime_device_count method. We can probably fix it by mocking out this function to return the desired number of devices for the test.
b59039c to
993d46f
Compare
|
Absolutely, agree - thanks. Added the mock, I'll re-request review once the CI succeeds. |
This PR improves the input validation in the Mesh class by making error messages more descriptive. The changes make it easier to debug mesh configuration issues and provide clearer feedback to users. In addition, it moves the invariant encapsulation to the Mesh constructor, as opposed to mark_sharding. This makes it so that we only validate the sharding annotations with respect to the specified Mesh, as opposed to having standalone checks against the number of global participating devices.