Skip to content

Commit e379298

Browse files
committed
Refactor tabulate
1 parent 5109e2c commit e379298

File tree

3 files changed

+87
-76
lines changed

3 files changed

+87
-76
lines changed

flax/nnx/rnglib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def _to_keyless(
5858

5959

6060
def _function_to_method(random_f):
61+
@functools.wraps(random_f)
6162
def rngs_random_method(self: Rngs | RngStream, *args, **kwargs) -> jax.Array:
6263
return random_f(self(), *args, **kwargs)
6364

flax/nnx/summary.py

Lines changed: 71 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import io
2020
import typing as tp
2121
from types import MappingProxyType
22+
import functools
23+
import itertools
2224

2325
import jax
2426
import numpy as np
@@ -34,6 +36,7 @@
3436

3537
from functools import wraps
3638

39+
3740
try:
3841
from IPython import get_ipython
3942

@@ -130,14 +133,14 @@ def __str__(self):
130133

131134
@dataclasses.dataclass
132135
class CallInfo:
136+
call_order: int
133137
object_id: int
134138
type: type
135139
path: statelib.PathParts
136-
input_args: tuple[tp.Any, ...]
137-
input_kwargs: dict[str, tp.Any]
140+
inputs_repr: str
138141
outputs: tp.Any
139-
flops: int | None = None
140-
vjp_flops: int | None = None
142+
flops: int | None
143+
vjp_flops: int | None
141144

142145
class SimpleObjectRepr:
143146
def __init__(self, obj: tp.Any):
@@ -168,32 +171,6 @@ def inner(state, *args, **kwargs):
168171
return f(model, *args, **kwargs)
169172
return jax.vjp(inner, state, *args, **kwargs)
170173

171-
def _get_call_info(jitted, method_name, node_stats, obj, compute_flops: bool, *args, **kwargs):
172-
e = jitted.lower(obj, *args, **kwargs)
173-
flops = _get_flops(e) if compute_flops else None
174-
outputs = e.lowered.out_info[2]
175-
output_repr = jax.tree.map(_to_dummy_array, outputs)
176-
input_args_info, input_kwargs_info = jax.tree.map(
177-
_to_dummy_array, (args, kwargs)
178-
)
179-
object_id: int = getattr(obj, '_nnx_tabulate_id')
180-
node_info = node_stats[object_id]
181-
assert node_info is not None
182-
path = node_info.path
183-
if method_name != '__call__':
184-
path = (*path, method_name)
185-
186-
return CallInfo(
187-
object_id=object_id,
188-
type=type(obj),
189-
path=path,
190-
input_args=input_args_info,
191-
input_kwargs=input_kwargs_info,
192-
outputs=output_repr,
193-
flops=flops,
194-
)
195-
196-
197174
def filter_rng_streams(row: CallInfo):
198175
return not issubclass(row.type, nnx.RngStream)
199176

@@ -206,13 +183,60 @@ def _create_obj_env(object_types):
206183
result[(obj_type, name)] = top_method
207184
return result
208185

209-
def _argsave(tracer_args, f):
186+
def _get_inputs_repr(args, kwargs):
187+
input_args, input_kwargs = jax.tree.map(
188+
_to_dummy_array, (args, kwargs)
189+
)
190+
inputs_repr = ''
191+
if input_args:
192+
if len(input_args) == 1 and not input_kwargs:
193+
inputs_repr += _as_yaml_str(input_args[0])
194+
else:
195+
inputs_repr += _as_yaml_str(input_args)
196+
if input_kwargs:
197+
inputs_repr += '\n'
198+
if input_kwargs:
199+
inputs_repr += _as_yaml_str(input_kwargs)
200+
return inputs_repr
201+
202+
def _save_call_info(counter, tracer_args, f, node_stats, compute_flops, compute_vjp_flops, seen):
210203
"Wrap a function to save its arguments"
211-
n = f.__name__
204+
205+
# Used when computing vjp flops
206+
def do_vjp(*args, **kwargs):
207+
primals, f_vjp = jax.vjp(f, *args, **kwargs)
208+
return f_vjp(primals)
209+
210+
method_name = f.__name__
211+
212+
@functools.partial(jax.jit)
213+
def jit_f(graphdef, state):
214+
args, kwargs = nnx.merge(graphdef, state)
215+
return f(*args, **kwargs)
216+
212217
@wraps(f)
213218
def wrapper(obj, *args, **kwargs):
214-
tracer_args.append((obj, n, args, kwargs))
215-
return f(obj, *args, **kwargs)
219+
inputs_repr = _get_inputs_repr(args, kwargs)
220+
object_id = getattr(obj, '_nnx_tabulate_id')
221+
node_info = node_stats[object_id]
222+
path = node_info.path
223+
if method_name != '__call__':
224+
path = (*path, method_name)
225+
identifier = (inputs_repr, object_id)
226+
counter_val = next(counter)
227+
graphdef, state = nnx.split(((obj, *args), kwargs))
228+
lowered = jit_f.lower(graphdef, state)
229+
if identifier not in seen:
230+
seen.add(identifier)
231+
flops = _get_flops(lowered) if compute_flops else None
232+
outputs = lowered.out_info
233+
output_repr = jax.tree.map(_to_dummy_array, outputs)
234+
vjp_flops = _get_flops(jax.jit(do_vjp).lower(
235+
obj, *args, **kwargs)) if compute_vjp_flops else None
236+
tracer_args.append(
237+
CallInfo(counter_val, object_id, type(obj), path, inputs_repr,
238+
output_repr, flops, vjp_flops))
239+
return jit_f(graphdef, state)
216240
return wrapper
217241

218242
def _overwrite_methods(env):
@@ -367,39 +391,22 @@ def tabulate(
367391
# iteration over methods easier.
368392
env = _create_obj_env(object_types)
369393

370-
# Modify all the object's methods to save their Tracer arguments.
371-
# tracer_args contains (object, name, args, kwargs) tuples.
372-
tracer_args: list[tuple[tp.Any, str, tuple, dict[str, tp.Any]]] = []
373-
saver_env = {k: _argsave(tracer_args, v) for k,v in env.items()}
374-
_overwrite_methods(saver_env)
375-
376-
# Add JIT calculation to each method. We can extract flops and output info from
377-
# the lowered JITs. We'll only call these jitted values, which guarantees
378-
# that each method will only be traced (and added to the table) once.
379-
jits = {} # Maps (class, method_name) to jit
380-
for key, value in saver_env.items():
381-
jits[key] = nnx.jit(value)
394+
# Information is recorded in post-order, but should be presented as a pre-order traversal.
395+
# This keeps track of the order of calls.
396+
counter = itertools.count(0)
397+
398+
# Modify all the object's methods to save their lowered JIT representations.
399+
rows : list[CallInfo] = []
400+
seen : set = set()
401+
jits = {k: _save_call_info(counter, rows, v, node_stats, compute_flops, compute_vjp_flops, seen)
402+
for k, v in env.items()}
382403
_overwrite_methods(jits)
383404

384405
# Trace the top function (which indirectly traces all the others)
385-
jits[(type(obj), method)].trace(obj, *input_args, **input_kwargs)
406+
jits[(type(obj), method)](obj, *input_args, **input_kwargs)
386407

387-
# Get call_info
388-
rows : list[CallInfo] = [_get_call_info(
389-
jits[(type(object), name)], name, node_stats, object,
390-
compute_flops, *args, **kwargs)
391-
for (object, name, args, kwargs) in tracer_args]
392-
393-
# Add VJP flops if required. This needs to be done separately because calls to `_pure_nnx_vjp`
394-
# can result in tracing the jitted functions a second time if there's shared structure.
395-
# This would add items to `tracer_args`, resulting in duplicate rows in the table.
396-
if compute_vjp_flops:
397-
for i, row in enumerate(rows):
398-
object, method_name, args, kwargs = tracer_args[i]
399-
def do_vjp(*args, **kwargs):
400-
primals, f_vjp = _pure_nnx_vjp(jits[(type(object), method_name)], *args, **kwargs)
401-
return f_vjp(primals)
402-
row.vjp_flops = _get_flops(jax.jit(do_vjp).lower(object, *args, **kwargs))
408+
# Sort call info in pre-order traversal order
409+
rows.sort(key=lambda x: x.call_order)
403410

404411
# Restore the object's original methods
405412
_overwrite_methods(env)
@@ -436,17 +443,7 @@ def do_vjp(*args, **kwargs):
436443
path_str = '/'.join(map(str, row.path))
437444
col_reprs.append(path_str)
438445
col_reprs.append(row.type.__name__)
439-
inputs_repr = ''
440-
if row.input_args:
441-
input_args = row.input_args
442-
if len(row.input_args) == 1 and not row.input_kwargs:
443-
input_args = row.input_args[0]
444-
inputs_repr += _as_yaml_str(input_args)
445-
if inputs_repr and row.input_kwargs:
446-
inputs_repr += '\n'
447-
if row.input_kwargs:
448-
inputs_repr += _as_yaml_str(row.input_kwargs)
449-
col_reprs.append(inputs_repr)
446+
col_reprs.append(row.inputs_repr)
450447
col_reprs.append(_as_yaml_str(row.outputs))
451448
if compute_flops:
452449
col_reprs.append(str(row.flops))

tests/nnx/summary_test.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ def __call__(self, x):
4141

4242
foo = Foo(nnx.Rngs(0))
4343
x = jnp.ones((1, 32))
44-
table_repr = nnx.tabulate(
44+
table_repr_ = nnx.tabulate(
4545
foo, x, console_kwargs=CONSOLE_TEST_KWARGS
46-
).splitlines()
46+
)
47+
table_repr = table_repr_.splitlines()
4748

4849
self.assertIn('Foo Summary', table_repr[0])
4950
self.assertIn('path', table_repr[2])
@@ -295,5 +296,17 @@ def __call__(self, x):
295296
self.assertEqual(module.hooked_param.get_metadata('description'), 'Custom parameter')
296297
self.assertEqual(module.hooked_param.get_metadata('trainable'), True)
297298

299+
def test_tabulate_concrete_shape(self):
300+
class Net(nnx.Module):
301+
def __init__(self):
302+
self.rngs = nnx.Rngs(0)
303+
304+
def __call__(self, x):
305+
return self.rngs.uniform((x.shape[0], 10))
306+
307+
net = Net()
308+
x = jnp.zeros((4, 8))
309+
nnx.tabulate(net, x, console_kwargs={"width": 200})
310+
298311
if __name__ == '__main__':
299312
absltest.main()

0 commit comments

Comments
 (0)