Skip to content

Commit e4df499

Browse files
committed
Encapsulate Mesh invariants
1 parent a3ef52e commit e4df499

File tree

2 files changed

+101
-15
lines changed

2 files changed

+101
-15
lines changed

Diff for: test/spmd/test_xla_sharding.py

+81-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import copy
22

3-
import unittest
4-
from unittest.mock import patch
3+
from collections import OrderedDict
54
import math
65
import numpy as np
6+
import unittest
7+
from unittest.mock import patch
78
import sys
89

910
import torch
@@ -1565,6 +1566,84 @@ def test_mark_sharding_with_gradients_annotation(self):
15651566
# Check that the gradient has sharding.
15661567
self.assertIn(sharding_spec, x_grad_sharding)
15671568

1569+
def test_valid_mesh_creation(self):
1570+
mesh_shape = (1, self.n_devices)
1571+
axis_names = ('data', 'model')
1572+
mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names)
1573+
1574+
self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices)))
1575+
self.assertEqual(mesh.mesh_shape, mesh_shape)
1576+
self.assertEqual(mesh.axis_names, axis_names)
1577+
1578+
def test_valid_mesh_without_axis_names(self):
1579+
mesh_shape = (1, self.n_devices)
1580+
mesh = xs.Mesh(self.device_ids, mesh_shape)
1581+
1582+
self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices)))
1583+
self.assertEqual(mesh.mesh_shape, mesh_shape)
1584+
self.assertIsNone(mesh.axis_names)
1585+
1586+
def test_invalid_axis_names_length(self):
1587+
mesh_shape = (1, self.n_devices)
1588+
axis_names = ('data', 'model', 'extra')
1589+
1590+
with self.assertRaisesRegex(
1591+
AssertionError, "Number of axis names .* must match mesh dimensions"):
1592+
xs.Mesh(self.device_ids, mesh_shape, axis_names)
1593+
1594+
def test_duplicate_axis_names(self):
1595+
mesh_shape = (1, self.n_devices)
1596+
axis_names = ('data', 'data')
1597+
1598+
with self.assertRaisesRegex(AssertionError, "Axis names must be unique"):
1599+
xs.Mesh(self.device_ids, mesh_shape, axis_names)
1600+
1601+
def test_invalid_device_count(self):
1602+
mesh_shape = (2, self.n_devices)
1603+
1604+
with self.assertRaisesRegex(AssertionError,
1605+
"Number of device IDs .* must match mesh size"):
1606+
xs.Mesh(self.device_ids, mesh_shape)
1607+
1608+
@unittest.skipIf(xr.global_runtime_device_count() == 1,
1609+
"Multiple devices needed for duplicated device IDs")
1610+
def test_duplicate_device_ids(self):
1611+
mesh_shape = (1, self.n_devices)
1612+
duplicate_ids = np.array([0] * self.n_devices)
1613+
1614+
with self.assertRaisesRegex(AssertionError, "Device IDs must be unique"):
1615+
xs.Mesh(duplicate_ids, mesh_shape)
1616+
1617+
def test_device_ids_out_of_bounds(self):
1618+
mesh_shape = (1, self.n_devices)
1619+
invalid_ids = np.array([self.n_devices + 1] * self.n_devices)
1620+
1621+
with self.assertRaisesRegex(AssertionError,
1622+
"Device IDs must be less than mesh size"):
1623+
xs.Mesh(invalid_ids, mesh_shape)
1624+
1625+
def test_mesh_size(self):
1626+
mesh_shape = (1, self.n_devices)
1627+
mesh = xs.Mesh(self.device_ids, mesh_shape)
1628+
self.assertEqual(mesh.size(), self.n_devices)
1629+
1630+
def test_mesh_shape_method(self):
1631+
mesh_shape = (1, self.n_devices)
1632+
axis_names = ('data', 'model')
1633+
mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names)
1634+
1635+
expected_shape = OrderedDict([('data', 1), ('model', self.n_devices)])
1636+
self.assertEqual(mesh.shape(), expected_shape)
1637+
1638+
def test_get_logical_mesh(self):
1639+
mesh_shape = (2, 2)
1640+
device_ids = np.array([0, 1, 2, 3])
1641+
mesh = xs.Mesh(device_ids, mesh_shape)
1642+
1643+
expected_logical_mesh = np.array([[0, 1], [2, 3]])
1644+
np.testing.assert_array_equal(mesh.get_logical_mesh(),
1645+
expected_logical_mesh)
1646+
15681647

