Description
🚀 Feature
The previous work #92909 on DTensor support with XLA backend enabled factory functions including distribute_tensor
to create XLA sharded tensors when using XLA SPMD. However, it simply allows creating an XLAShardedTensor object but this object did not implement the full set of DTensor APIs necessary for users to utilize these tensors in parallelism workflows. This RFC proposes implementing DTensor APIs in XLAShardedTensor. This approach will enable PyTorch users to express tensor distributions consistently across different backends.
Currently, DTensor for XLA devices can be created through:
import torch
from torch.distributed import DeviceMesh, Shard, distribute_tensor
mesh = DeviceMesh("xla", list(range(world_size)))
big_tensor = torch.randn(100000, 88)
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])
This example is from RFC #92909. With the changes in this proposal done, DTensor APIs and public properties can be used:
desired_layouts = [Replicate()]
# Reshard the tensor to be replicate on dim 0
my_dtensor.redistribute(desired_layouts)
# Create a zero tensor that has the same shape and sharding with my_tensor
zero_tensor = torch.distributed.tensor.zeros(my_dtensor.size(), device_mesh=mesh, placements=my_dtensor.placements)
Motivation
DTensor is becoming the new fundamental building block for distributed computation in PyTorch. Enabling DTensor with XLA devices offers several key benefits:
- DTensor provides simple and generic APIs to express tensor distribution in SPMD style.
- The integration reduces the overhead for users to learn new stacks and migrate their models.
- The integration allows access to many existing parallelism utilities, including tensor parallel and sequence parallel wrappers.
- The integration enables popular training frameworks built on top of DTensor, such as TorchTitan.
- The integration provides the opportunity to consolidate XLA into DTensor ecosystem, making better alignment between XLA and native Torch.
Currently, Torch XLA provides XLAShardedTensor, mesh, and mark_sharding APIs for distributed training. Ongoing work #92909 has already enabled creating a sharded tensor with XLA device using DTensor standard factory functions when using XLA SPMD. However, XLAShardedTensor has not yet implement the full suite of DTensor interfaces, preventing users from using DTensor with XLA devices in the same way they use it with non-XLA devices.
Pitch
We propose implementing DTensor SPMD APIs and related public properties in XLAShardedTensor, so that XLAShardedTensor can function like a DTensor. MPMD API enablement is out of the scope because DTensor SPMD APIs with XLA backend compilation provide a complete solution of distributed computation, both functionally and non-functionally. While there may still be some benefit enabling DTensor + XLA MPMD, this proposal won't focus on it.
DTensor SPMD APIs and related public properties are listed below: (reference: https://docs.pytorch.org/docs/stable/distributed.tensor.html)
### redistribute performs necessary collective operations that redistribute the current DTensor from its current placements to a new placements, or from current DeviceMesh to a new DeviceMesh.
redistribute(device_mesh=None, placements=None, *, async_op=False)
### Return the full tensor of this DTensor. It will perform necessary collectives to gather the local tensors from other ranks in its DeviceMesh and concatenate them together.
full_tensor(*, grad_placements=None)
### The DeviceMesh attribute that associates with this DTensor object.
device_mesh: DeviceMesh
### The placements attribute of this DTensor that describes the layout of this DTensor on the its DeviceMesh.
placements: tuple[torch.distributed.tensor.placement_types.Placement, ...]
While in #92909 DTensor factory function distribute_tensor
has been implemented to support creating a sharded tensor from a regular tensor, there are other convenient factory functions that haven't been implemented. We list them here and propose to implement them for full compatibility: (reference: https://docs.pytorch.org/docs/stable/distributed.tensor.html)
torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)
torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)
torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)
torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)
torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)
torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)
While there's existing support to convert a DTensor DeviceMesh to an XLA Mesh and further use it to create a sharded tensor on XLA device, DeviceMesh can be also used in the form of a submesh(a mesh with a subset of dimensions of the root mesh that can be conveniently used when composing multiple sharding strategies), which we don't see the support provided in XLA. This would require implementation of following APIs: (reference: https://github.com/pytorch/pytorch/blob/60abb0d3273749cb2a7d583c7c2863bd2819e87e/torch/distributed/device_mesh.py)
_MeshEnv.create_sub_mesh(self, device_mesh: "DeviceMesh", submesh_dim_names: tuple[str, ...], submesh_dims: list[tuple[int, ...]],)
_MeshEnv.get_root_mesh(self, device_mesh: "DeviceMesh")
DeviceMesh.__getitem__(self, mesh_dim_names: Union[str, tuple[str, ...]])
High-level Approach
Changes to XLAShardedTensor
The first change needs to be done is we will let XLAShardedTensor inherit DTensor. This makes sure XLAShardedTensor can be used at the places where DTensor is required. We implement SPMD APIs specified above while we explicitly throw exceptions for MPMD API calls, with the proper message conveyed to users that MPMD APIs are disabled when using XLA SPMD.
DTensor.redistribute
The redistribute API reshards a DTensor from its current placement to a new placement, by calling collectives. XLA doesn't have an explicit control on resharding a tensor, but this can be achieved by simply cloning the existing tensor and do a mark_sharding
on new tensor. Alternatively, we could create a new XLA reshard API to support resharding, which can hide the clone in behind and possibly handle buffer donation so this process can be more efficient.
def redistribute(device_mesh=None, placements=None, *, async_op=False):
...
clone_tensor = this.clone().detach()
return distribute_tensor(clone_tensor, device_mesh, placements)
DTensor.full_tensor
The full_tensor API returns the global tensor. With XLA, we can simply utilize the redistribute API with replicate placements.
DTensor Properties
For device_mesh
and placements
, in XLAShardedTensor we don't have these objects saved. We will need to save them. As long as XLAShardedTensor starts to inherit DTensor, these fields will be there already.
DTensor Factory Functions
All the factory functions that need to be implemented above(zeros, ones etc.) rely on a DTensor internal helper function _dtensor_init_helper
(referece). This helper doesn't utilize distribute_tensor
factory function because the helper takes advantage of the traits of constant tensors so it can initialize only the local tensor and mark them distributed instead of distributing a global tensor. However, when using XLA SPMD the buffers are allocated with sharded shape anyway so we can simply replace the implementation with distribute_tensor
def _dtensor_init_helper(
init_op,
size: torch.Size,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
**kwargs,
) -> DTensor:
...
if device_mesh.device_type == 'xla':
# initialize the global tensor
if init_op == torch.full:
fill_value = kwargs.pop("fill_value", 0)
global_tensor = init_op(size, **kwargs)
elif init_op == torch.rand or init_op == torch.randn:
# this tensor meta is not used except `shape`
dtype = kwargs.get("dtype", torch.get_default_dtype())
tensor_meta = TensorMeta(size, (0,), dtype)
spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta)
if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker:
random._rng_tracker = random.OffsetBasedRNGTracker(device_mesh)
assert random._rng_tracker is not None
with random._rng_tracker._distribute_region(spec):
global_tensor = init_op(size, **kwargs)
else:
global_tensor = init_op(size, **kwargs)
return distribute_tensor(global_tensor, device_mesh, placements)
DeviceMesh Submesh Utils
Submesh is a slice of the root DeviceMesh by indicating a dim name. For example, a 2-d mesh mesh_2d['dp', 'tp'] with 8 ranks can have a submesh mesh_2d['dp'] with 4 ranks [0,1,2,3] if the current rank is one of [0,1,2,3]. Torch XLA doesn't have a built-in support for submesh. To make this work, we can enable submeshes in Torch XLA. This could simply be a convenient helper class on top of Mesh that handles indexing of dim names and mapping to root mesh dims. Alternatively, we can avoid introducing the concept of submesh in Torch XLA, instead, DTensor submesh needs to be handled and mapped back to root mesh when converting to XLA mesh.
Alternatives
The goal of this RFC is to take advantage of the generic and simple DTensor API and provides consistent user experiences across different backends. This could potentially be achieved by integrating XLA backend with other APIs, too. Considering the native Torch ecosystem, wide user base and existing integration with XLA, we think it's beneficial to continue this integration, with other candidates considered later.