Skip to content

Commit 415454a

Browse files
committed
refactor summary
1 parent 09537f4 commit 415454a

File tree

1 file changed

+50
-30
lines changed

1 file changed

+50
-30
lines changed

flax/nnx/summary.py

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _collect_stats(
9393
if id(value) in node_stats:
9494
continue
9595
elif isinstance(value, variablelib.Variable):
96-
var_type = type(value)
96+
var_type = value.var_type
9797
if issubclass(var_type, nnx.RngState):
9898
var_type = nnx.RngState
9999
size_bytes = SizeBytes.from_any(value.get_value())
@@ -168,10 +168,32 @@ def inner(state, *args, **kwargs):
168168
return f(model, *args, **kwargs)
169169
return jax.vjp(inner, state, *args, **kwargs)
170170

171-
def _get_call_info(jitted, method_name, node_stats, obj, compute_flops: bool, *args, **kwargs):
172-
e = jitted.lower(obj, *args, **kwargs)
173-
flops = _get_flops(e) if compute_flops else None
174-
outputs = e.lowered.out_info[2]
171+
def _get_call_info(
172+
jitted,
173+
method_name,
174+
node_stats,
175+
obj,
176+
compute_flops: bool,
177+
compute_vjp_flops: bool,
178+
args,
179+
kwargs,
180+
outputs,
181+
):
182+
if compute_flops:
183+
e = jitted.lower(obj, *args, **kwargs)
184+
flops = _get_flops(e)
185+
else:
186+
flops = None
187+
if compute_vjp_flops:
188+
189+
def do_vjp(*args, **kwargs):
190+
primals, f_vjp = _pure_nnx_vjp(jitted, obj, *args, **kwargs)
191+
return f_vjp(primals)
192+
193+
e_vjp = jax.jit(do_vjp).lower(obj, *args, **kwargs)
194+
vjp_flops = _get_flops(e_vjp)
195+
else:
196+
vjp_flops = None
175197
output_repr = jax.tree.map(_to_dummy_array, outputs)
176198
input_args_info, input_kwargs_info = jax.tree.map(
177199
_to_dummy_array, (args, kwargs)
@@ -191,6 +213,7 @@ def _get_call_info(jitted, method_name, node_stats, obj, compute_flops: bool, *a
191213
input_kwargs=input_kwargs_info,
192214
outputs=output_repr,
193215
flops=flops,
216+
vjp_flops=vjp_flops,
194217
)
195218

196219

@@ -211,8 +234,9 @@ def _argsave(tracer_args, f):
211234
n = f.__name__
212235
@wraps(f)
213236
def wrapper(obj, *args, **kwargs):
214-
tracer_args.append((obj, n, args, kwargs))
215-
return f(obj, *args, **kwargs)
237+
out = f(obj, *args, **kwargs)
238+
tracer_args.append((obj, n, args, kwargs, out))
239+
return out
216240
return wrapper
217241

218242
def _overwrite_methods(env):
@@ -358,6 +382,8 @@ def tabulate(
358382
_variable_types: set[type] = {
359383
nnx.RngState # type: ignore[misc]
360384
if isinstance(leaf, nnx.RngState)
385+
else leaf.var_type
386+
if isinstance(leaf, variablelib.Variable)
361387
else type(leaf)
362388
for _, leaf in nnx.to_flat_state(nnx.state(obj))
363389
}
@@ -368,38 +394,32 @@ def tabulate(
368394
env = _create_obj_env(object_types)
369395

370396
# Modify all the object's methods to save their Tracer arguments.
371-
# tracer_args contains (object, name, args, kwargs) tuples.
372-
tracer_args: list[tuple[tp.Any, str, tuple, dict[str, tp.Any]]] = []
373-
saver_env = {k: _argsave(tracer_args, v) for k,v in env.items()}
374-
_overwrite_methods(saver_env)
375-
397+
# tracer_args contains (object, name, args, kwargs, out) tuples.
376398
# Add JIT calculation to each method. We can extract flops and output info from
377399
# the lowered JITs. We'll only call these jitted values, which guarantees
378400
# that each method will only be traced (and added to the table) once.
379-
jits = {} # Maps (class, method_name) to jit
380-
for key, value in saver_env.items():
381-
jits[key] = nnx.jit(value)
401+
tracer_args: list[tuple[tp.Any, str, tuple, dict[str, tp.Any], tp.Any]] = []
402+
jits = {k: nnx.jit(_argsave(tracer_args, v)) for k, v in env.items()}
382403
_overwrite_methods(jits)
383404

384405
# Trace the top function (which indirectly traces all the others)
385406
jits[(type(obj), method)].trace(obj, *input_args, **input_kwargs)
386407

387408
# Get call_info
388-
rows : list[CallInfo] = [_get_call_info(
389-
jits[(type(object), name)], name, node_stats, object,
390-
compute_flops, *args, **kwargs)
391-
for (object, name, args, kwargs) in tracer_args]
392-
393-
# Add VJP flops if required. This needs to be done separately because calls to `_pure_nnx_vjp`
394-
# can result in tracing the jitted functions a second time if there's shared structure.
395-
# This would add items to `tracer_args`, resulting in duplicate rows in the table.
396-
if compute_vjp_flops:
397-
for i, row in enumerate(rows):
398-
object, method_name, args, kwargs = tracer_args[i]
399-
def do_vjp(*args, **kwargs):
400-
primals, f_vjp = _pure_nnx_vjp(jits[(type(object), method_name)], *args, **kwargs)
401-
return f_vjp(primals)
402-
row.vjp_flops = _get_flops(jax.jit(do_vjp).lower(object, *args, **kwargs))
409+
rows: list[CallInfo] = [
410+
_get_call_info(
411+
jits[(type(object), name)],
412+
name,
413+
node_stats,
414+
object,
415+
compute_flops,
416+
compute_vjp_flops,
417+
args,
418+
kwargs,
419+
out,
420+
)
421+
for (object, name, args, kwargs, out) in list(tracer_args)
422+
]
403423

404424
# Restore the object's original methods
405425
_overwrite_methods(env)

0 commit comments

Comments
 (0)