15691648
if __name__ == '__main__':
15701649
test = unittest.main()

Diff for: torch_xla/distributed/spmd/xla_sharding.py

+20-13
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,25 @@ def __init__(self,
6969
axis_names: Optional[tuple[str, ...]] = None):
7070
if not isinstance(device_ids, np.ndarray):
7171
device_ids = np.array(device_ids)
72-
assert (axis_names is None) or (len(mesh_shape) == len(axis_names))
73-
assert axis_names is None or (len(set(axis_names)) == len(axis_names))
74-
assert (len(device_ids) == np.prod(mesh_shape))
75-
assert len(device_ids) == len(np.unique(device_ids))
72+
assert len(device_ids) > 0, "This requires XLA supported device(s)."
73+
74+
if axis_names is not None:
75+
assert len(mesh_shape) == len(axis_names), \
76+
f"Number of axis names ({len(axis_names)}) must match mesh dimensions ({len(mesh_shape)})"
77+
assert len(set(axis_names)) == len(axis_names), \
78+
f"Axis names must be unique, got: {axis_names}"
79+
80+
expected_devices = np.prod(mesh_shape)
81+
assert len(device_ids) == expected_devices, \
82+
f"Number of device IDs ({len(device_ids)}) must match mesh size ({expected_devices})"
83+
assert len(device_ids) == len(np.unique(device_ids)), \
84+
f"Device IDs must be unique, got: {device_ids}"
85+
7686
self.device_ids = device_ids
7787
self.mesh_shape = mesh_shape
7888
self.axis_names = axis_names
79-
assert all(d < self.size() for d in device_ids)
89+
assert all(d < self.size() for d in device_ids), \
90+
f"Device IDs must be less than mesh size ({self.size()}), got: {device_ids}"
8091

8192
def size(self):
8293
return np.prod(self.mesh_shape)
@@ -555,16 +566,14 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
555566
>>> linear = nn.Linear(32, 10).to(xm.xla_device())
556567
>>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel
557568
"""
558-
num_devices = xr.global_runtime_device_count()
559-
assert num_devices > 0, "This requires XLA supported device(s)."
560-
assert mesh.size() == num_devices, \
561-
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
562569
# We only allow fully specified `partition_spec` to be applicable, as opposed
563570
# to filling in the unspecified replicated dims. Fully specified `partiion_spec`
564571
# should be of the same rank as `t`. This is to support partial replication
565572
# where the group assignment may vary with different input ranks.
566573
assert len(t.shape) == len(partition_spec), \
567574
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."
575+
assert len(partition_spec) == mesh.size(), \
576+
f"Partition spec length ({len(partition_spec)}) should be equal to the mesh size ({mesh.size()})."
568577

569578
op_sharding = mesh.get_op_sharding(partition_spec)
570579
annotate_func = torch_xla._XLAC._xla_mark_sharding
@@ -603,16 +612,14 @@ def mark_sharding_with_gradients(
603612
604613
This version can also be used in AOTAutograd.
605614
"""
606-
num_devices = xr.global_runtime_device_count()
607-
assert num_devices > 0, "This requires XLA supported device(s)."
608-
assert mesh.size() == num_devices, \
609-
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
610615
# We only allow fully specified `partition_spec` to be applicable, as opposed
611616
# to filling in the unspecified replicated dims. Fully specified `partiion_spec`
612617
# should be of the same rank as `t`. This is to support partial replication
613618
# where the group assignment may vary with different input ranks.
614619
assert len(t.shape) == len(partition_spec), \
615620
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."
621+
assert len(partition_spec) == mesh.size(), \
622+
f"Partition spec length ({len(partition_spec)}) should be equal to the mesh size ({mesh.size()})."
616623

617624
return MarkShardingFunction.apply(t, mesh, partition_spec)
618625

0 commit comments

Comments
 (0)