Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions flax/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from jax import core, lax
from jax.extend import linear_util as lu
from jax.interpreters import partial_eval as pe
from jax.sharding import NamedSharding, PartitionSpec as P, AxisType


def _pmap_device_order():
Expand All @@ -42,7 +43,24 @@ def replicate(tree, devices=None):
A new pytree containing the replicated arrays.
"""
devices = devices or _pmap_device_order()
return jax.device_put_replicated(tree, devices)
mesh = jax.make_mesh(
(len(devices),),
("_flax_jax_utils_replicate_data_axis",),
(AxisType.Auto,),
devices=devices,
)
data_sharding = NamedSharding(mesh, P("_flax_jax_utils_replicate_data_axis"))

def _device_put_replicated(x):
if isinstance(x, (jax.Array, np.ndarray)):
buf = x[None]
else:
buf = jnp.asarray(x)[None]
buf = jnp.concat([buf] * len(devices))
return jax.device_put(buf, data_sharding)

with jax.set_mesh(mesh):
return jax.tree.map(_device_put_replicated, tree)


def unreplicate(tree):
Expand Down Expand Up @@ -137,17 +155,26 @@ def prefetch_to_device(iterator, size, devices=None):
queue = collections.deque()
devices = _pmap_device_order() if devices is None else devices

mesh = jax.make_mesh(
(len(devices),),
("_flax_jax_utils_prefetch_to_device_data_axis",),
(AxisType.Auto,),
devices=devices,
)
data_sharding = NamedSharding(mesh, P("_flax_jax_utils_prefetch_to_device_data_axis"))

def _prefetch(xs):
return jax.device_put_sharded(list(xs), devices)
return jax.device_put(xs, data_sharding)

def enqueue(n): # Enqueues *up to* `n` elements from the iterator.
for data in itertools.islice(iterator, n):
queue.append(jax.tree_util.tree_map(_prefetch, data))

enqueue(size) # Fill up the buffer.
while queue:
yield queue.popleft()
enqueue(1)
with jax.set_mesh(mesh):
enqueue(size) # Fill up the buffer.
while queue:
yield queue.popleft()
enqueue(1)


def _scan_nd(body_fn, init, xs, n=1, unroll=(1,)):
Expand Down
59 changes: 59 additions & 0 deletions tests/jax_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,64 @@ def add(params, a, *, b):
np.testing.assert_allclose(np.float64(y), np.float64(5 * x + 10))


class DataShardingTest(parameterized.TestCase):
def setUp(self):
if jax.device_count() < 4:
self.skipTest('At least 4 devices required')

@parameterized.product(num_devices= ["all", 2])
def test_prefetch_to_device(self, num_devices):
devices = jax.local_devices()
if isinstance(num_devices, int):
devices = devices[:num_devices]
shape = (len(devices), 4, 16, 16, 3)
iterator = (jnp.ones(shape) for _ in range(4))

data_iter = jax_utils.prefetch_to_device(iterator, size=3, devices=devices)
for _ in range(4):
data = next(data_iter)
self.assertEqual(data.shape, shape)
self.assertIsNotNone(data.sharding)
sharding_slices_per_device = data.sharding.devices_indices_map(tuple(data.shape))
self.assertEqual(len(sharding_slices_per_device), len(devices))
# Here we check that sharding_slices_per_device is like
# Device(id=2): (slice(2, 3, None), slice(None, None, None), ..., slice(None, None, None))
for i, dev in enumerate(devices):
sharding_slice = sharding_slices_per_device[dev]
self.assertEqual(sharding_slice[0], slice(i + 0, i + 1, None))
for sharding_slice_j in sharding_slice[1:]:
self.assertEqual(sharding_slice_j, slice(None, None, None))

@parameterized.product(num_devices= ["all", 2])
def test_replicate(self, num_devices):
devices = jax.local_devices()
if isinstance(num_devices, int):
devices = devices[:num_devices]
num_batches = 5
shape = (2, 3)
data_tree = [
i * jnp.ones((2, 3)) for i in range(num_batches - 2)
] + [4, 5 * np.ones(shape)]
out_tree = jax_utils.replicate(data_tree, devices=devices)

def check_sharding(p):
if p.ndim == 1:
self.assertEqual(p.shape, (len(devices),))
else:
self.assertEqual(p.shape, (len(devices), *shape))
self.assertIsNotNone(p.sharding)
sharding_slices_per_device = p.sharding.devices_indices_map(tuple(p.shape))
self.assertEqual(len(sharding_slices_per_device), len(devices))
# Here we check that sharding_slices_per_device is like
# Device(id=2): (slice(2, 3, None), slice(None, None, None), slice(None, None, None))
for i, dev in enumerate(devices):
sharding_slice = sharding_slices_per_device[dev]
self.assertEqual(sharding_slice[0], slice(i + 0, i + 1, None))
for sharding_slice_j in sharding_slice[1:]:
self.assertEqual(sharding_slice_j, slice(None, None, None))

jax.tree.map(check_sharding, out_tree)


if __name__ == '__main__':
absltest.main()
Loading