Skip to content

Changes needed for device_assignment in PjRtComputationClient::Compile to support submeshing/localized spmd #9391

Open
@kvshbg-aws

Description

@kvshbg-aws

Currently xla::DeviceAssignment used inside PjRtComputationClient::Compile call makes use of the local devices/device_count for device assignment, but for submeshing or even for localized SPMD we would want to make use of the global_device_ids the user provided.

Once we have the wrapper class defined in #9390 we can update the sharding object inside PjRtShardedData struct to use torch_xla::OpSharding instead of the xla::OpSharding object (here and here). This will allow accessing the original_tile_assignment (with global device IDs) using the torch_xla::OpSharding sharding object, and hence the device_assignment for PJRT can make use of the device_ids that the user provided during the Mesh initialization

At the moment, with the current implementation of PjRtComputationClient::Compile, a user would not be able to specify a subset of addressable devices in a program. This results from the PjRtComputationClient inferring the number of participants from the PJRT local devices, hence, we would also have to extend the function to support sub-meshing and for that we will update the function to make use of the torch_xla::OpSharding sharding.tile_assignment() (and hence the global_device_ids) while creating the device_assignment ‘s (here and here).

This will further enable us to introduce some additional asserts on the device_ids of the submesh, for example, the submesh should be a subset of client_->addressable_devices ; all the sharded tensors within a process should have the same mesh/submesh i.e. have the same tile_assignment.

And to support the above, we would have to make changes to the CompileInstance object to include an additional field which will hold the DataPtr to the pjrt_sharded_data which we can then use to get the sharding (which will be a torch_xla::OpSharding object) and hence get the original_tile_assignment from the sharding object

Metadata

Metadata

Assignees

Labels

distributedSPMD and other distributed things.enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions