Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,9 +1458,18 @@ def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args, **kwargs):
# add 2 to each static_argnums because we add two initial arguments to rematted
static_argnums_ = jax.tree_util.tree_map(lambda x: x + 2, static_argnums)

# After JAX v0.3.16, concrete=False is a no-op and concrete=True raises
# NotImplementedError. Starting in JAX v0.8.2, the concrete argument is
# deprecated and will be removed in the future.
if concrete:
raise NotImplementedError(
"The concrete argument is deprecated. Use static_argnums instead."
" for more information, see"
" https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html"
)

@functools.partial(
jax.remat,
concrete=concrete,
static_argnums=static_argnums_,
prevent_cse=prevent_cse,
policy=policy,
Expand Down
12 changes: 11 additions & 1 deletion flax/linen/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,18 @@ def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args):
static_args = tuple(x for i, x in enumerate(args) if i in static_argnums)
dyn_args = tuple(x for i, x in enumerate(args) if i not in static_argnums)

# After JAX v0.3.16, concrete=False is a no-op and concrete=True raises
# NotImplementedError. Starting in JAX v0.8.2, the concrete argument is
# deprecated and will be removed in the future.
if concrete:
raise NotImplementedError(
"The concrete argument is deprecated. Use static_argnums instead."
" for more information, see"
" https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html"
)

@functools.partial(
jax.remat, concrete=concrete, prevent_cse=prevent_cse, policy=policy
jax.remat, prevent_cse=prevent_cse, policy=policy
)
@functools.wraps(fn)
def rematted(variable_groups, rng_groups, *dyn_args):
Expand Down
Loading