Skip to content

Commit 9a113b6

Browse files
committed
Use MaybeJit for summaries
1 parent 74985b2 commit 9a113b6

File tree

3 files changed

+54
-2
lines changed

3 files changed

+54
-2
lines changed

flax/nnx/rnglib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def _to_keyless(
5858

5959

6060
def _function_to_method(random_f):
61+
@functools.wraps(random_f)
6162
def rngs_random_method(self: Rngs | RngStream, *args, **kwargs) -> jax.Array:
6263
return random_f(self(), *args, **kwargs)
6364

flax/nnx/summary.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import io
2020
import typing as tp
2121
from types import MappingProxyType
22+
import functools
23+
from types import SimpleNamespace
2224

2325
import jax
2426
import numpy as np
@@ -34,6 +36,7 @@
3436

3537
from functools import wraps
3638

39+
3740
try:
3841
from IPython import get_ipython
3942

@@ -50,6 +53,41 @@ class NoneDumper(yaml.SafeDumper):
5053
lambda dumper, data: dumper.represent_scalar('tag:yaml.org,2002:str', 'None'),
5154
)
5255

56+
class MaybeJit:
57+
"""
58+
Wraps a function with nnx.jit, but saves the original to run
59+
if the function turns out to be non-concrete. We can't get the flops of non-concrete functions,
60+
but we should still be able to trace the input and output shapes.
61+
"""
62+
def __init__(self, f):
63+
self.f = f
64+
self.jitted = nnx.jit(f)
65+
functools.update_wrapper(self, self.f)
66+
67+
# implement descriptor protocol so that we can use this as a method
68+
def __get__(self, obj, objtype=None):
69+
if obj is None:
70+
return self
71+
return functools.partial(self, obj)
72+
73+
def __call__(self, *args):
74+
try:
75+
return self.jitted(*args)
76+
except TypeError as e:
77+
return self.f(*args)
78+
79+
# This will only be used for the top-level method
80+
def trace(self, *args):
81+
return self(*args)
82+
83+
def lower(self, *args):
84+
try:
85+
return self.jitted.lower(*args)
86+
except TypeError as e:
87+
result = self.f(*args)
88+
# Mock a `Lowered` instance with a SimpleNamespace
89+
return SimpleNamespace(cost_analysis=-1, lowered=SimpleNamespace(out_info=(None, None, result)))
90+
5391
class SizeBytes(typing.SizeBytes):
5492
def __repr__(self) -> str:
5593
bytes_repr = _bytes_repr(self.bytes)
@@ -222,7 +260,7 @@ def _overwrite_methods(env):
222260

223261
def _get_flops(e) -> int:
224262
cost = e.cost_analysis() or e.compile().cost_analysis()
225-
return 0 if cost is None or 'flops' not in cost else int(cost['flops'])
263+
return -1 if cost is None or 'flops' not in cost else int(cost['flops'])
226264

227265
def tabulate(
228266
obj,
@@ -378,7 +416,8 @@ def tabulate(
378416
# that each method will only be traced (and added to the table) once.
379417
jits = {} # Maps (class, method_name) to jit
380418
for key, value in saver_env.items():
381-
jits[key] = nnx.jit(value)
419+
jits[key] = MaybeJit(value)
420+
382421
_overwrite_methods(jits)
383422

384423
# Trace the top function (which indirectly traces all the others)

tests/nnx/summary_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,5 +295,17 @@ def __call__(self, x):
295295
self.assertEqual(module.hooked_param.get_metadata('description'), 'Custom parameter')
296296
self.assertEqual(module.hooked_param.get_metadata('trainable'), True)
297297

298+
def test_tabulate_concrete_shape(self):
299+
class Net(nnx.Module):
300+
def __init__(self):
301+
self.rngs = nnx.Rngs(0)
302+
303+
def __call__(self, x):
304+
return self.rngs.uniform((x.shape[0], 10))
305+
306+
net = Net()
307+
x = jnp.zeros((4, 8))
308+
print(nnx.tabulate(net, x, console_kwargs={"width": 200}))
309+
298310
if __name__ == '__main__':
299311
absltest.main()

0 commit comments

Comments
 (0)