1919import io
2020import typing as tp
2121from types import MappingProxyType
22+ import functools
23+ from types import SimpleNamespace
2224
2325import jax
2426import numpy as np
3436
3537from functools import wraps
3638
39+
3740try :
3841 from IPython import get_ipython
3942
@@ -50,6 +53,37 @@ class NoneDumper(yaml.SafeDumper):
5053 lambda dumper , data : dumper .represent_scalar ('tag:yaml.org,2002:str' , 'None' ),
5154)
5255
56+ class MaybeJit :
57+ """
58+ Wraps a function with nnx.jit, but saves the original to run
59+ if the function turns out to be non-concrete. We can't get the flops of non-concrete functions,
60+ but we should still be able to trace the input and output shapes.
61+ """
62+ def __init__ (self , f ):
63+ functools .update_wrapper (self , f )
64+ self .f = f
65+ self .jitted = nnx .jit (f )
66+
67+ # implement descriptor protocol so that we can use this as a method
68+ def __get__ (self , obj , objtype = None ):
69+ if obj is None :
70+ return self
71+ return functools .partial (self , obj )
72+
73+ def __call__ (self , * args , ** kwargs ):
74+ try :
75+ return self .jitted (* args , ** kwargs )
76+ except TypeError as e :
77+ return self .f (* args , ** kwargs )
78+
79+ def lower (self , * args , ** kwargs ):
80+ try :
81+ return self .jitted .lower (* args , ** kwargs )
82+ except TypeError as e :
83+ result = self .f (* args , ** kwargs )
84+ # Mock a `Lowered` instance with a SimpleNamespace
85+ return SimpleNamespace (cost_analysis = - 1 , lowered = SimpleNamespace (out_info = (None , None , result )))
86+
5387class SizeBytes (typing .SizeBytes ):
5488 def __repr__ (self ) -> str :
5589 bytes_repr = _bytes_repr (self .bytes )
@@ -151,14 +185,13 @@ def __repr__(self):
151185
152186
153187def _to_dummy_array (x ):
154- if isinstance ( x , jax . ShapeDtypeStruct ) :
188+ try :
155189 return ArrayRepr (x .shape , x .dtype )
156- elif isinstance (x , jax .Array | np .ndarray ):
157- return ArrayRepr .from_array (x )
158- elif graph .is_graph_node (x ):
159- return SimpleObjectRepr (x )
160- else :
161- return x
190+ except :
191+ if graph .is_graph_node (x ):
192+ return SimpleObjectRepr (x )
193+ else :
194+ return x
162195
163196def _pure_nnx_vjp (f , model , * args , ** kwargs ):
164197 "Wrap nnx functional api around jax.vjp. Only handles pure method calls."
@@ -168,10 +201,9 @@ def inner(state, *args, **kwargs):
168201 return f (model , * args , ** kwargs )
169202 return jax .vjp (inner , state , * args , ** kwargs )
170203
171- def _get_call_info (jitted , method_name , node_stats , obj , compute_flops : bool , * args , ** kwargs ):
172- e = jitted .lower (obj , * args , ** kwargs )
173- flops = _get_flops (e ) if compute_flops else None
174- outputs = e .lowered .out_info [2 ]
204+ def _get_call_info (lowered , method_name , node_stats , obj , compute_flops : bool , * args , ** kwargs ):
205+ flops = _get_flops (lowered ) if compute_flops else None
206+ outputs = lowered .lowered .out_info [2 ]
175207 output_repr = jax .tree .map (_to_dummy_array , outputs )
176208 input_args_info , input_kwargs_info = jax .tree .map (
177209 _to_dummy_array , (args , kwargs )
@@ -206,12 +238,15 @@ def _create_obj_env(object_types):
206238 result [(obj_type , name )] = top_method
207239 return result
208240
209- def _argsave (tracer_args , f ):
241+ def _argsave (counter , tracer_args , f ):
210242 "Wrap a function to save its arguments"
211243 n = f .__name__
212244 @wraps (f )
213245 def wrapper (obj , * args , ** kwargs ):
214- tracer_args .append ((obj , n , args , kwargs ))
246+ counter_val = counter [0 ]
247+ counter [0 ] += 1
248+ lowered = f .lower (obj , * args , ** kwargs )
249+ tracer_args .append ((counter_val , obj , n , lowered , args , kwargs ))
215250 return f (obj , * args , ** kwargs )
216251 return wrapper
217252
@@ -222,7 +257,7 @@ def _overwrite_methods(env):
222257
223258def _get_flops (e ) -> int :
224259 cost = e .cost_analysis () or e .compile ().cost_analysis ()
225- return 0 if cost is None or 'flops' not in cost else int (cost ['flops' ])
260+ return - 1 if cost is None or 'flops' not in cost else int (cost ['flops' ])
226261
227262def tabulate (
228263 obj ,
@@ -367,39 +402,35 @@ def tabulate(
367402 # iteration over methods easier.
368403 env = _create_obj_env (object_types )
369404
370- # Modify all the object's methods to save their Tracer arguments.
371- # tracer_args contains (object, name, args, kwargs) tuples.
372- tracer_args : list [tuple [tp .Any , str , tuple , dict [str , tp .Any ]]] = []
373- saver_env = {k : _argsave (tracer_args , v ) for k ,v in env .items ()}
374- _overwrite_methods (saver_env )
375-
376- # Add JIT calculation to each method. We can extract flops and output info from
377- # the lowered JITs. We'll only call these jitted values, which guarantees
378- # that each method will only be traced (and added to the table) once.
379- jits = {} # Maps (class, method_name) to jit
380- for key , value in saver_env .items ():
381- jits [key ] = nnx .jit (value )
405+ # Modify all the object's methods to save their lowered JIT representations.
406+ tracer_args = []
407+
408+ # Information is recorded in post-order, but should be presented as a pre-order traversal.
409+ # This counter is incremented in pre-order traversal to keep track of the order of calls.
410+ counter = [0 ]
411+
412+ jits = {k : _argsave (counter , tracer_args , MaybeJit (v )) for k ,v in env .items ()}
382413 _overwrite_methods (jits )
383414
384415 # Trace the top function (which indirectly traces all the others)
385- jits [(type (obj ), method )]. trace (obj , * input_args , ** input_kwargs )
416+ jits [(type (obj ), method )](obj , * input_args , ** input_kwargs )
386417
387418 # Get call_info
388419 rows : list [CallInfo ] = [_get_call_info (
389- jits [( type ( object ), name )] , name , node_stats , object ,
420+ lowered , name , node_stats , object ,
390421 compute_flops , * args , ** kwargs )
391- for (object , name , args , kwargs ) in tracer_args ]
422+ for (_ , object , name , lowered , args , kwargs ) in sorted ( tracer_args , key = lambda x : x [ 0 ]) ]
392423
393424 # Add VJP flops if required. This needs to be done separately because calls to `_pure_nnx_vjp`
394425 # can result in tracing the jitted functions a second time if there's shared structure.
395426 # This would add items to `tracer_args`, resulting in duplicate rows in the table.
396- if compute_vjp_flops :
397- for i , row in enumerate (rows ):
398- object , method_name , args , kwargs = tracer_args [i ]
399- def do_vjp (* args , ** kwargs ):
400- primals , f_vjp = _pure_nnx_vjp (jits [(type (object ), method_name )], * args , ** kwargs )
401- return f_vjp (primals )
402- row .vjp_flops = _get_flops (jax .jit (do_vjp ).lower (object , * args , ** kwargs ))
427+ # if compute_vjp_flops:
428+ # for i, row in enumerate(rows):
429+ # object, method_name, args, kwargs = tracer_args[i]
430+ # def do_vjp(*args, **kwargs):
431+ # primals, f_vjp = _pure_nnx_vjp(jits[(type(object), method_name)], *args, **kwargs)
432+ # return f_vjp(primals)
433+ # row.vjp_flops = _get_flops(jax.jit(do_vjp).lower(object, *args, **kwargs))
403434
404435 # Restore the object's original methods
405436 _overwrite_methods (env )
0 commit comments