2121from types import MappingProxyType
2222import functools
2323import itertools
24- from types import SimpleNamespace
2524
2625import jax
2726import numpy as np
@@ -54,38 +53,6 @@ class NoneDumper(yaml.SafeDumper):
5453 lambda dumper , data : dumper .represent_scalar ('tag:yaml.org,2002:str' , 'None' ),
5554)
5655
57- class MaybeJit :
58- """
59- Wraps a function with nnx.jit, but saves the original to run
60- if the function turns out to be non-concrete. We can't get the flops of non-concrete functions,
61- but we should still be able to trace the input and output shapes.
62- """
63- def __init__ (self , f ):
64- functools .update_wrapper (self , f )
65- self .f = f
66- self .jitted = nnx .jit (f )
67- self .seen = set ()
68-
69- # implement descriptor protocol so that we can use this as a method
70- def __get__ (self , obj , objtype = None ):
71- if obj is None :
72- return self
73- return functools .partial (self , obj )
74-
75- def __call__ (self , * args , ** kwargs ):
76- try :
77- return self .jitted (* args , ** kwargs )
78- except TypeError as e :
79- return self .f (* args , ** kwargs )
80-
81- def lower (self , * args , ** kwargs ):
82- try :
83- return self .jitted .lower (* args , ** kwargs )
84- except TypeError as e :
85- result = self .f (* args , ** kwargs )
86- # Mock a `Lowered` instance with a SimpleNamespace
87- return SimpleNamespace (cost_analysis = "0" , lowered = SimpleNamespace (out_info = (None , None , result )))
88-
8956class SizeBytes (typing .SizeBytes ):
9057 def __repr__ (self ) -> str :
9158 bytes_repr = _bytes_repr (self .bytes )
@@ -242,6 +209,12 @@ def do_vjp(*args, **kwargs):
242209
243210 method_name = f .__name__
244211
212+ @functools .partial (jax .jit )
213+ def jit_f (graphdef , state ):
214+ args , kwargs = nnx .merge (graphdef , state )
215+ return f (* args , ** kwargs )
216+ jit_f ._seen = set ()
217+
245218 @wraps (f )
246219 def wrapper (obj , * args , ** kwargs ):
247220 inputs_repr = _get_inputs_repr (args , kwargs )
@@ -251,19 +224,20 @@ def wrapper(obj, *args, **kwargs):
251224 if method_name != '__call__' :
252225 path = (* path , method_name )
253226 identifier = (inputs_repr , object_id )
254- if identifier not in f .seen :
255- counter_val = next (counter )
256- lowered = f .lower (obj , * args , ** kwargs )
227+ counter_val = next (counter )
228+ graphdef , state = nnx .split (((obj , * args ), kwargs ))
229+ lowered = jit_f .lower (graphdef , state )
230+ if identifier not in jit_f ._seen :
231+ jit_f ._seen .add (identifier )
257232 flops = _get_flops (lowered ) if compute_flops else None
258- outputs = lowered .lowered . out_info [ 2 ]
233+ outputs = lowered .out_info
259234 output_repr = jax .tree .map (_to_dummy_array , outputs )
260235 vjp_flops = _get_flops (jax .jit (do_vjp ).lower (
261236 obj , * args , ** kwargs )) if compute_vjp_flops else None
262237 tracer_args .append (
263238 CallInfo (counter_val , object_id , type (obj ), path , inputs_repr ,
264239 output_repr , flops , vjp_flops ))
265- f .seen .add (identifier )
266- return f (obj , * args , ** kwargs )
240+ return jit_f (graphdef , state )
267241 return wrapper
268242
269243def _overwrite_methods (env ):
@@ -424,12 +398,12 @@ def tabulate(
424398
425399 # Modify all the object's methods to save their lowered JIT representations.
426400 rows : list [CallInfo ] = []
427- maybejits = {k : _save_call_info (counter , rows , MaybeJit ( v ) , node_stats , compute_flops , compute_vjp_flops )
401+ jits = {k : _save_call_info (counter , rows , v , node_stats , compute_flops , compute_vjp_flops )
428402 for k , v in env .items ()}
429- _overwrite_methods (maybejits )
403+ _overwrite_methods (jits )
430404
431405 # Trace the top function (which indirectly traces all the others)
432- maybejits [(type (obj ), method )](obj , * input_args , ** input_kwargs )
406+ jits [(type (obj ), method )](obj , * input_args , ** input_kwargs )
433407
434408 # Sort call info in pre-order traversal order
435409 rows .sort (key = lambda x : x .call_order )
0 commit comments