Skip to content

Encapsulate Mesh invariants #8882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 27, 2025

Conversation

rpsilva-aws
Copy link
Collaborator

@rpsilva-aws rpsilva-aws commented Mar 25, 2025

This PR improves the input validation in the Mesh class by making error messages more descriptive. The changes make it easier to debug mesh configuration issues and provide clearer feedback to users. In addition, it moves the invariant encapsulation to the Mesh constructor, as opposed to mark_sharding. This makes it so that we only validate the sharding annotations with respect to the specified Mesh, as opposed to having standalone checks against the number of global participating devices.

@rpsilva-aws rpsilva-aws marked this pull request as ready for review March 25, 2025 22:15
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_sharding_ref branch from 9b8e2db to e4df499 Compare March 25, 2025 22:16
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_sharding_ref branch 2 times, most recently from 636eb45 to ddb791a Compare March 25, 2025 23:28
@tengyifei
Copy link
Collaborator

cc @lsy323

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_sharding_ref branch 2 times, most recently from e746546 to c0647a4 Compare March 26, 2025 01:49
@rpsilva-aws
Copy link
Collaborator Author

rpsilva-aws commented Mar 26, 2025

Switched the mesh shape' size (np.prod) assert to the mesh shape' length, since partition spec (0, 1) is still accurate to represent a single device with (1,1) mesh shape (size = 1, len = 2).

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_sharding_ref branch from c0647a4 to 3aac5ce Compare March 26, 2025 02:18
@rpsilva-aws
Copy link
Collaborator Author

rpsilva-aws commented Mar 26, 2025

Switched the mesh shape' size (np.prod) assert to the mesh shape' length, since partition spec (0, 1) is still accurate to represent a single device with (1,1) mesh shape (size = 1, len = 2).

We also support an empty tuple for the partition spec for scalar values, as well as specifying ('x', None) for 1D mesh shapes - as long as the tensor shape matches it. These semantics seem quite relaxed (we could check with JAX), but I don't see a major reason to change that, especially since it'd be making it more restrict for existing use cases. @tengyifei in this case, do you see any concern with removing the added expected constraint in https://github.com/pytorch/xla/pull/8882/files#diff-3dcff2b7395bbf1f8a09170775388ef686a1e5f593b3c3889996d78c93a9c394R582 (L582) - I updated the PR, so see below:

  assert len(partition_spec) == len(mesh.shape()), \
    f"Partition spec length ({len(partition_spec)}) should be equal to the mesh shape dimensions ({len(mesh.shape())})."

Technically, it was meant as a way to enforce a constraint with the specified partition spec and the provided mesh - but that is really an added delta. We just moved the global device check to the mesh, since it's more suitable there (mesh should enforce that invariant, instead of each of the individual sharding's annotations).

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_sharding_ref branch from 3aac5ce to b59039c Compare March 26, 2025 17:52
Copy link
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  assert len(partition_spec) == len(mesh.shape()), \
    f"Partition spec length ({len(partition_spec)}) should be equal to the mesh shape dimensions ({len(mesh.shape())})."

This check wouldn't make sense. I can shard a 2D tensor over a 4D mesh by e.g. sharding each tensor dim over two mesh axes.

It sounds like you're removing it, and I don't see this check in the latest commit, so that sounds reasonable to me.

However, there's a failed HybridMesh test. That test failed because it didn't mock out the global_runtime_device_count method. We can probably fix it by mocking out this function to return the desired number of devices for the test.

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_sharding_ref branch from b59039c to 993d46f Compare March 26, 2025 22:46
@rpsilva-aws
Copy link
Collaborator Author

Absolutely, agree - thanks. Added the mock, I'll re-request review once the CI succeeds.

@rpsilva-aws rpsilva-aws merged commit 96ad8f5 into pytorch:master Mar 27, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants