Skip to content

Encapsulate Mesh invariants #8882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 98 additions & 3 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import copy

import unittest
from unittest.mock import patch
from collections import OrderedDict
import math
import numpy as np
import unittest
from unittest.mock import patch
import sys

import torch
Expand Down Expand Up @@ -762,9 +763,12 @@ def test_hybrid_mesh_shape(self):
"Crash on TPU v2")
@patch('torch_xla.runtime.global_runtime_device_attributes')
@patch('torch_xla.core.xla_model.xla_device_hw')
def test_hybrid_mesh(self, xla_device_mock, device_attributes_mock):
@patch('torch_xla.runtime.global_runtime_device_count')
def test_hybrid_mesh(self, device_count_mock, xla_device_mock,
device_attributes_mock):
# mock device attributes for 2 slices of v4-8
num_slices = 2
device_count_mock.return_value = 8
xla_device_mock.return_value = "TPU"
device_attributes_mock.return_value = [{
'coords': [0, 0, 0],
Expand Down Expand Up @@ -1565,6 +1569,97 @@ def test_mark_sharding_with_gradients_annotation(self):
# Check that the gradient has sharding.
self.assertIn(sharding_spec, x_grad_sharding)

def test_valid_mesh_creation(self):
mesh_shape = (1, self.n_devices)
axis_names = ('data', 'model')
mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names)

self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices)))
self.assertEqual(mesh.mesh_shape, mesh_shape)
self.assertEqual(mesh.axis_names, axis_names)

def test_valid_mesh_without_axis_names(self):
mesh_shape = (1, self.n_devices)
mesh = xs.Mesh(self.device_ids, mesh_shape)

self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices)))
self.assertEqual(mesh.mesh_shape, mesh_shape)
self.assertIsNone(mesh.axis_names)

def test_invalid_axis_names_length(self):
mesh_shape = (1, self.n_devices)
axis_names = ('data', 'model', 'extra')

with self.assertRaisesRegex(
AssertionError, "Number of axis names .* must match mesh dimensions"):
xs.Mesh(self.device_ids, mesh_shape, axis_names)

def test_duplicate_axis_names(self):
mesh_shape = (1, self.n_devices)
axis_names = ('data', 'data')

with self.assertRaisesRegex(AssertionError, "Axis names must be unique"):
xs.Mesh(self.device_ids, mesh_shape, axis_names)

def test_invalid_device_count(self):
mesh_shape = (2, self.n_devices)

with self.assertRaisesRegex(AssertionError,
"Number of device IDs .* must match mesh size"):
xs.Mesh(self.device_ids, mesh_shape)

@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices needed for duplicated device IDs")
def test_duplicate_device_ids(self):
mesh_shape = (1, self.n_devices)
duplicate_ids = np.array([0] * self.n_devices)

with self.assertRaisesRegex(AssertionError, "Device IDs must be unique"):
xs.Mesh(duplicate_ids, mesh_shape)

def test_device_ids_out_of_bounds(self):
mesh_shape = (1, self.n_devices)
invalid_ids = np.arange(self.n_devices + 1, self.n_devices * 2 + 1)

with self.assertRaisesRegex(AssertionError,
"Device IDs must be less than mesh size"):
xs.Mesh(invalid_ids, mesh_shape)

def test_mesh_size(self):
mesh_shape = (1, self.n_devices)
mesh = xs.Mesh(self.device_ids, mesh_shape)
self.assertEqual(mesh.size(), self.n_devices)

def test_mesh_shape_method(self):
mesh_shape = (1, self.n_devices)
axis_names = ('data', 'model')
mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names)

expected_shape = OrderedDict([('data', 1), ('model', self.n_devices)])
self.assertEqual(mesh.shape(), expected_shape)

@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices needed")
def test_mismatch_global_devices(self):
partial_num_devices = self.n_devices // 2
device_ids = np.arange(partial_num_devices)
mesh_shape = (1, partial_num_devices)
with self.assertRaisesRegex(
AssertionError,
"Number of device IDs .* must match the global number of devices"):
xs.Mesh(device_ids, mesh_shape)

@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices needed")
def test_get_logical_mesh(self):
device_ids = np.arange(self.n_devices)
mesh_shape = (2, self.n_devices // 2)
mesh = xs.Mesh(device_ids, mesh_shape)

logical_mesh = mesh.get_logical_mesh()
self.assertEqual(logical_mesh.shape, mesh_shape)
np.testing.assert_array_equal(np.sort(logical_mesh.flatten()), device_ids)


if __name__ == '__main__':
test = unittest.main()
Expand Down
36 changes: 23 additions & 13 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,32 @@ def __init__(self,
axis_names: Optional[tuple[str, ...]] = None):
if not isinstance(device_ids, np.ndarray):
device_ids = np.array(device_ids)
assert (axis_names is None) or (len(mesh_shape) == len(axis_names))
assert axis_names is None or (len(set(axis_names)) == len(axis_names))
assert (len(device_ids) == np.prod(mesh_shape))
assert len(device_ids) == len(np.unique(device_ids))

# At the moment, XLA requires that the Mesh uses the global number of
# devices.
num_devices = xr.global_runtime_device_count()
assert num_devices > 0, "This requires XLA supported device(s)."
assert num_devices == len(
device_ids
), f"Number of device IDs ({len(device_ids)}) must match the global number of devices ({num_devices})"

if axis_names is not None:
assert len(mesh_shape) == len(axis_names), \
f"Number of axis names ({len(axis_names)}) must match mesh dimensions ({len(mesh_shape)})"
assert len(set(axis_names)) == len(axis_names), \
f"Axis names must be unique, got: {axis_names}"

expected_devices = np.prod(mesh_shape)
assert len(device_ids) == expected_devices, \
f"Number of device IDs ({len(device_ids)}) must match mesh size ({expected_devices})"
assert len(device_ids) == len(np.unique(device_ids)), \
f"Device IDs must be unique, got: {device_ids}"

self.device_ids = device_ids
self.mesh_shape = mesh_shape
self.axis_names = axis_names
assert all(d < self.size() for d in device_ids)
assert all(d < self.size() for d in device_ids), \
f"Device IDs must be less than mesh size ({self.size()}), got: {device_ids}"

def size(self):
return np.prod(self.mesh_shape)
Expand Down Expand Up @@ -555,10 +573,6 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
>>> linear = nn.Linear(32, 10).to(xm.xla_device())
>>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel
"""
num_devices = xr.global_runtime_device_count()
assert num_devices > 0, "This requires XLA supported device(s)."
assert mesh.size() == num_devices, \
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
# We only allow fully specified `partition_spec` to be applicable, as opposed
# to filling in the unspecified replicated dims. Fully specified `partiion_spec`
# should be of the same rank as `t`. This is to support partial replication
Expand Down Expand Up @@ -603,10 +617,6 @@ def mark_sharding_with_gradients(

This version can also be used in AOTAutograd.
"""
num_devices = xr.global_runtime_device_count()
assert num_devices > 0, "This requires XLA supported device(s)."
assert mesh.size() == num_devices, \
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
# We only allow fully specified `partition_spec` to be applicable, as opposed
# to filling in the unspecified replicated dims. Fully specified `partiion_spec`
# should be of the same rank as `t`. This is to support partial replication
Expand Down
Loading