| 
1 | 1 | import copy  | 
2 | 2 | 
 
  | 
3 |  | -import unittest  | 
4 |  | -from unittest.mock import patch  | 
 | 3 | +from collections import OrderedDict  | 
5 | 4 | import math  | 
6 | 5 | import numpy as np  | 
 | 6 | +import unittest  | 
 | 7 | +from unittest.mock import patch  | 
7 | 8 | import sys  | 
8 | 9 | 
 
  | 
9 | 10 | import torch  | 
@@ -762,9 +763,12 @@ def test_hybrid_mesh_shape(self):  | 
762 | 763 |                    "Crash on TPU v2")  | 
763 | 764 |   @patch('torch_xla.runtime.global_runtime_device_attributes')  | 
764 | 765 |   @patch('torch_xla.core.xla_model.xla_device_hw')  | 
765 |  | -  def test_hybrid_mesh(self, xla_device_mock, device_attributes_mock):  | 
 | 766 | +  @patch('torch_xla.runtime.global_runtime_device_count')  | 
 | 767 | +  def test_hybrid_mesh(self, device_count_mock, xla_device_mock,  | 
 | 768 | +                       device_attributes_mock):  | 
766 | 769 |     # mock device attributes for 2 slices of v4-8  | 
767 | 770 |     num_slices = 2  | 
 | 771 | +    device_count_mock.return_value = 8  | 
