Skip to content

Normalize tile_assignment after constructing the xla::OpSharding object #9389

Open
@kvshbg-aws

Description

@kvshbg-aws

Currently users cannot define the mesh with device_ids starting from anything except 0. This blocks us from defining sub-meshes, and also blocks the user from using localized SPMD within a single node. The error message being -

RuntimeError: Passing an empty index list to Tensor::index() is not valid syntax

To resolve this issue we propose normalizing the tile_assignment object after creating an xla::OpSharding inside CreateOpSharding function. This will also make sure that the tile_assignment passed on through the entire process from XlaMarkSharding() (inside xla_sharding_util.cpp file) through SetSharding() (insdie ir.cpp file) to the HLO which will be generated.
This will allow the users to define submeshes which start with device_ids other than 0 and also pave the way for defining submeshes as well as inter-node localized SPMD.

We propose to make use of an anonymous function inside xla_sharding_util.cpp file which can be used to normalize the tile_assignment_devices field of the OpSharding object

Metadata

Metadata

Assignees

Labels

distributedSPMD and other distributed things.enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions