|
1 | 1 | import copy
|
2 | 2 |
|
3 |
| -import unittest |
4 |
| -from unittest.mock import patch |
| 3 | +from collections import OrderedDict |
5 | 4 | import math
|
6 | 5 | import numpy as np
|
| 6 | +import unittest |
| 7 | +from unittest.mock import patch |
7 | 8 | import sys
|
8 | 9 |
|
9 | 10 | import torch
|
@@ -1565,6 +1566,97 @@ def test_mark_sharding_with_gradients_annotation(self):
|
1565 | 1566 | # Check that the gradient has sharding.
|
1566 | 1567 | self.assertIn(sharding_spec, x_grad_sharding)
|
1567 | 1568 |
|
| 1569 | + def test_valid_mesh_creation(self): |
| 1570 | + mesh_shape = (1, self.n_devices) |
| 1571 | + axis_names = ('data', 'model') |
| 1572 | + mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names) |
| 1573 | + |
| 1574 | + self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices))) |
| 1575 | + self.assertEqual(mesh.mesh_shape, mesh_shape) |
| 1576 | + self.assertEqual(mesh.axis_names, axis_names) |
| 1577 | + |
| 1578 | + def test_valid_mesh_without_axis_names(self): |
| 1579 | + mesh_shape = (1, self.n_devices) |
| 1580 | + mesh = xs.Mesh(self.device_ids, mesh_shape) |
| 1581 | + |
| 1582 | + self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices))) |
| 1583 | + self.assertEqual(mesh.mesh_shape, mesh_shape) |
| 1584 | + self.assertIsNone(mesh.axis_names) |
| 1585 | + |
| 1586 | + def test_invalid_axis_names_length(self): |
| 1587 | + mesh_shape = (1, self.n_devices) |
| 1588 | + axis_names = ('data', 'model', 'extra') |
| 1589 | + |
| 1590 | + with self.assertRaisesRegex( |
| 1591 | + AssertionError, "Number of axis names .* must match mesh dimensions"): |
| 1592 | + xs.Mesh(self.device_ids, mesh_shape, axis_names) |
| 1593 | + |
| 1594 | + def test_duplicate_axis_names(self): |
| 1595 | + mesh_shape = (1, self.n_devices) |
| 1596 | + axis_names = ('data', 'data') |
| 1597 | + |
| 1598 | + with self.assertRaisesRegex(AssertionError, "Axis names must be unique"): |
| 1599 | + xs.Mesh(self.device_ids, mesh_shape, axis_names) |
| 1600 | + |
| 1601 | + def test_invalid_device_count(self): |
| 1602 | + mesh_shape = (2, self.n_devices) |
| 1603 | + |
| 1604 | + with self.assertRaisesRegex(AssertionError, |
| 1605 | + "Number of device IDs .* must match mesh size"): |
| 1606 | + xs.Mesh(self.device_ids, mesh_shape) |
| 1607 | + |
| 1608 | + @unittest.skipIf(xr.global_runtime_device_count() == 1, |
| 1609 | + "Multiple devices needed for duplicated device IDs") |
| 1610 | + def test_duplicate_device_ids(self): |
| 1611 | + mesh_shape = (1, self.n_devices) |
| 1612 | + duplicate_ids = np.array([0] * self.n_devices) |
| 1613 | + |
| 1614 | + with self.assertRaisesRegex(AssertionError, "Device IDs must be unique"): |
| 1615 | + xs.Mesh(duplicate_ids, mesh_shape) |
| 1616 | + |
| 1617 | + def test_device_ids_out_of_bounds(self): |
| 1618 | + mesh_shape = (1, self.n_devices) |
| 1619 | + invalid_ids = np.arange(self.n_devices + 1, self.n_devices * 2 + 1) |
| 1620 | + |
| 1621 | + with self.assertRaisesRegex(AssertionError, |
| 1622 | + "Device IDs must be less than mesh size"): |
| 1623 | + xs.Mesh(invalid_ids, mesh_shape) |
| 1624 | + |
| 1625 | + def test_mesh_size(self): |
| 1626 | + mesh_shape = (1, self.n_devices) |
| 1627 | + mesh = xs.Mesh(self.device_ids, mesh_shape) |
| 1628 | + self.assertEqual(mesh.size(), self.n_devices) |
| 1629 | + |
| 1630 | + def test_mesh_shape_method(self): |
| 1631 | + mesh_shape = (1, self.n_devices) |
| 1632 | + axis_names = ('data', 'model') |
| 1633 | + mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names) |
| 1634 | + |
| 1635 | + expected_shape = OrderedDict([('data', 1), ('model', self.n_devices)]) |
| 1636 | + self.assertEqual(mesh.shape(), expected_shape) |
| 1637 | + |
| 1638 | + @unittest.skipIf(xr.global_runtime_device_count() == 1, |
| 1639 | + "Multiple devices needed") |
| 1640 | + def test_mismatch_global_devices(self): |
| 1641 | + partial_num_devices = self.n_devices // 2 |
| 1642 | + device_ids = np.arange(partial_num_devices) |
| 1643 | + mesh_shape = (1, partial_num_devices) |
| 1644 | + with self.assertRaisesRegex( |
| 1645 | + AssertionError, |
| 1646 | + "Number of device IDs .* must match the global number of devices"): |
| 1647 | + xs.Mesh(device_ids, mesh_shape) |
| 1648 | + |
| 1649 | + @unittest.skipIf(xr.global_runtime_device_count() == 1, |
| 1650 | + "Multiple devices needed") |
| 1651 | + def test_get_logical_mesh(self): |
| 1652 | + device_ids = np.arange(self.n_devices) |
| 1653 | + mesh_shape = (2, self.n_devices // 2) |
| 1654 | + mesh = xs.Mesh(device_ids, mesh_shape) |
| 1655 | + |
| 1656 | + logical_mesh = mesh.get_logical_mesh() |
| 1657 | + self.assertEqual(logical_mesh.shape, mesh_shape) |
| 1658 | + np.testing.assert_array_equal(np.sort(logical_mesh.flatten()), device_ids) |
| 1659 | + |
1568 | 1660 |
|
1569 | 1661 | if __name__ == '__main__':
|
1570 | 1662 | test = unittest.main()
|
|
0 commit comments