Skip to content

Commit 8611538

Browse files
author
Flax Authors
committed
Merge pull request #1456 from jheek:better-transform-error
PiperOrigin-RevId: 390124613
2 parents 470aaa6 + 964a92b commit 8611538

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

flax/core/tracers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import jax
1818

19+
from .. import errors
20+
1921

2022
def current_trace():
2123
"""Returns the innermost Jax tracer."""
@@ -32,4 +34,4 @@ def trace_level(main):
3234
def check_trace_level(base_level):
3335
level = trace_level(current_trace())
3436
if level != base_level:
35-
raise ValueError('Jax transforms and modules cannot be mixed.')
37+
raise errors.JaxTransformError()

flax/errors.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,19 @@ def __init__(self, col, variable_name, scope_path):
263263
f'"{scope_path}" because collection "{col}" is immutable.')
264264

265265

266+
class JaxTransformError(FlaxError):
267+
"""
268+
JAX transforms and Flax modules cannot be mixed.
269+
270+
JAX's functional transformations expect pure function.
271+
When you want to use JAX transformations **inside** Flax models,
272+
you should make use of the Flax transformation wrappers
273+
(e.g.: ``flax.linen.vmap``, ``flax.linen.scan``, etc.).
274+
"""
275+
def __init__(self):
276+
super().__init__('Jax transforms and Flax models cannot be mixed.')
277+
278+
266279
#################################################
267280
# module.py errors #
268281
#################################################

0 commit comments

Comments
 (0)