Skip to content

Create and Expose the torch_xla::OpSharding wrapper class instead of xla::OpSharding class #9390

Open
0 of 2 issues completed
@kvshbg-aws

Description

@kvshbg-aws

As a part of localized SPMD/submeshing effort we need abstract xla::OpSharding proto object with torch_xla specific wrapper class and expose the torch_xla::OpSharding object instead of xla proto object to the user

Currently we expose the xla::OpSharding object directly to the user when they call get_op_sharding object. This abstraction is relatively low level and it will work well when only operating with all devices, but not when users start their mesh at something other than 0.

To resolve this, we propose of creating a new torch_xla::OpSharding class which will be a wrapper of the xla::OpSharding proto. This wrapper class will have a constructor which takes in the xla::OpSharding object and the tile_assignment object. (The tile_assignment object will have the global_device_ids which will be further required down the stack while creating device_assignment for pjrt_client, hence we would like to pass it as a param for the constructor of this class)

The wrapper class needs to have forwarded methods, so that the user can still make use the same APIs as that of xla::OpSharding (for example xla::OpSharding sharding.type()) to access the proto’s fields/variables etc. The only difference is that the user will now have to make use of torch_xla::OpSharding object instead of the xla proto when calling the API’s. The plus side of having such abstraction is that it allows the users to define their own custom fields inside the torch_xla::OpSharding class, for example, we can define the global_tile_assignment field which will have the global_device_ids required down the line in our use-case of localized SPMD/sub-meshing.

Once we have the wrapper class ready, we will return it instead of xla::OpSharding when we return CreateOpSharding from xla_sharding_util.cpp file and thus the init_python_bindings.cpp will also incur some changes, where in we will convert the references made to xla::OpSharding to torch_xla::OpSharding

Note - changing the return type over here will also require us to change the type in all subsequent function/calls which make use of xla::OpSharding object

Sub-issues

Metadata

Metadata

Assignees

Labels

distributedSPMD and other distributed things.enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions