2121from types import MappingProxyType
2222import functools
2323import itertools
24- from types import SimpleNamespace
2524
2625import jax
2726import numpy as np
@@ -54,38 +53,6 @@ class NoneDumper(yaml.SafeDumper):
5453 lambda dumper , data : dumper .represent_scalar ('tag:yaml.org,2002:str' , 'None' ),
5554)
5655
57- class MaybeJit :
58- """
59- Wraps a function with nnx.jit, but saves the original to run
60- if the function turns out to be non-concrete. We can't get the flops of non-concrete functions,
61- but we should still be able to trace the input and output shapes.
62- """
63- def __init__ (self , f ):
64- functools .update_wrapper (self , f )
65- self .f = f
66- self .jitted = nnx .jit (f )
67- self .seen = set ()
68-
69- # implement descriptor protocol so that we can use this as a method
70- def __get__ (self , obj , objtype = None ):
71- if obj is None :
72- return self
73- return functools .partial (self , obj )
74-
75- def __call__ (self , * args , ** kwargs ):
76- try :
77- return self .jitted (* args , ** kwargs )
78- except TypeError as e :
79- return self .f (* args , ** kwargs )
80-
81- def lower (self , * args , ** kwargs ):
82- try :
83- return self .jitted .lower (* args , ** kwargs )
84- except TypeError as e :
85- result = self .f (* args , ** kwargs )
86- # Mock a `Lowered` instance with a SimpleNamespace
87- return SimpleNamespace (cost_analysis = "0" , lowered = SimpleNamespace (out_info = (None , None , result )))
88-
8956class SizeBytes (typing .SizeBytes ):
9057 def __repr__ (self ) -> str :
9158 bytes_repr = _bytes_repr (self .bytes )
@@ -242,11 +209,11 @@ def do_vjp(*args, **kwargs):
242209
243210 method_name = f .__name__
244211
245- def split_f (* args , ** kwargs ):
212+ @functools .partial (jax .jit )
213+ def jit_f (graphdef , state ):
214+ args , kwargs = nnx .merge (graphdef , state )
246215 return f (* args , ** kwargs )
247-
248- nnx .jit (split_f )
249- # TODO: write split_f and continue from here.
216+ jit_f ._seen = set ()
250217
251218 @wraps (f )
252219 def wrapper (obj , * args , ** kwargs ):
@@ -257,19 +224,20 @@ def wrapper(obj, *args, **kwargs):
257224 if method_name != '__call__' :
258225 path = (* path , method_name )
259226 identifier = (inputs_repr , object_id )
260- if identifier not in f .seen :
261- counter_val = next (counter )
262- lowered = f .lower (obj , * args , ** kwargs )
227+ counter_val = next (counter )
228+ graphdef , state = nnx .split (((obj , * args ), kwargs ))
229+ lowered = jit_f .lower (graphdef , state )
230+ if identifier not in jit_f ._seen :
231+ jit_f ._seen .add (identifier )
263232 flops = _get_flops (lowered ) if compute_flops else None
264- outputs = lowered .lowered . out_info [ 2 ]
233+ outputs = lowered .out_info
265234 output_repr = jax .tree .map (_to_dummy_array , outputs )
266235 vjp_flops = _get_flops (jax .jit (do_vjp ).lower (
267236 obj , * args , ** kwargs )) if compute_vjp_flops else None
268237 tracer_args .append (
269238 CallInfo (counter_val , object_id , type (obj ), path , inputs_repr ,
270239 output_repr , flops , vjp_flops ))
271- f .seen .add (identifier )
272- return f (obj , * args , ** kwargs )
240+ return jit_f (graphdef , state )
273241 return wrapper
274242
275243def _overwrite_methods (env ):
@@ -430,12 +398,12 @@ def tabulate(
430398
431399 # Modify all the object's methods to save their lowered JIT representations.
432400 rows : list [CallInfo ] = []
433- maybejits = {k : _save_call_info (counter , rows , v , node_stats , compute_flops , compute_vjp_flops )
401+ jits = {k : _save_call_info (counter , rows , v , node_stats , compute_flops , compute_vjp_flops )
434402 for k , v in env .items ()}
435- _overwrite_methods (maybejits )
403+ _overwrite_methods (jits )
436404
437405 # Trace the top function (which indirectly traces all the others)
438- maybejits [(type (obj ), method )](obj , * input_args , ** input_kwargs )
406+ jits [(type (obj ), method )](obj , * input_args , ** input_kwargs )
439407
440408 # Sort call info in pre-order traversal order
441409 rows .sort (key = lambda x : x .call_order )
0 commit comments