| 
1 | 1 | import copy  | 
2 | 2 | 
 
  | 
3 |  | -import unittest  | 
4 |  | -from unittest.mock import patch  | 
 | 3 | +from collections import OrderedDict  | 
5 | 4 | import math  | 
6 | 5 | import numpy as np  | 
 | 6 | +import unittest  | 
 | 7 | +from unittest.mock import patch  | 
7 | 8 | import sys  | 
8 | 9 | 
 
  | 
9 | 10 | import torch  | 
@@ -1565,6 +1566,97 @@ def test_mark_sharding_with_gradients_annotation(self):  | 
1565 | 1566 |       # Check that the gradient has sharding.  | 
1566 | 1567 |       self.assertIn(sharding_spec, x_grad_sharding)  | 
1567 | 1568 | 
 
  | 
 | 1569 | +  def test_valid_mesh_creation(self):  | 
 | 1570 | +    mesh_shape = (1, self.n_devices)  | 
 | 1571 | +    axis_names = ('data', 'model')  | 
 | 1572 | +    mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names)  | 
 | 1573 | + | 
 | 1574 | +    self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices)))  | 
 | 1575 | +    self.assertEqual(mesh.mesh_shape, mesh_shape)  | 
 | 1576 | +    self.assertEqual(mesh.axis_names, axis_names)  | 
 | 1577 | + | 
 | 1578 | +  def test_valid_mesh_without_axis_names(self):  | 
 | 1579 | +    mesh_shape = (1, self.n_devices)  | 
 | 1580 | +    mesh = xs.Mesh(self.device_ids, mesh_shape)  | 
 | 1581 | + | 
 | 1582 | +    self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices)))  | 
 | 1583 | +    self.assertEqual(mesh.mesh_shape, mesh_shape)  | 
 | 1584 | +    self.assertIsNone(mesh.axis_names)  | 
 | 1585 | + | 
 | 1586 | +  def test_invalid_axis_names_length(self):  | 
 | 1587 | +    mesh_shape = (1, self.n_devices)  | 
 | 1588 | +    axis_names = ('data', 'model', 'extra')  | 
 | 1589 | + | 
 | 1590 | +    with self.assertRaisesRegex(  | 
 | 1591 | +        AssertionError, "Number of axis names .* must match mesh dimensions"):  | 
 | 1592 | +      xs.Mesh(self.device_ids, mesh_shape, axis_names)  | 
 | 1593 | + | 
 | 1594 | +  def test_duplicate_axis_names(self):  | 
 | 1595 | +    mesh_shape = (1, self.n_devices)  | 
 | 1596 | +    axis_names = ('data', 'data')  | 
 | 1597 | + | 
 | 1598 | +    with self.assertRaisesRegex(AssertionError, "Axis names must be unique"):  | 
 | 1599 | +      xs.Mesh(self.device_ids, mesh_shape, axis_names)  | 
 | 1600 | + | 
 | 1601 | +  def test_invalid_device_count(self):  | 
 | 1602 | +    mesh_shape = (2, self.n_devices)  | 
 | 1603 | + | 
 | 1604 | +    with self.assertRaisesRegex(AssertionError,  | 
 | 1605 | +                                "Number of device IDs .* must match mesh size"):  | 
 | 1606 | +      xs.Mesh(self.device_ids, mesh_shape)  | 
 | 1607 | + | 
 | 1608 | +  @unittest.skipIf(xr.global_runtime_device_count() == 1,  | 
 | 1609 | +                   "Multiple devices needed for duplicated device IDs")  | 
 | 1610 | +  def test_duplicate_device_ids(self):  | 
 | 1611 | +    mesh_shape = (1, self.n_devices)  | 
 | 1612 | +    duplicate_ids = np.array([0] * self.n_devices)  | 
 | 1613 | + | 
 | 1614 | +    with self.assertRaisesRegex(AssertionError, "Device IDs must be unique"):  | 
 | 1615 | +      xs.Mesh(duplicate_ids, mesh_shape)  | 
 | 1616 | + | 
 | 1617 | +  def test_device_ids_out_of_bounds(self):  | 
 | 1618 | +    mesh_shape = (1, self.n_devices)  | 
 | 1619 | +    invalid_ids = np.arange(self.n_devices + 1, self.n_devices * 2 + 1)  | 
 | 1620 | + | 
 | 1621 | +    with self.assertRaisesRegex(AssertionError,  | 
 | 1622 | +                                "Device IDs must be less than mesh size"):  | 
 | 1623 | +      xs.Mesh(invalid_ids, mesh_shape)  | 
 | 1624 | + | 
 | 1625 | +  def test_mesh_size(self):  | 
 | 1626 | +    mesh_shape = (1, self.n_devices)  | 
 | 1627 | +    mesh = xs.Mesh(self.device_ids, mesh_shape)  | 
 | 1628 | +    self.assertEqual(mesh.size(), self.n_devices)  | 
 | 1629 | + | 
 | 1630 | +  def test_mesh_shape_method(self):  | 
 | 1631 | +    mesh_shape = (1, self.n_devices)  | 
 | 1632 | +    axis_names = ('data', 'model')  | 
 | 1633 | +    mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names)  | 
 | 1634 | + | 
 | 1635 | +    expected_shape = OrderedDict([('data', 1), ('model', self.n_devices)])  | 
 | 1636 | +    self.assertEqual(mesh.shape(), expected_shape)  | 
 | 1637 | + | 
 | 1638 | +  @unittest.skipIf(xr.global_runtime_device_count() == 1,  | 
 | 1639 | +                   "Multiple devices needed")  | 
 | 1640 | +  def test_mismatch_global_devices(self):  | 
 | 1641 | +    partial_num_devices = self.n_devices // 2  | 
 | 1642 | +    device_ids = np.arange(partial_num_devices)  | 
 | 1643 | +    mesh_shape = (1, partial_num_devices)  | 
 | 1644 | +    with self.assertRaisesRegex(  | 
 | 1645 | +        AssertionError,  | 
 | 1646 | +        "Number of device IDs .* must match the global number of devices"):  | 
 | 1647 | +      xs.Mesh(device_ids, mesh_shape)  | 
 | 1648 | + | 
 | 1649 | +  @unittest.skipIf(xr.global_runtime_device_count() == 1,  | 
 | 1650 | +                   "Multiple devices needed")  | 
 | 1651 | +  def test_get_logical_mesh(self):  | 
 | 1652 | +    device_ids = np.arange(self.n_devices)  | 
 | 1653 | +    mesh_shape = (2, self.n_devices // 2)  | 
 | 1654 | +    mesh = xs.Mesh(device_ids, mesh_shape)  | 
 | 1655 | + | 
 | 1656 | +    logical_mesh = mesh.get_logical_mesh()  | 
 | 1657 | +    self.assertEqual(logical_mesh.shape, mesh_shape)  | 
 | 1658 | +    np.testing.assert_array_equal(np.sort(logical_mesh.flatten()), device_ids)  | 
 | 1659 | + | 
1568 | 1660 | 
 
  | 
1569 | 1661 | if __name__ == '__main__':  | 
1570 | 1662 |   test = unittest.main()  | 
 | 
0 commit comments