diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index 2c8479b3f..71a3550cb 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -57,6 +57,7 @@ def _to_keyless( def _function_to_method(random_f): + @functools.wraps(random_f) def rngs_random_method(self: Rngs | RngStream, *args, **kwargs) -> jax.Array: return random_f(self(), *args, **kwargs) diff --git a/flax/nnx/summary.py b/flax/nnx/summary.py index 2789732d3..b9a84dd64 100644 --- a/flax/nnx/summary.py +++ b/flax/nnx/summary.py @@ -19,6 +19,8 @@ import io import typing as tp from types import MappingProxyType +import functools +import itertools import jax import numpy as np @@ -34,6 +36,7 @@ from functools import wraps + try: from IPython import get_ipython @@ -130,14 +133,14 @@ def __str__(self): @dataclasses.dataclass class CallInfo: + call_order: int object_id: int type: type path: statelib.PathParts - input_args: tuple[tp.Any, ...] - input_kwargs: dict[str, tp.Any] + inputs_repr: str outputs: tp.Any - flops: int | None = None - vjp_flops: int | None = None + flops: int | None + vjp_flops: int | None class SimpleObjectRepr: def __init__(self, obj: tp.Any): @@ -168,32 +171,6 @@ def inner(state, *args, **kwargs): return f(model, *args, **kwargs) return jax.vjp(inner, state, *args, **kwargs) -def _get_call_info(jitted, method_name, node_stats, obj, compute_flops: bool, *args, **kwargs): - e = jitted.lower(obj, *args, **kwargs) - flops = _get_flops(e) if compute_flops else None - outputs = e.lowered.out_info[2] - output_repr = jax.tree.map(_to_dummy_array, outputs) - input_args_info, input_kwargs_info = jax.tree.map( - _to_dummy_array, (args, kwargs) - ) - object_id: int = getattr(obj, '_nnx_tabulate_id') - node_info = node_stats[object_id] - assert node_info is not None - path = node_info.path - if method_name != '__call__': - path = (*path, method_name) - - return CallInfo( - object_id=object_id, - type=type(obj), - path=path, - input_args=input_args_info, - input_kwargs=input_kwargs_info, - outputs=output_repr, - flops=flops, - ) - - def filter_rng_streams(row: CallInfo): return not issubclass(row.type, nnx.RngStream) @@ -206,13 +183,64 @@ def _create_obj_env(object_types): result[(obj_type, name)] = top_method return result -def _argsave(tracer_args, f): +def _get_inputs_repr(args, kwargs): + input_args, input_kwargs = jax.tree.map( + _to_dummy_array, (args, kwargs) + ) + inputs_repr = '' + if input_args: + if len(input_args) == 1 and not input_kwargs: + inputs_repr += _as_yaml_str(input_args[0]) + else: + inputs_repr += _as_yaml_str(input_args) + if input_kwargs: + inputs_repr += '\n' + if input_kwargs: + inputs_repr += _as_yaml_str(input_kwargs) + return inputs_repr + +def _save_call_info(counter, tracer_args, f, node_stats, compute_flops, compute_vjp_flops, seen): "Wrap a function to save its arguments" - n = f.__name__ + + # Used when computing vjp flops + def do_vjp(*args, **kwargs): + primals, f_vjp = jax.vjp(f, *args, **kwargs) + return f_vjp(primals) + + method_name = f.__name__ + + @functools.partial(jax.jit) + def jit_f(graphdef, state): + args, kwargs = nnx.merge(graphdef, state) + return f(*args, **kwargs) + @wraps(f) def wrapper(obj, *args, **kwargs): - tracer_args.append((obj, n, args, kwargs)) - return f(obj, *args, **kwargs) + inputs_repr = _get_inputs_repr(args, kwargs) + object_id = getattr(obj, '_nnx_tabulate_id') + node_info = node_stats[object_id] + path = node_info.path + if method_name != '__call__': + path = (*path, method_name) + identifier = (inputs_repr, object_id) + counter_val = next(counter) + graphdef, state = nnx.split(((obj, *args), kwargs)) + if compute_flops: + lowered = jit_f.lower(graphdef, state) + flops = _get_flops(lowered) + outputs = lowered.out_info + else: + flops = None + outputs = jit_f(graphdef, state) + if identifier not in seen: + seen.add(identifier) + output_repr = jax.tree.map(_to_dummy_array, outputs) + vjp_flops = _get_flops(jax.jit(do_vjp).lower( + obj, *args, **kwargs)) if compute_vjp_flops else None + tracer_args.append( + CallInfo(counter_val, object_id, type(obj), path, inputs_repr, + output_repr, flops, vjp_flops)) + return jit_f(graphdef, state) return wrapper def _overwrite_methods(env): @@ -367,39 +395,22 @@ def tabulate( # iteration over methods easier. env = _create_obj_env(object_types) - # Modify all the object's methods to save their Tracer arguments. - # tracer_args contains (object, name, args, kwargs) tuples. - tracer_args: list[tuple[tp.Any, str, tuple, dict[str, tp.Any]]] = [] - saver_env = {k: _argsave(tracer_args, v) for k,v in env.items()} - _overwrite_methods(saver_env) - - # Add JIT calculation to each method. We can extract flops and output info from - # the lowered JITs. We'll only call these jitted values, which guarantees - # that each method will only be traced (and added to the table) once. - jits = {} # Maps (class, method_name) to jit - for key, value in saver_env.items(): - jits[key] = nnx.jit(value) + # Information is recorded in post-order, but should be presented as a pre-order traversal. + # This keeps track of the order of calls. + counter = itertools.count(0) + + # Modify all the object's methods to save their lowered JIT representations. + rows : list[CallInfo] = [] + seen : set = set() + jits = {k: _save_call_info(counter, rows, v, node_stats, compute_flops, compute_vjp_flops, seen) + for k, v in env.items()} _overwrite_methods(jits) # Trace the top function (which indirectly traces all the others) - jits[(type(obj), method)].trace(obj, *input_args, **input_kwargs) + jits[(type(obj), method)](obj, *input_args, **input_kwargs) - # Get call_info - rows : list[CallInfo] = [_get_call_info( - jits[(type(object), name)], name, node_stats, object, - compute_flops, *args, **kwargs) - for (object, name, args, kwargs) in tracer_args] - - # Add VJP flops if required. This needs to be done separately because calls to `_pure_nnx_vjp` - # can result in tracing the jitted functions a second time if there's shared structure. - # This would add items to `tracer_args`, resulting in duplicate rows in the table. - if compute_vjp_flops: - for i, row in enumerate(rows): - object, method_name, args, kwargs = tracer_args[i] - def do_vjp(*args, **kwargs): - primals, f_vjp = _pure_nnx_vjp(jits[(type(object), method_name)], *args, **kwargs) - return f_vjp(primals) - row.vjp_flops = _get_flops(jax.jit(do_vjp).lower(object, *args, **kwargs)) + # Sort call info in pre-order traversal order + rows.sort(key=lambda x: x.call_order) # Restore the object's original methods _overwrite_methods(env) @@ -436,17 +447,7 @@ def do_vjp(*args, **kwargs): path_str = '/'.join(map(str, row.path)) col_reprs.append(path_str) col_reprs.append(row.type.__name__) - inputs_repr = '' - if row.input_args: - input_args = row.input_args - if len(row.input_args) == 1 and not row.input_kwargs: - input_args = row.input_args[0] - inputs_repr += _as_yaml_str(input_args) - if inputs_repr and row.input_kwargs: - inputs_repr += '\n' - if row.input_kwargs: - inputs_repr += _as_yaml_str(row.input_kwargs) - col_reprs.append(inputs_repr) + col_reprs.append(row.inputs_repr) col_reprs.append(_as_yaml_str(row.outputs)) if compute_flops: col_reprs.append(str(row.flops)) diff --git a/tests/nnx/summary_test.py b/tests/nnx/summary_test.py index c8b3b7e49..cb6d918a6 100644 --- a/tests/nnx/summary_test.py +++ b/tests/nnx/summary_test.py @@ -41,9 +41,10 @@ def __call__(self, x): foo = Foo(nnx.Rngs(0)) x = jnp.ones((1, 32)) - table_repr = nnx.tabulate( + table_repr_ = nnx.tabulate( foo, x, console_kwargs=CONSOLE_TEST_KWARGS - ).splitlines() + ) + table_repr = table_repr_.splitlines() self.assertIn('Foo Summary', table_repr[0]) self.assertIn('path', table_repr[2]) @@ -224,6 +225,32 @@ def __call__(self, x): # We should see 3 calls per block, plus one overall call self.assertEqual(sum([s.startswith("├─") for s in table.splitlines()]), 7) + def test_time_complexity(self): + counter = [] + + class Block(nnx.Module): + def __init__(self, rngs): + self.linear = nnx.Linear(2, 2, rngs=rngs) + + def __call__(self, x): + counter.append(1) + return self.linear(x) + + class Model(nnx.Module): + def __init__(self, rngs): + for d in range(10): + setattr(self, f"linear{d}", Block(rngs)) + + def __call__(self, x): + for d in range(10): + x = getattr(self, f"linear{d}")(x) + return x + + m = Model(nnx.Rngs(0)) + x = jnp.ones((4, 2)) + nnx.tabulate(m, x, compute_flops=True, compute_vjp_flops=False) + self.assertEqual(len(counter), 10) + def test_shared(self): class Block(nnx.Module): def __init__(self, linear: nnx.Linear, *, rngs): @@ -295,5 +322,17 @@ def __call__(self, x): self.assertEqual(module.hooked_param.get_metadata('description'), 'Custom parameter') self.assertEqual(module.hooked_param.get_metadata('trainable'), True) + def test_tabulate_concrete_shape(self): + class Net(nnx.Module): + def __init__(self): + self.rngs = nnx.Rngs(0) + + def __call__(self, x): + return self.rngs.uniform((x.shape[0], 10)) + + net = Net() + x = jnp.zeros((4, 8)) + nnx.tabulate(net, x, console_kwargs={"width": 200}) + if __name__ == '__main__': absltest.main()