Skip to content

Commit 0a6b1b1

Browse files
Jake VanderPlasFlax Authors
authored andcommitted
Avoid passing concrete argument to jax.remat
This argument has had no effect since JAX v0.3.17, aside from raising `NotImplementedError` if it is set to `True`. It will be deprecated in JAX v0.8.2 and eventually removed. #jax-fixit PiperOrigin-RevId: 839319133
1 parent 6ac5a78 commit 0a6b1b1

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

flax/core/lift.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1458,9 +1458,9 @@ def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args, **kwargs):
14581458
# add 2 to each static_argnums because we add two initial arguments to rematted
14591459
static_argnums_ = jax.tree_util.tree_map(lambda x: x + 2, static_argnums)
14601460

1461+
14611462
@functools.partial(
14621463
jax.remat,
1463-
concrete=concrete,
14641464
static_argnums=static_argnums_,
14651465
prevent_cse=prevent_cse,
14661466
policy=policy,

flax/linen/partitioning.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,20 @@ def _repack_remat_args(dyn_args, static_args):
552552
def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args):
553553
static_args = tuple(x for i, x in enumerate(args) if i in static_argnums)
554554
dyn_args = tuple(x for i, x in enumerate(args) if i not in static_argnums)
555+
556+
# After JAX v0.3.16, concrete=False is a no-op and concrete=True raises
557+
# NotImplementedError. Starting in JAX v0.8.2, the concrete argument is
558+
# deprecated and will be removed in the future.
559+
if concrete:
560+
raise NotImplementedError(
561+
"The concrete argument is deprecated. Use static_argnums instead, and"
562+
" for more information, see"
563+
" https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html"
564+
)
565+
del concrete
555566

556567
@functools.partial(
557-
jax.remat, concrete=concrete, prevent_cse=prevent_cse, policy=policy
568+
jax.remat, prevent_cse=prevent_cse, policy=policy
558569
)
559570
@functools.wraps(fn)
560571
def rematted(variable_groups, rng_groups, *dyn_args):

0 commit comments

Comments
 (0)