Description
As a part of localized SPMD/submeshing effort we need abstract xla::OpSharding
proto object with torch_xla specific wrapper class and expose the torch_xla::OpSharding object instead of xla proto object to the user
Currently we expose the xla::OpSharding
object directly to the user when they call get_op_sharding object. This abstraction is relatively low level and it will work well when only operating with all devices, but not when users start their mesh at something other than 0.
To resolve this, we propose of creating a new torch_xla::OpSharding
class which will be a wrapper of the xla::OpSharding
proto. This wrapper class will have a constructor which takes in the xla::OpSharding
object and the tile_assignment
object. (The tile_assignment object will have the global_device_ids which will be further required down the stack while creating device_assignment for pjrt_client, hence we would like to pass it as a param for the constructor of this class)
The wrapper class needs to have forwarded methods, so that the user can still make use the same APIs as that of xla::OpSharding
(for example xla::OpSharding sharding.type()
) to access the proto’s fields/variables etc. The only difference is that the user will now have to make use of torch_xla::OpSharding
object instead of the xla proto when calling the API’s. The plus side of having such abstraction is that it allows the users to define their own custom fields inside the torch_xla::OpSharding
class, for example, we can define the global_tile_assignment
field which will have the global_device_ids required down the line in our use-case of localized SPMD/sub-meshing.
Once we have the wrapper class ready, we will return it instead of xla::OpSharding when we return CreateOpSharding from xla_sharding_util.cpp file and thus the init_python_bindings.cpp will also incur some changes, where in we will convert the references made to xla::OpSharding to torch_xla::OpSharding
Note - changing the return type over here will also require us to change the type in all subsequent function/calls which make use of xla::OpSharding object