1919import io
2020import typing as tp
2121from types import MappingProxyType
22+ import functools
23+ from types import SimpleNamespace
2224
2325import jax
2426import numpy as np
3436
3537from functools import wraps
3638
39+
3740try :
3841 from IPython import get_ipython
3942
@@ -50,6 +53,38 @@ class NoneDumper(yaml.SafeDumper):
5053 lambda dumper , data : dumper .represent_scalar ('tag:yaml.org,2002:str' , 'None' ),
5154)
5255
56+ class MaybeJit :
57+ """
58+ Wraps a function with nnx.jit, but saves the original to run
59+ if the function turns out to be non-concrete. We can't get the flops of non-concrete functions,
60+ but we should still be able to trace the input and output shapes.
61+ """
62+ def __init__ (self , f ):
63+ functools .update_wrapper (self , f )
64+ self .f = f
65+ self .jitted = nnx .jit (f )
66+ self .seen = set ()
67+
68+ # implement descriptor protocol so that we can use this as a method
69+ def __get__ (self , obj , objtype = None ):
70+ if obj is None :
71+ return self
72+ return functools .partial (self , obj )
73+
74+ def __call__ (self , * args , ** kwargs ):
75+ try :
76+ return self .jitted (* args , ** kwargs )
77+ except TypeError as e :
78+ return self .f (* args , ** kwargs )
79+
80+ def lower (self , * args , ** kwargs ):
81+ try :
82+ return self .jitted .lower (* args , ** kwargs )
83+ except TypeError as e :
84+ result = self .f (* args , ** kwargs )
85+ # Mock a `Lowered` instance with a SimpleNamespace
86+ return SimpleNamespace (cost_analysis = - 1 , lowered = SimpleNamespace (out_info = (None , None , result )))
87+
5388class SizeBytes (typing .SizeBytes ):
5489 def __repr__ (self ) -> str :
5590 bytes_repr = _bytes_repr (self .bytes )
@@ -133,8 +168,7 @@ class CallInfo:
133168 object_id : int
134169 type : type
135170 path : statelib .PathParts
136- input_args : tuple [tp .Any , ...]
137- input_kwargs : dict [str , tp .Any ]
171+ inputs_repr : str
138172 outputs : tp .Any
139173 flops : int | None = None
140174 vjp_flops : int | None = None
@@ -151,14 +185,13 @@ def __repr__(self):
151185
152186
153187def _to_dummy_array (x ):
154- if isinstance ( x , jax . ShapeDtypeStruct ) :
188+ try :
155189 return ArrayRepr (x .shape , x .dtype )
156- elif isinstance (x , jax .Array | np .ndarray ):
157- return ArrayRepr .from_array (x )
158- elif graph .is_graph_node (x ):
159- return SimpleObjectRepr (x )
160- else :
161- return x
190+ except :
191+ if graph .is_graph_node (x ):
192+ return SimpleObjectRepr (x )
193+ else :
194+ return x
162195
163196def _pure_nnx_vjp (f , model , * args , ** kwargs ):
164197 "Wrap nnx functional api around jax.vjp. Only handles pure method calls."
@@ -168,14 +201,10 @@ def inner(state, *args, **kwargs):
168201 return f (model , * args , ** kwargs )
169202 return jax .vjp (inner , state , * args , ** kwargs )
170203
171- def _get_call_info (jitted , method_name , node_stats , obj , compute_flops : bool , * args , ** kwargs ):
172- e = jitted .lower (obj , * args , ** kwargs )
173- flops = _get_flops (e ) if compute_flops else None
174- outputs = e .lowered .out_info [2 ]
204+ def _get_call_info (lowered , method_name , node_stats , obj , compute_flops , inputs_repr ):
205+ flops = _get_flops (lowered ) if compute_flops else None
206+ outputs = lowered .lowered .out_info [2 ]
175207 output_repr = jax .tree .map (_to_dummy_array , outputs )
176- input_args_info , input_kwargs_info = jax .tree .map (
177- _to_dummy_array , (args , kwargs )
178- )
179208 object_id : int = getattr (obj , '_nnx_tabulate_id' )
180209 node_info = node_stats [object_id ]
181210 assert node_info is not None
@@ -187,8 +216,7 @@ def _get_call_info(jitted, method_name, node_stats, obj, compute_flops: bool, *a
187216 object_id = object_id ,
188217 type = type (obj ),
189218 path = path ,
190- input_args = input_args_info ,
191- input_kwargs = input_kwargs_info ,
219+ inputs_repr = inputs_repr ,
192220 outputs = output_repr ,
193221 flops = flops ,
194222 )
@@ -206,13 +234,33 @@ def _create_obj_env(object_types):
206234 result [(obj_type , name )] = top_method
207235 return result
208236
209- def _argsave (tracer_args , f ):
237+ def _argsave (counter , tracer_args , f ):
210238 "Wrap a function to save its arguments"
211239 n = f .__name__
212240 @wraps (f )
213241 def wrapper (obj , * args , ** kwargs ):
214- tracer_args .append ((obj , n , args , kwargs ))
215- return f (obj , * args , ** kwargs )
242+ input_args , input_kwargs = jax .tree .map (
243+ _to_dummy_array , (args , kwargs )
244+ )
245+ inputs_repr = ''
246+ if input_args :
247+ if len (input_args ) == 1 and not input_kwargs :
248+ inputs_repr += _as_yaml_str (input_args [0 ])
249+ else :
250+ inputs_repr += _as_yaml_str (input_args )
251+ if input_kwargs :
252+ inputs_repr += '\n '
253+ if input_kwargs :
254+ inputs_repr += _as_yaml_str (input_kwargs )
255+
256+ identifier = (inputs_repr , getattr (obj , '_nnx_tabulate_id' ))
257+ if identifier not in f .seen :
258+ counter_val = counter [0 ]
259+ counter [0 ] += 1
260+ lowered = f .lower (obj , * args , ** kwargs )
261+ tracer_args .append ((counter_val , obj , n , lowered , inputs_repr ))
262+ f .seen .add (identifier )
263+ return f (obj , * args , ** kwargs )
216264 return wrapper
217265
218266def _overwrite_methods (env ):
@@ -222,7 +270,7 @@ def _overwrite_methods(env):
222270
223271def _get_flops (e ) -> int :
224272 cost = e .cost_analysis () or e .compile ().cost_analysis ()
225- return 0 if cost is None or 'flops' not in cost else int (cost ['flops' ])
273+ return - 1 if cost is None or 'flops' not in cost else int (cost ['flops' ])
226274
227275def tabulate (
228276 obj ,
@@ -367,39 +415,35 @@ def tabulate(
367415 # iteration over methods easier.
368416 env = _create_obj_env (object_types )
369417
370- # Modify all the object's methods to save their Tracer arguments.
371- # tracer_args contains (object, name, args, kwargs) tuples.
372- tracer_args : list [tuple [tp .Any , str , tuple , dict [str , tp .Any ]]] = []
373- saver_env = {k : _argsave (tracer_args , v ) for k ,v in env .items ()}
374- _overwrite_methods (saver_env )
375-
376- # Add JIT calculation to each method. We can extract flops and output info from
377- # the lowered JITs. We'll only call these jitted values, which guarantees
378- # that each method will only be traced (and added to the table) once.
379- jits = {} # Maps (class, method_name) to jit
380- for key , value in saver_env .items ():
381- jits [key ] = nnx .jit (value )
418+ # Modify all the object's methods to save their lowered JIT representations.
419+ tracer_args = []
420+
421+ # Information is recorded in post-order, but should be presented as a pre-order traversal.
422+ # This counter is incremented in pre-order traversal to keep track of the order of calls.
423+ counter = [0 ]
424+
425+ jits = {k : _argsave (counter , tracer_args , MaybeJit (v )) for k ,v in env .items ()}
382426 _overwrite_methods (jits )
383427
384428 # Trace the top function (which indirectly traces all the others)
385- jits [(type (obj ), method )]. trace (obj , * input_args , ** input_kwargs )
429+ jits [(type (obj ), method )](obj , * input_args , ** input_kwargs )
386430
387431 # Get call_info
388432 rows : list [CallInfo ] = [_get_call_info (
389- jits [( type ( object ), name )] , name , node_stats , object ,
390- compute_flops , * args , ** kwargs )
391- for (object , name , args , kwargs ) in tracer_args ]
433+ lowered , name , node_stats , object ,
434+ compute_flops , inputs_repr )
435+ for (_ , object , name , lowered , inputs_repr ) in sorted ( tracer_args , key = lambda x : x [ 0 ]) ]
392436
393437 # Add VJP flops if required. This needs to be done separately because calls to `_pure_nnx_vjp`
394438 # can result in tracing the jitted functions a second time if there's shared structure.
395439 # This would add items to `tracer_args`, resulting in duplicate rows in the table.
396- if compute_vjp_flops :
397- for i , row in enumerate (rows ):
398- object , method_name , args , kwargs = tracer_args [i ]
399- def do_vjp (* args , ** kwargs ):
400- primals , f_vjp = _pure_nnx_vjp (jits [(type (object ), method_name )], * args , ** kwargs )
401- return f_vjp (primals )
402- row .vjp_flops = _get_flops (jax .jit (do_vjp ).lower (object , * args , ** kwargs ))
440+ # if compute_vjp_flops:
441+ # for i, row in enumerate(rows):
442+ # object, method_name, args, kwargs = tracer_args[i]
443+ # def do_vjp(*args, **kwargs):
444+ # primals, f_vjp = _pure_nnx_vjp(jits[(type(object), method_name)], *args, **kwargs)
445+ # return f_vjp(primals)
446+ # row.vjp_flops = _get_flops(jax.jit(do_vjp).lower(object, *args, **kwargs))
403447
404448 # Restore the object's original methods
405449 _overwrite_methods (env )
@@ -436,17 +480,7 @@ def do_vjp(*args, **kwargs):
436480 path_str = '/' .join (map (str , row .path ))
437481 col_reprs .append (path_str )
438482 col_reprs .append (row .type .__name__ )
439- inputs_repr = ''
440- if row .input_args :
441- input_args = row .input_args
442- if len (row .input_args ) == 1 and not row .input_kwargs :
443- input_args = row .input_args [0 ]
444- inputs_repr += _as_yaml_str (input_args )
445- if inputs_repr and row .input_kwargs :
446- inputs_repr += '\n '
447- if row .input_kwargs :
448- inputs_repr += _as_yaml_str (row .input_kwargs )
449- col_reprs .append (inputs_repr )
483+ col_reprs .append (row .inputs_repr )
450484 col_reprs .append (_as_yaml_str (row .outputs ))
451485 if compute_flops :
452486 col_reprs .append (str (row .flops ))
0 commit comments