768 | 772 |     xla_device_mock.return_value = "TPU"  | 
769 | 773 |     device_attributes_mock.return_value = [{  | 
770 | 774 |         'coords': [0, 0, 0],  | 
@@ -1565,6 +1569,97 @@ def test_mark_sharding_with_gradients_annotation(self):  | 
1565 | 1569 |       # Check that the gradient has sharding.  | 
1566 | 1570 |       self.assertIn(sharding_spec, x_grad_sharding)  | 
1567 | 1571 | 
 
  | 
 | 1572 | +  def test_valid_mesh_creation(self):  | 
 | 1573 | +    mesh_shape = (1, self.n_devices)  | 
 | 1574 | +    axis_names = ('data', 'model')  | 
 | 1575 | +    mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names)  | 
 | 1576 | + | 
 | 1577 | +    self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices)))  | 
 | 1578 | +    self.assertEqual(mesh.mesh_shape, mesh_shape)  | 
 | 1579 | +    self.assertEqual(mesh.axis_names, axis_names)  | 
 | 1580 | + | 
 | 1581 | +  def test_valid_mesh_without_axis_names(self):  | 
 | 1582 | +    mesh_shape = (1, self.n_devices)  | 
 | 1583 | +    mesh = xs.Mesh(self.device_ids, mesh_shape)  | 
 | 1584 | + | 
 | 1585 | +    self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices)))  | 
 | 1586 | +    self.assertEqual(mesh.mesh_shape, mesh_shape)  | 
 | 1587 | +    self.assertIsNone(mesh.axis_names)  | 
 | 1588 | + | 
 | 1589 | +  def test_invalid_axis_names_length(self):  | 
 | 1590 | +    mesh_shape = (1, self.n_devices)  | 
 | 1591 | +    axis_names = ('data', 'model', 'extra')  | 
 | 1592 | + | 
 | 1593 | +    with self.assertRaisesRegex(  | 
 | 1594 | +        AssertionError, "Number of axis names .* must match mesh dimensions"):  | 
 | 1595 | +      xs.Mesh(self.device_ids, mesh_shape, axis_names)  | 
 | 1596 | + | 
 | 1597 | +  def test_duplicate_axis_names(self):  | 
 | 1598 | +    mesh_shape = (1, self.n_devices)  | 
 | 1599 | +    axis_names = ('data', 'data')  | 
 | 1600 | + | 
 | 1601 | +    with self.assertRaisesRegex(AssertionError, "Axis names must be unique"):  | 
 | 1602 | +      xs.Mesh(self.device_ids, mesh_shape, axis_names)  | 
 | 1603 | + | 
 | 1604 | +  def test_invalid_device_count(self):  | 
 | 1605 | +    mesh_shape = (2, self.n_devices)  | 
 | 1606 | + | 
 | 1607 | +    with self.assertRaisesRegex(AssertionError,  | 
 | 1608 | +                                "Number of device IDs .* must match mesh size"):  | 
 | 1609 | +      xs.Mesh(self.device_ids, mesh_shape)  | 
 | 1610 | + | 
 | 1611 | +  @unittest.skipIf(xr.global_runtime_device_count() == 1,  | 
 | 1612 | +                   "Multiple devices needed for duplicated device IDs")  | 
 | 1613 | +  def test_duplicate_device_ids(self):  | 
 | 1614 | +    mesh_shape = (1, self.n_devices)  | 
 | 1615 | +    duplicate_ids = np.array([0] * self.n_devices)  | 
 | 1616 | + | 
 | 1617 | +    with self.assertRaisesRegex(AssertionError, "Device IDs must be unique"):  | 
 | 1618 | +      xs.Mesh(duplicate_ids, mesh_shape)  | 
 | 1619 | + | 
 | 1620 | +  def test_device_ids_out_of_bounds(self):  | 
 | 1621 | +    mesh_shape = (1, self.n_devices)  | 
 | 1622 | +    invalid_ids = np.arange(self.n_devices + 1, self.n_devices * 2 + 1)  | 
 | 1623 | + | 
 | 1624 | +    with self.assertRaisesRegex(AssertionError,  | 
 | 1625 | +                                "Device IDs must be less than mesh size"):  | 
 | 1626 | +      xs.Mesh(invalid_ids, mesh_shape)  | 
 | 1627 | + | 
 | 1628 | +  def test_mesh_size(self):  | 
 | 1629 | +    mesh_shape = (1, self.n_devices)  | 
 | 1630 | +    mesh = xs.Mesh(self.device_ids, mesh_shape)  | 
 | 1631 | +    self.assertEqual(mesh.size(), self.n_devices)  | 
 | 1632 | + | 
 | 1633 | +  def test_mesh_shape_method(self):  | 
 | 1634 | +    mesh_shape = (1, self.n_devices)  | 
 | 1635 | +    axis_names = ('data', 'model')  | 
 | 1636 | +    mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names)  | 
 | 1637 | + | 
 | 1638 | +    expected_shape = OrderedDict([('data', 1), ('model', self.n_devices)])  | 
 | 1639 | +    self.assertEqual(mesh.shape(), expected_shape)  | 
 | 1640 | + | 
 | 1641 | +  @unittest.skipIf(xr.global_runtime_device_count() == 1,  | 
 | 1642 | +                   "Multiple devices needed")  | 
 | 1643 | +  def test_mismatch_global_devices(self):  | 
 | 1644 | +    partial_num_devices = self.n_devices // 2  | 
 | 1645 | +    device_ids = np.arange(partial_num_devices)  | 
 | 1646 | +    mesh_shape = (1, partial_num_devices)  | 
 | 1647 | +    with self.assertRaisesRegex(  | 
 | 1648 | +        AssertionError,  | 
 | 1649 | +        "Number of device IDs .* must match the global number of devices"):  | 
 | 1650 | +      xs.Mesh(device_ids, mesh_shape)  | 
 | 1651 | + | 
 | 1652 | +  @unittest.skipIf(xr.global_runtime_device_count() == 1,  | 
 | 1653 | +                   "Multiple devices needed")  | 
 | 1654 | +  def test_get_logical_mesh(self):  | 
 | 1655 | +    device_ids = np.arange(self.n_devices)  | 
 | 1656 | +    mesh_shape = (2, self.n_devices // 2)  | 
 | 1657 | +    mesh = xs.Mesh(device_ids, mesh_shape)  | 
 | 1658 | + | 
 | 1659 | +    logical_mesh = mesh.get_logical_mesh()  | 
 | 1660 | +    self.assertEqual(logical_mesh.shape, mesh_shape)  | 
 | 1661 | +    np.testing.assert_array_equal(np.sort(logical_mesh.flatten()), device_ids)  | 
 | 1662 | + | 
1568 | 1663 | 
 
  | 
1569 | 1664 | if __name__ == '__main__':  | 
1570 | 1665 |   test = unittest.main()  | 
 | 
0 commit comments