Description
Currently users cannot define the mesh with device_ids starting from anything except 0. This blocks us from defining sub-meshes, and also blocks the user from using localized SPMD within a single node. The error message being -
RuntimeError: Passing an empty index list to Tensor::index() is not valid syntax
To resolve this issue we propose normalizing the tile_assignment
object after creating an xla::OpSharding
inside CreateOpSharding function. This will also make sure that the tile_assignment passed on through the entire process from XlaMarkSharding()
(inside xla_sharding_util.cpp file) through SetSharding()
(insdie ir.cpp file) to the HLO which will be generated.
This will allow the users to define submeshes which start with device_ids other than 0 and also pave the way for defining submeshes as well as inter-node localized SPMD.
We propose to make use of an anonymous function inside xla_sharding_util.cpp
file which can be used to normalize the tile_assignment_devices
field of the OpSharding object