diff --git a/flax/jax_utils.py b/flax/jax_utils.py index bfe6849f3..466d715bd 100644 --- a/flax/jax_utils.py +++ b/flax/jax_utils.py @@ -25,6 +25,7 @@ from jax import core, lax from jax.extend import linear_util as lu from jax.interpreters import partial_eval as pe +from jax.sharding import NamedSharding, PartitionSpec as P, AxisType def _pmap_device_order(): @@ -42,7 +43,24 @@ def replicate(tree, devices=None): A new pytree containing the replicated arrays. """ devices = devices or _pmap_device_order() - return jax.device_put_replicated(tree, devices) + mesh = jax.make_mesh( + (len(devices),), + ("_flax_jax_utils_replicate_data_axis",), + (AxisType.Auto,), + devices=devices, + ) + data_sharding = NamedSharding(mesh, P("_flax_jax_utils_replicate_data_axis")) + + def _device_put_replicated(x): + if isinstance(x, (jax.Array, np.ndarray)): + buf = x[None] + else: + buf = jnp.asarray(x)[None] + buf = jnp.concat([buf] * len(devices)) + return jax.device_put(buf, data_sharding) + + with jax.set_mesh(mesh): + return jax.tree.map(_device_put_replicated, tree) def unreplicate(tree): @@ -137,17 +155,26 @@ def prefetch_to_device(iterator, size, devices=None): queue = collections.deque() devices = _pmap_device_order() if devices is None else devices + mesh = jax.make_mesh( + (len(devices),), + ("_flax_jax_utils_prefetch_to_device_data_axis",), + (AxisType.Auto,), + devices=devices, + ) + data_sharding = NamedSharding(mesh, P("_flax_jax_utils_prefetch_to_device_data_axis")) + def _prefetch(xs): - return jax.device_put_sharded(list(xs), devices) + return jax.device_put(xs, data_sharding) def enqueue(n): # Enqueues *up to* `n` elements from the iterator. for data in itertools.islice(iterator, n): queue.append(jax.tree_util.tree_map(_prefetch, data)) - enqueue(size) # Fill up the buffer. - while queue: - yield queue.popleft() - enqueue(1) + with jax.set_mesh(mesh): + enqueue(size) # Fill up the buffer. + while queue: + yield queue.popleft() + enqueue(1) def _scan_nd(body_fn, init, xs, n=1, unroll=(1,)): diff --git a/tests/jax_utils_test.py b/tests/jax_utils_test.py index c2130e700..feb636e1a 100644 --- a/tests/jax_utils_test.py +++ b/tests/jax_utils_test.py @@ -106,5 +106,64 @@ def add(params, a, *, b): np.testing.assert_allclose(np.float64(y), np.float64(5 * x + 10)) +class DataShardingTest(parameterized.TestCase): + def setUp(self): + if jax.device_count() < 4: + self.skipTest('At least 4 devices required') + + @parameterized.product(num_devices= ["all", 2]) + def test_prefetch_to_device(self, num_devices): + devices = jax.local_devices() + if isinstance(num_devices, int): + devices = devices[:num_devices] + shape = (len(devices), 4, 16, 16, 3) + iterator = (jnp.ones(shape) for _ in range(4)) + + data_iter = jax_utils.prefetch_to_device(iterator, size=3, devices=devices) + for _ in range(4): + data = next(data_iter) + self.assertEqual(data.shape, shape) + self.assertIsNotNone(data.sharding) + sharding_slices_per_device = data.sharding.devices_indices_map(tuple(data.shape)) + self.assertEqual(len(sharding_slices_per_device), len(devices)) + # Here we check that sharding_slices_per_device is like + # Device(id=2): (slice(2, 3, None), slice(None, None, None), ..., slice(None, None, None)) + for i, dev in enumerate(devices): + sharding_slice = sharding_slices_per_device[dev] + self.assertEqual(sharding_slice[0], slice(i + 0, i + 1, None)) + for sharding_slice_j in sharding_slice[1:]: + self.assertEqual(sharding_slice_j, slice(None, None, None)) + + @parameterized.product(num_devices= ["all", 2]) + def test_replicate(self, num_devices): + devices = jax.local_devices() + if isinstance(num_devices, int): + devices = devices[:num_devices] + num_batches = 5 + shape = (2, 3) + data_tree = [ + i * jnp.ones((2, 3)) for i in range(num_batches - 2) + ] + [4, 5 * np.ones(shape)] + out_tree = jax_utils.replicate(data_tree, devices=devices) + + def check_sharding(p): + if p.ndim == 1: + self.assertEqual(p.shape, (len(devices),)) + else: + self.assertEqual(p.shape, (len(devices), *shape)) + self.assertIsNotNone(p.sharding) + sharding_slices_per_device = p.sharding.devices_indices_map(tuple(p.shape)) + self.assertEqual(len(sharding_slices_per_device), len(devices)) + # Here we check that sharding_slices_per_device is like + # Device(id=2): (slice(2, 3, None), slice(None, None, None), slice(None, None, None)) + for i, dev in enumerate(devices): + sharding_slice = sharding_slices_per_device[dev] + self.assertEqual(sharding_slice[0], slice(i + 0, i + 1, None)) + for sharding_slice_j in sharding_slice[1:]: + self.assertEqual(sharding_slice_j, slice(None, None, None)) + + jax.tree.map(check_sharding, out_tree) + + if __name__ == '__main__': absltest.main()