Skip to content

Commit 3cbe27b

Browse files
Jake VanderPlasFlax Authors
authored andcommitted
Avoid passing concrete argument to jax.remat
This argument has had no effect since JAX v0.3.17, aside from raising `NotImplementedError` if it is set to `True`. It will be deprecated in JAX v0.8.2 and eventually removed (jax-ml/jax#33674). Flax should probably deprecate this argument from its own `remat` wrappers, but I'll leave that up to the team. #jax-fixit PiperOrigin-RevId: 839319133
1 parent 59cfd99 commit 3cbe27b

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

flax/core/lift.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1458,9 +1458,18 @@ def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args, **kwargs):
14581458
# add 2 to each static_argnums because we add two initial arguments to rematted
14591459
static_argnums_ = jax.tree_util.tree_map(lambda x: x + 2, static_argnums)
14601460

1461+
# After JAX v0.3.16, concrete=False is a no-op and concrete=True raises
1462+
# NotImplementedError. Starting in JAX v0.8.2, the concrete argument is
1463+
# deprecated and will be removed in the future.
1464+
if concrete:
1465+
raise NotImplementedError(
1466+
"The concrete argument is deprecated. Use static_argnums instead."
1467+
" for more information, see"
1468+
" https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html"
1469+
)
1470+
14611471
@functools.partial(
14621472
jax.remat,
1463-
concrete=concrete,
14641473
static_argnums=static_argnums_,
14651474
prevent_cse=prevent_cse,
14661475
policy=policy,

flax/linen/partitioning.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,18 @@ def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args):
553553
static_args = tuple(x for i, x in enumerate(args) if i in static_argnums)
554554
dyn_args = tuple(x for i, x in enumerate(args) if i not in static_argnums)
555555

556+
# After JAX v0.3.16, concrete=False is a no-op and concrete=True raises
557+
# NotImplementedError. Starting in JAX v0.8.2, the concrete argument is
558+
# deprecated and will be removed in the future.
559+
if concrete:
560+
raise NotImplementedError(
561+
"The concrete argument is deprecated. Use static_argnums instead."
562+
" for more information, see"
563+
" https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html"
564+
)
565+
556566
@functools.partial(
557-
jax.remat, concrete=concrete, prevent_cse=prevent_cse, policy=policy
567+
jax.remat, prevent_cse=prevent_cse, policy=policy
558568
)
559569
@functools.wraps(fn)
560570
def rematted(variable_groups, rng_groups, *dyn_args):

0 commit comments

Comments
 (0)