Skip to content

Conversation

@vfdev-5
Copy link
Collaborator

@vfdev-5 vfdev-5 commented Nov 19, 2025

Description:

  • Rewritten flax.jax_utils.prefetch_to_device and flax.jax_utils.replicate using jax.device_put
  • Added tests

Context:

Jax deprecated the ops: jax.device_put_replicated and jax.device_put_sharded: https://docs.jax.dev/en/latest/changelog.html#jax-0-8-1-november-18-2025

flax.jax_utils.prefetch_to_device is used in ImageNet example and we want re-enable examples tests: #5099

In this PR flax.jax_utils.prefetch_to_device is recoded using jax.device_put, jax make_mesh and NamedSharding etc to fix deprecation warning. flax.jax_utils.replicate is also recoded.

@vfdev-5 vfdev-5 marked this pull request as ready for review November 19, 2025 21:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants