@@ -170,8 +170,8 @@ class CallInfo:
170170 path : statelib .PathParts
171171 inputs_repr : str
172172 outputs : tp .Any
173- flops : int | None = None
174- vjp_flops : int | None = None
173+ flops : int | None
174+ vjp_flops : int | None
175175
176176class SimpleObjectRepr :
177177 def __init__ (self , obj : tp .Any ):
@@ -201,7 +201,7 @@ def inner(state, *args, **kwargs):
201201 return f (model , * args , ** kwargs )
202202 return jax .vjp (inner , state , * args , ** kwargs )
203203
204- def _get_call_info (lowered , method_name , node_stats , obj , compute_flops , inputs_repr ):
204+ def _get_call_info (lowered , method_name , node_stats , obj , compute_flops , inputs_repr , vjp_flops ):
205205 flops = _get_flops (lowered ) if compute_flops else None
206206 outputs = lowered .lowered .out_info [2 ]
207207 output_repr = jax .tree .map (_to_dummy_array , outputs )
@@ -219,6 +219,7 @@ def _get_call_info(lowered, method_name, node_stats, obj, compute_flops, inputs_
219219 inputs_repr = inputs_repr ,
220220 outputs = output_repr ,
221221 flops = flops ,
222+ vjp_flops = vjp_flops
222223 )
223224
224225
@@ -234,8 +235,11 @@ def _create_obj_env(object_types):
234235 result [(obj_type , name )] = top_method
235236 return result
236237
237- def _argsave (counter , tracer_args , f ):
238+ def _argsave (counter , tracer_args , f , compute_vjp_flops ):
238239 "Wrap a function to save its arguments"
240+ def do_vjp (* args , ** kwargs ):
241+ primals , f_vjp = jax .vjp (f , * args , ** kwargs )
242+ return f_vjp (primals )
239243 n = f .__name__
240244 @wraps (f )
241245 def wrapper (obj , * args , ** kwargs ):
@@ -258,7 +262,11 @@ def wrapper(obj, *args, **kwargs):
258262 counter_val = counter [0 ]
259263 counter [0 ] += 1
260264 lowered = f .lower (obj , * args , ** kwargs )
261- tracer_args .append ((counter_val , obj , n , lowered , inputs_repr ))
265+ if compute_vjp_flops :
266+ vjp_flops = _get_flops (jax .jit (do_vjp ).lower (obj , * args , ** kwargs ))
267+ else :
268+ vjp_flops = None
269+ tracer_args .append ((counter_val , obj , n , lowered , inputs_repr , vjp_flops ))
262270 f .seen .add (identifier )
263271 return f (obj , * args , ** kwargs )
264272 return wrapper
@@ -415,14 +423,13 @@ def tabulate(
415423 # iteration over methods easier.
416424 env = _create_obj_env (object_types )
417425
418- # Modify all the object's methods to save their lowered JIT representations.
419- tracer_args = []
420-
421426 # Information is recorded in post-order, but should be presented as a pre-order traversal.
422427 # This counter is incremented in pre-order traversal to keep track of the order of calls.
423428 counter = [0 ]
424429
425- jits = {k : _argsave (counter , tracer_args , MaybeJit (v )) for k ,v in env .items ()}
430+ # Modify all the object's methods to save their lowered JIT representations.
431+ tracer_args = []
432+ jits = {k : _argsave (counter , tracer_args , MaybeJit (v ), compute_vjp_flops ) for k ,v in env .items ()}
426433 _overwrite_methods (jits )
427434
428435 # Trace the top function (which indirectly traces all the others)
@@ -431,19 +438,8 @@ def tabulate(
431438 # Get call_info
432439 rows : list [CallInfo ] = [_get_call_info (
433440 lowered , name , node_stats , object ,
434- compute_flops , inputs_repr )
435- for (_ , object , name , lowered , inputs_repr ) in sorted (tracer_args , key = lambda x : x [0 ])]
436-
437- # Add VJP flops if required. This needs to be done separately because calls to `_pure_nnx_vjp`
438- # can result in tracing the jitted functions a second time if there's shared structure.
439- # This would add items to `tracer_args`, resulting in duplicate rows in the table.
440- # if compute_vjp_flops:
441- # for i, row in enumerate(rows):
442- # object, method_name, args, kwargs = tracer_args[i]
443- # def do_vjp(*args, **kwargs):
444- # primals, f_vjp = _pure_nnx_vjp(jits[(type(object), method_name)], *args, **kwargs)
445- # return f_vjp(primals)
446- # row.vjp_flops = _get_flops(jax.jit(do_vjp).lower(object, *args, **kwargs))
441+ compute_flops , inputs_repr , vjp_flops )
442+ for (_ , object , name , lowered , inputs_repr , vjp_flops ) in sorted (tracer_args , key = lambda x : x [0 ])]
447443
448444 # Restore the object's original methods
449445 _overwrite_methods (env )
0 commit comments