Skip to content

Commit 60d4ab4

Browse files
committed
Use nnx.split to avoid compilation issue
1 parent 4aa7bd7 commit 60d4ab4

File tree

1 file changed

+14
-46
lines changed

1 file changed

+14
-46
lines changed

flax/nnx/summary.py

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from types import MappingProxyType
2222
import functools
2323
import itertools
24-
from types import SimpleNamespace
2524

2625
import jax
2726
import numpy as np
@@ -54,38 +53,6 @@ class NoneDumper(yaml.SafeDumper):
5453
lambda dumper, data: dumper.represent_scalar('tag:yaml.org,2002:str', 'None'),
5554
)
5655

57-
class MaybeJit:
58-
"""
59-
Wraps a function with nnx.jit, but saves the original to run
60-
if the function turns out to be non-concrete. We can't get the flops of non-concrete functions,
61-
but we should still be able to trace the input and output shapes.
62-
"""
63-
def __init__(self, f):
64-
functools.update_wrapper(self, f)
65-
self.f = f
66-
self.jitted = nnx.jit(f)
67-
self.seen = set()
68-
69-
# implement descriptor protocol so that we can use this as a method
70-
def __get__(self, obj, objtype=None):
71-
if obj is None:
72-
return self
73-
return functools.partial(self, obj)
74-
75-
def __call__(self, *args, **kwargs):
76-
try:
77-
return self.jitted(*args, **kwargs)
78-
except TypeError as e:
79-
return self.f(*args, **kwargs)
80-
81-
def lower(self, *args, **kwargs):
82-
try:
83-
return self.jitted.lower(*args, **kwargs)
84-
except TypeError as e:
85-
result = self.f(*args, **kwargs)
86-
# Mock a `Lowered` instance with a SimpleNamespace
87-
return SimpleNamespace(cost_analysis="0", lowered=SimpleNamespace(out_info=(None, None, result)))
88-
8956
class SizeBytes(typing.SizeBytes):
9057
def __repr__(self) -> str:
9158
bytes_repr = _bytes_repr(self.bytes)
@@ -242,11 +209,11 @@ def do_vjp(*args, **kwargs):
242209

243210
method_name = f.__name__
244211

245-
def split_f(*args, **kwargs):
212+
@functools.partial(jax.jit)
213+
def jit_f(graphdef, state):
214+
args, kwargs = nnx.merge(graphdef, state)
246215
return f(*args, **kwargs)
247-
248-
nnx.jit(split_f)
249-
# TODO: write split_f and continue from here.
216+
jit_f._seen = set()
250217

251218
@wraps(f)
252219
def wrapper(obj, *args, **kwargs):
@@ -257,19 +224,20 @@ def wrapper(obj, *args, **kwargs):
257224
if method_name != '__call__':
258225
path = (*path, method_name)
259226
identifier = (inputs_repr, object_id)
260-
if identifier not in f.seen:
261-
counter_val = next(counter)
262-
lowered = f.lower(obj, *args, **kwargs)
227+
counter_val = next(counter)
228+
graphdef, state = nnx.split(((obj, *args), kwargs))
229+
lowered = jit_f.lower(graphdef, state)
230+
if identifier not in jit_f._seen:
231+
jit_f._seen.add(identifier)
263232
flops = _get_flops(lowered) if compute_flops else None
264-
outputs = lowered.lowered.out_info[2]
233+
outputs = lowered.out_info
265234
output_repr = jax.tree.map(_to_dummy_array, outputs)
266235
vjp_flops = _get_flops(jax.jit(do_vjp).lower(
267236
obj, *args, **kwargs)) if compute_vjp_flops else None
268237
tracer_args.append(
269238
CallInfo(counter_val, object_id, type(obj), path, inputs_repr,
270239
output_repr, flops, vjp_flops))
271-
f.seen.add(identifier)
272-
return f(obj, *args, **kwargs)
240+
return jit_f(graphdef, state)
273241
return wrapper
274242

275243
def _overwrite_methods(env):
@@ -430,12 +398,12 @@ def tabulate(
430398

431399
# Modify all the object's methods to save their lowered JIT representations.
432400
rows : list[CallInfo] = []
433-
maybejits = {k: _save_call_info(counter, rows, v, node_stats, compute_flops, compute_vjp_flops)
401+
jits = {k: _save_call_info(counter, rows, v, node_stats, compute_flops, compute_vjp_flops)
434402
for k, v in env.items()}
435-
_overwrite_methods(maybejits)
403+
_overwrite_methods(jits)
436404

437405
# Trace the top function (which indirectly traces all the others)
438-
maybejits[(type(obj), method)](obj, *input_args, **input_kwargs)
406+
jits[(type(obj), method)](obj, *input_args, **input_kwargs)
439407

440408
# Sort call info in pre-order traversal order
441409
rows.sort(key=lambda x: x.call_order)

0 commit comments

Comments
 (0)