From 993d46fd2ad033387e869c80bfe8f87b84274ebb Mon Sep 17 00:00:00 2001 From: Rui Silva Date: Tue, 25 Mar 2025 22:09:30 +0000 Subject: [PATCH] Encapsulate Mesh invariants --- test/spmd/test_xla_sharding.py | 101 ++++++++++++++++++++- torch_xla/distributed/spmd/xla_sharding.py | 36 +++++--- 2 files changed, 121 insertions(+), 16 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 2f77ec210f5..6d7846eba56 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1,9 +1,10 @@ import copy -import unittest -from unittest.mock import patch +from collections import OrderedDict import math import numpy as np +import unittest +from unittest.mock import patch import sys import torch @@ -762,9 +763,12 @@ def test_hybrid_mesh_shape(self): "Crash on TPU v2") @patch('torch_xla.runtime.global_runtime_device_attributes') @patch('torch_xla.core.xla_model.xla_device_hw') - def test_hybrid_mesh(self, xla_device_mock, device_attributes_mock): + @patch('torch_xla.runtime.global_runtime_device_count') + def test_hybrid_mesh(self, device_count_mock, xla_device_mock, + device_attributes_mock): # mock device attributes for 2 slices of v4-8 num_slices = 2 + device_count_mock.return_value = 8 xla_device_mock.return_value = "TPU" device_attributes_mock.return_value = [{ 'coords': [0, 0, 0], @@ -1565,6 +1569,97 @@ def test_mark_sharding_with_gradients_annotation(self): # Check that the gradient has sharding. self.assertIn(sharding_spec, x_grad_sharding) + def test_valid_mesh_creation(self): + mesh_shape = (1, self.n_devices) + axis_names = ('data', 'model') + mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names) + + self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices))) + self.assertEqual(mesh.mesh_shape, mesh_shape) + self.assertEqual(mesh.axis_names, axis_names) + + def test_valid_mesh_without_axis_names(self): + mesh_shape = (1, self.n_devices) + mesh = xs.Mesh(self.device_ids, mesh_shape) + + self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices))) + self.assertEqual(mesh.mesh_shape, mesh_shape) + self.assertIsNone(mesh.axis_names) + + def test_invalid_axis_names_length(self): + mesh_shape = (1, self.n_devices) + axis_names = ('data', 'model', 'extra') + + with self.assertRaisesRegex( + AssertionError, "Number of axis names .* must match mesh dimensions"): + xs.Mesh(self.device_ids, mesh_shape, axis_names) + + def test_duplicate_axis_names(self): + mesh_shape = (1, self.n_devices) + axis_names = ('data', 'data') + + with self.assertRaisesRegex(AssertionError, "Axis names must be unique"): + xs.Mesh(self.device_ids, mesh_shape, axis_names) + + def test_invalid_device_count(self): + mesh_shape = (2, self.n_devices) + + with self.assertRaisesRegex(AssertionError, + "Number of device IDs .* must match mesh size"): + xs.Mesh(self.device_ids, mesh_shape) + + @unittest.skipIf(xr.global_runtime_device_count() == 1, + "Multiple devices needed for duplicated device IDs") + def test_duplicate_device_ids(self): + mesh_shape = (1, self.n_devices) + duplicate_ids = np.array([0] * self.n_devices) + + with self.assertRaisesRegex(AssertionError, "Device IDs must be unique"): + xs.Mesh(duplicate_ids, mesh_shape) + + def test_device_ids_out_of_bounds(self): + mesh_shape = (1, self.n_devices) + invalid_ids = np.arange(self.n_devices + 1, self.n_devices * 2 + 1) + + with self.assertRaisesRegex(AssertionError, + "Device IDs must be less than mesh size"): + xs.Mesh(invalid_ids, mesh_shape) + + def test_mesh_size(self): + mesh_shape = (1, self.n_devices) + mesh = xs.Mesh(self.device_ids, mesh_shape) + self.assertEqual(mesh.size(), self.n_devices) + + def test_mesh_shape_method(self): + mesh_shape = (1, self.n_devices) + axis_names = ('data', 'model') + mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names) + + expected_shape = OrderedDict([('data', 1), ('model', self.n_devices)]) + self.assertEqual(mesh.shape(), expected_shape) + + @unittest.skipIf(xr.global_runtime_device_count() == 1, + "Multiple devices needed") + def test_mismatch_global_devices(self): + partial_num_devices = self.n_devices // 2 + device_ids = np.arange(partial_num_devices) + mesh_shape = (1, partial_num_devices) + with self.assertRaisesRegex( + AssertionError, + "Number of device IDs .* must match the global number of devices"): + xs.Mesh(device_ids, mesh_shape) + + @unittest.skipIf(xr.global_runtime_device_count() == 1, + "Multiple devices needed") + def test_get_logical_mesh(self): + device_ids = np.arange(self.n_devices) + mesh_shape = (2, self.n_devices // 2) + mesh = xs.Mesh(device_ids, mesh_shape) + + logical_mesh = mesh.get_logical_mesh() + self.assertEqual(logical_mesh.shape, mesh_shape) + np.testing.assert_array_equal(np.sort(logical_mesh.flatten()), device_ids) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 9f1fecf6277..1ade0b4249d 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -69,14 +69,32 @@ def __init__(self, axis_names: Optional[tuple[str, ...]] = None): if not isinstance(device_ids, np.ndarray): device_ids = np.array(device_ids) - assert (axis_names is None) or (len(mesh_shape) == len(axis_names)) - assert axis_names is None or (len(set(axis_names)) == len(axis_names)) - assert (len(device_ids) == np.prod(mesh_shape)) - assert len(device_ids) == len(np.unique(device_ids)) + + # At the moment, XLA requires that the Mesh uses the global number of + # devices. + num_devices = xr.global_runtime_device_count() + assert num_devices > 0, "This requires XLA supported device(s)." + assert num_devices == len( + device_ids + ), f"Number of device IDs ({len(device_ids)}) must match the global number of devices ({num_devices})" + + if axis_names is not None: + assert len(mesh_shape) == len(axis_names), \ + f"Number of axis names ({len(axis_names)}) must match mesh dimensions ({len(mesh_shape)})" + assert len(set(axis_names)) == len(axis_names), \ + f"Axis names must be unique, got: {axis_names}" + + expected_devices = np.prod(mesh_shape) + assert len(device_ids) == expected_devices, \ + f"Number of device IDs ({len(device_ids)}) must match mesh size ({expected_devices})" + assert len(device_ids) == len(np.unique(device_ids)), \ + f"Device IDs must be unique, got: {device_ids}" + self.device_ids = device_ids self.mesh_shape = mesh_shape self.axis_names = axis_names - assert all(d < self.size() for d in device_ids) + assert all(d < self.size() for d in device_ids), \ + f"Device IDs must be less than mesh size ({self.size()}), got: {device_ids}" def size(self): return np.prod(self.mesh_shape) @@ -555,10 +573,6 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, >>> linear = nn.Linear(32, 10).to(xm.xla_device()) >>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel """ - num_devices = xr.global_runtime_device_count() - assert num_devices > 0, "This requires XLA supported device(s)." - assert mesh.size() == num_devices, \ - f"{mesh.mesh_shape} is not mappable over {num_devices} devices." # We only allow fully specified `partition_spec` to be applicable, as opposed # to filling in the unspecified replicated dims. Fully specified `partiion_spec` # should be of the same rank as `t`. This is to support partial replication @@ -603,10 +617,6 @@ def mark_sharding_with_gradients( This version can also be used in AOTAutograd. """ - num_devices = xr.global_runtime_device_count() - assert num_devices > 0, "This requires XLA supported device(s)." - assert mesh.size() == num_devices, \ - f"{mesh.mesh_shape} is not mappable over {num_devices} devices." # We only allow fully specified `partition_spec` to be applicable, as opposed # to filling in the unspecified replicated dims. Fully specified `partiion_spec` # should be of the same rank as `t`. This is to support partial replication