Description
Currently xla::DeviceAssignment
used inside PjRtComputationClient::Compile
call makes use of the local devices/device_count for device assignment, but for submeshing or even for localized SPMD we would want to make use of the global_device_ids the user provided.
Once we have the wrapper class defined in #9390 we can update the sharding object inside PjRtShardedData
struct to use torch_xla::OpSharding
instead of the xla::OpSharding
object (here and here). This will allow accessing the original_tile_assignment (with global device IDs) using the torch_xla::OpSharding sharding
object, and hence the device_assignment for PJRT can make use of the device_ids that the user provided during the Mesh initialization
At the moment, with the current implementation of PjRtComputationClient::Compile
, a user would not be able to specify a subset of addressable devices in a program. This results from the PjRtComputationClient inferring the number of participants from the PJRT local devices, hence, we would also have to extend the function to support sub-meshing and for that we will update the function to make use of the torch_xla::OpSharding sharding.tile_assignment()
(and hence the global_device_ids) while creating the device_assignment ‘s (here and here).
This will further enable us to introduce some additional asserts on the device_ids of the submesh, for example, the submesh should be a subset of client_->addressable_devices
; all the sharded tensors within a process should have the same mesh/submesh i.e. have the same tile_assignment.
And to support the above, we would have to make changes to the CompileInstance object to include an additional field which will hold the DataPtr to the pjrt_sharded_data which we can then use to get the sharding (which will be a torch_xla::OpSharding object) and hence get the original_tile_assignment from the sharding object