1919import io
2020import typing as tp
2121from types import MappingProxyType
22+ import functools
23+ import itertools
2224
2325import jax
2426import numpy as np
3436
3537from functools import wraps
3638
39+
3740try :
3841 from IPython import get_ipython
3942
@@ -130,14 +133,14 @@ def __str__(self):
130133
131134@dataclasses .dataclass
132135class CallInfo :
136+ call_order : int
133137 object_id : int
134138 type : type
135139 path : statelib .PathParts
136- input_args : tuple [tp .Any , ...]
137- input_kwargs : dict [str , tp .Any ]
140+ inputs_repr : str
138141 outputs : tp .Any
139- flops : int | None = None
140- vjp_flops : int | None = None
142+ flops : int | None
143+ vjp_flops : int | None
141144
142145class SimpleObjectRepr :
143146 def __init__ (self , obj : tp .Any ):
@@ -168,32 +171,6 @@ def inner(state, *args, **kwargs):
168171 return f (model , * args , ** kwargs )
169172 return jax .vjp (inner , state , * args , ** kwargs )
170173
171- def _get_call_info (jitted , method_name , node_stats , obj , compute_flops : bool , * args , ** kwargs ):
172- e = jitted .lower (obj , * args , ** kwargs )
173- flops = _get_flops (e ) if compute_flops else None
174- outputs = e .lowered .out_info [2 ]
175- output_repr = jax .tree .map (_to_dummy_array , outputs )
176- input_args_info , input_kwargs_info = jax .tree .map (
177- _to_dummy_array , (args , kwargs )
178- )
179- object_id : int = getattr (obj , '_nnx_tabulate_id' )
180- node_info = node_stats [object_id ]
181- assert node_info is not None
182- path = node_info .path
183- if method_name != '__call__' :
184- path = (* path , method_name )
185-
186- return CallInfo (
187- object_id = object_id ,
188- type = type (obj ),
189- path = path ,
190- input_args = input_args_info ,
191- input_kwargs = input_kwargs_info ,
192- outputs = output_repr ,
193- flops = flops ,
194- )
195-
196-
197174def filter_rng_streams (row : CallInfo ):
198175 return not issubclass (row .type , nnx .RngStream )
199176
@@ -206,13 +183,60 @@ def _create_obj_env(object_types):
206183 result [(obj_type , name )] = top_method
207184 return result
208185
209- def _argsave (tracer_args , f ):
186+ def _get_inputs_repr (args , kwargs ):
187+ input_args , input_kwargs = jax .tree .map (
188+ _to_dummy_array , (args , kwargs )
189+ )
190+ inputs_repr = ''
191+ if input_args :
192+ if len (input_args ) == 1 and not input_kwargs :
193+ inputs_repr += _as_yaml_str (input_args [0 ])
194+ else :
195+ inputs_repr += _as_yaml_str (input_args )
196+ if input_kwargs :
197+ inputs_repr += '\n '
198+ if input_kwargs :
199+ inputs_repr += _as_yaml_str (input_kwargs )
200+ return inputs_repr
201+
202+ def _save_call_info (counter , tracer_args , f , node_stats , compute_flops , compute_vjp_flops , seen ):
210203 "Wrap a function to save its arguments"
211- n = f .__name__
204+
205+ # Used when computing vjp flops
206+ def do_vjp (* args , ** kwargs ):
207+ primals , f_vjp = jax .vjp (f , * args , ** kwargs )
208+ return f_vjp (primals )
209+
210+ method_name = f .__name__
211+
212+ @functools .partial (jax .jit )
213+ def jit_f (graphdef , state ):
214+ args , kwargs = nnx .merge (graphdef , state )
215+ return f (* args , ** kwargs )
216+
212217 @wraps (f )
213218 def wrapper (obj , * args , ** kwargs ):
214- tracer_args .append ((obj , n , args , kwargs ))
215- return f (obj , * args , ** kwargs )
219+ inputs_repr = _get_inputs_repr (args , kwargs )
220+ object_id = getattr (obj , '_nnx_tabulate_id' )
221+ node_info = node_stats [object_id ]
222+ path = node_info .path
223+ if method_name != '__call__' :
224+ path = (* path , method_name )
225+ identifier = (inputs_repr , object_id )
226+ counter_val = next (counter )
227+ graphdef , state = nnx .split (((obj , * args ), kwargs ))
228+ lowered = jit_f .lower (graphdef , state )
229+ if identifier not in seen :
230+ seen .add (identifier )
231+ flops = _get_flops (lowered ) if compute_flops else None
232+ outputs = lowered .out_info
233+ output_repr = jax .tree .map (_to_dummy_array , outputs )
234+ vjp_flops = _get_flops (jax .jit (do_vjp ).lower (
235+ obj , * args , ** kwargs )) if compute_vjp_flops else None
236+ tracer_args .append (
237+ CallInfo (counter_val , object_id , type (obj ), path , inputs_repr ,
238+ output_repr , flops , vjp_flops ))
239+ return jit_f (graphdef , state )
216240 return wrapper
217241
218242def _overwrite_methods (env ):
@@ -367,39 +391,22 @@ def tabulate(
367391 # iteration over methods easier.
368392 env = _create_obj_env (object_types )
369393
370- # Modify all the object's methods to save their Tracer arguments.
371- # tracer_args contains (object, name, args, kwargs) tuples.
372- tracer_args : list [tuple [tp .Any , str , tuple , dict [str , tp .Any ]]] = []
373- saver_env = {k : _argsave (tracer_args , v ) for k ,v in env .items ()}
374- _overwrite_methods (saver_env )
375-
376- # Add JIT calculation to each method. We can extract flops and output info from
377- # the lowered JITs. We'll only call these jitted values, which guarantees
378- # that each method will only be traced (and added to the table) once.
379- jits = {} # Maps (class, method_name) to jit
380- for key , value in saver_env .items ():
381- jits [key ] = nnx .jit (value )
394+ # Information is recorded in post-order, but should be presented as a pre-order traversal.
395+ # This keeps track of the order of calls.
396+ counter = itertools .count (0 )
397+
398+ # Modify all the object's methods to save their lowered JIT representations.
399+ rows : list [CallInfo ] = []
400+ seen : set = set ()
401+ jits = {k : _save_call_info (counter , rows , v , node_stats , compute_flops , compute_vjp_flops , seen )
402+ for k , v in env .items ()}
382403 _overwrite_methods (jits )
383404
384405 # Trace the top function (which indirectly traces all the others)
385- jits [(type (obj ), method )]. trace (obj , * input_args , ** input_kwargs )
406+ jits [(type (obj ), method )](obj , * input_args , ** input_kwargs )
386407
387- # Get call_info
388- rows : list [CallInfo ] = [_get_call_info (
389- jits [(type (object ), name )], name , node_stats , object ,
390- compute_flops , * args , ** kwargs )
391- for (object , name , args , kwargs ) in tracer_args ]
392-
393- # Add VJP flops if required. This needs to be done separately because calls to `_pure_nnx_vjp`
394- # can result in tracing the jitted functions a second time if there's shared structure.
395- # This would add items to `tracer_args`, resulting in duplicate rows in the table.
396- if compute_vjp_flops :
397- for i , row in enumerate (rows ):
398- object , method_name , args , kwargs = tracer_args [i ]
399- def do_vjp (* args , ** kwargs ):
400- primals , f_vjp = _pure_nnx_vjp (jits [(type (object ), method_name )], * args , ** kwargs )
401- return f_vjp (primals )
402- row .vjp_flops = _get_flops (jax .jit (do_vjp ).lower (object , * args , ** kwargs ))
408+ # Sort call info in pre-order traversal order
409+ rows .sort (key = lambda x : x .call_order )
403410
404411 # Restore the object's original methods
405412 _overwrite_methods (env )
@@ -436,17 +443,7 @@ def do_vjp(*args, **kwargs):
436443 path_str = '/' .join (map (str , row .path ))
437444 col_reprs .append (path_str )
438445 col_reprs .append (row .type .__name__ )
439- inputs_repr = ''
440- if row .input_args :
441- input_args = row .input_args
442- if len (row .input_args ) == 1 and not row .input_kwargs :
443- input_args = row .input_args [0 ]
444- inputs_repr += _as_yaml_str (input_args )
445- if inputs_repr and row .input_kwargs :
446- inputs_repr += '\n '
447- if row .input_kwargs :
448- inputs_repr += _as_yaml_str (row .input_kwargs )
449- col_reprs .append (inputs_repr )
446+ col_reprs .append (row .inputs_repr )
450447 col_reprs .append (_as_yaml_str (row .outputs ))
451448 if compute_flops :
452449 col_reprs .append (str (row .flops ))
0 commit comments