@@ -93,7 +93,7 @@ def _collect_stats(
9393 if id (value ) in node_stats :
9494 continue
9595 elif isinstance (value , variablelib .Variable ):
96- var_type = type ( value )
96+ var_type = value . var_type
9797 if issubclass (var_type , nnx .RngState ):
9898 var_type = nnx .RngState
9999 size_bytes = SizeBytes .from_any (value .get_value ())
@@ -168,10 +168,32 @@ def inner(state, *args, **kwargs):
168168 return f (model , * args , ** kwargs )
169169 return jax .vjp (inner , state , * args , ** kwargs )
170170
171- def _get_call_info (jitted , method_name , node_stats , obj , compute_flops : bool , * args , ** kwargs ):
172- e = jitted .lower (obj , * args , ** kwargs )
173- flops = _get_flops (e ) if compute_flops else None
174- outputs = e .lowered .out_info [2 ]
171+ def _get_call_info (
172+ jitted ,
173+ method_name ,
174+ node_stats ,
175+ obj ,
176+ compute_flops : bool ,
177+ compute_vjp_flops : bool ,
178+ args ,
179+ kwargs ,
180+ outputs ,
181+ ):
182+ if compute_flops :
183+ e = jitted .lower (obj , * args , ** kwargs )
184+ flops = _get_flops (e )
185+ else :
186+ flops = None
187+ if compute_vjp_flops :
188+
189+ def do_vjp (* args , ** kwargs ):
190+ primals , f_vjp = _pure_nnx_vjp (jitted , obj , * args , ** kwargs )
191+ return f_vjp (primals )
192+
193+ e_vjp = jax .jit (do_vjp ).lower (obj , * args , ** kwargs )
194+ vjp_flops = _get_flops (e_vjp )
195+ else :
196+ vjp_flops = None
175197 output_repr = jax .tree .map (_to_dummy_array , outputs )
176198 input_args_info , input_kwargs_info = jax .tree .map (
177199 _to_dummy_array , (args , kwargs )
@@ -191,6 +213,7 @@ def _get_call_info(jitted, method_name, node_stats, obj, compute_flops: bool, *a
191213 input_kwargs = input_kwargs_info ,
192214 outputs = output_repr ,
193215 flops = flops ,
216+ vjp_flops = vjp_flops ,
194217 )
195218
196219
@@ -211,8 +234,9 @@ def _argsave(tracer_args, f):
211234 n = f .__name__
212235 @wraps (f )
213236 def wrapper (obj , * args , ** kwargs ):
214- tracer_args .append ((obj , n , args , kwargs ))
215- return f (obj , * args , ** kwargs )
237+ out = f (obj , * args , ** kwargs )
238+ tracer_args .append ((obj , n , args , kwargs , out ))
239+ return out
216240 return wrapper
217241
218242def _overwrite_methods (env ):
@@ -358,6 +382,8 @@ def tabulate(
358382 _variable_types : set [type ] = {
359383 nnx .RngState # type: ignore[misc]
360384 if isinstance (leaf , nnx .RngState )
385+ else leaf .var_type
386+ if isinstance (leaf , variablelib .Variable )
361387 else type (leaf )
362388 for _ , leaf in nnx .to_flat_state (nnx .state (obj ))
363389 }
@@ -368,38 +394,32 @@ def tabulate(
368394 env = _create_obj_env (object_types )
369395
370396 # Modify all the object's methods to save their Tracer arguments.
371- # tracer_args contains (object, name, args, kwargs) tuples.
372- tracer_args : list [tuple [tp .Any , str , tuple , dict [str , tp .Any ]]] = []
373- saver_env = {k : _argsave (tracer_args , v ) for k ,v in env .items ()}
374- _overwrite_methods (saver_env )
375-
397+ # tracer_args contains (object, name, args, kwargs, out) tuples.
376398 # Add JIT calculation to each method. We can extract flops and output info from
377399 # the lowered JITs. We'll only call these jitted values, which guarantees
378400 # that each method will only be traced (and added to the table) once.
379- jits = {} # Maps (class, method_name) to jit
380- for key , value in saver_env .items ():
381- jits [key ] = nnx .jit (value )
401+ tracer_args : list [tuple [tp .Any , str , tuple , dict [str , tp .Any ], tp .Any ]] = []
402+ jits = {k : nnx .jit (_argsave (tracer_args , v )) for k , v in env .items ()}
382403 _overwrite_methods (jits )
383404
384405 # Trace the top function (which indirectly traces all the others)
385406 jits [(type (obj ), method )].trace (obj , * input_args , ** input_kwargs )
386407
387408 # Get call_info
388- rows : list [CallInfo ] = [_get_call_info (
389- jits [(type (object ), name )], name , node_stats , object ,
390- compute_flops , * args , ** kwargs )
391- for (object , name , args , kwargs ) in tracer_args ]
392-
393- # Add VJP flops if required. This needs to be done separately because calls to `_pure_nnx_vjp`
394- # can result in tracing the jitted functions a second time if there's shared structure.
395- # This would add items to `tracer_args`, resulting in duplicate rows in the table.
396- if compute_vjp_flops :
397- for i , row in enumerate (rows ):
398- object , method_name , args , kwargs = tracer_args [i ]
399- def do_vjp (* args , ** kwargs ):
400- primals , f_vjp = _pure_nnx_vjp (jits [(type (object ), method_name )], * args , ** kwargs )
401- return f_vjp (primals )
402- row .vjp_flops = _get_flops (jax .jit (do_vjp ).lower (object , * args , ** kwargs ))
409+ rows : list [CallInfo ] = [
410+ _get_call_info (
411+ jits [(type (object ), name )],
412+ name ,
413+ node_stats ,
414+ object ,
415+ compute_flops ,
416+ compute_vjp_flops ,
417+ args ,
418+ kwargs ,
419+ out ,
420+ )
421+ for (object , name , args , kwargs , out ) in list (tracer_args )
422+ ]
403423
404424 # Restore the object's original methods
405425 _overwrite_methods (env )
0 commit comments