1919import io
2020import typing as tp
2121from types import MappingProxyType
22+ import functools
23+ from types import SimpleNamespace
2224
2325import jax
2426import numpy as np
3436
3537from functools import wraps
3638
39+
3740try :
3841 from IPython import get_ipython
3942
@@ -50,6 +53,41 @@ class NoneDumper(yaml.SafeDumper):
5053 lambda dumper , data : dumper .represent_scalar ('tag:yaml.org,2002:str' , 'None' ),
5154)
5255
56+ class MaybeJit :
57+ """
58+ Wraps a function with nnx.jit, but saves the original to run
59+ if the function turns out to be non-concrete. We can't get the flops of non-concrete functions,
60+ but we should still be able to trace the input and output shapes.
61+ """
62+ def __init__ (self , f ):
63+ self .f = f
64+ self .jitted = nnx .jit (f )
65+ functools .update_wrapper (self , self .f )
66+
67+ # implement descriptor protocol so that we can use this as a method
68+ def __get__ (self , obj , objtype = None ):
69+ if obj is None :
70+ return self
71+ return functools .partial (self , obj )
72+
73+ def __call__ (self , * args ):
74+ try :
75+ return self .jitted (* args )
76+ except TypeError as e :
77+ return self .f (* args )
78+
79+ # This will only be used for the top-level method
80+ def trace (self , * args ):
81+ return self (* args )
82+
83+ def lower (self , * args ):
84+ try :
85+ return self .jitted .lower (* args )
86+ except TypeError as e :
87+ result = self .f (* args )
88+ # Mock a `Lowered` instance with a SimpleNamespace
89+ return SimpleNamespace (cost_analysis = - 1 , lowered = SimpleNamespace (out_info = (None , None , result )))
90+
5391class SizeBytes (typing .SizeBytes ):
5492 def __repr__ (self ) -> str :
5593 bytes_repr = _bytes_repr (self .bytes )
@@ -222,7 +260,7 @@ def _overwrite_methods(env):
222260
223261def _get_flops (e ) -> int :
224262 cost = e .cost_analysis () or e .compile ().cost_analysis ()
225- return 0 if cost is None or 'flops' not in cost else int (cost ['flops' ])
263+ return - 1 if cost is None or 'flops' not in cost else int (cost ['flops' ])
226264
227265def tabulate (
228266 obj ,
@@ -378,7 +416,8 @@ def tabulate(
378416 # that each method will only be traced (and added to the table) once.
379417 jits = {} # Maps (class, method_name) to jit
380418 for key , value in saver_env .items ():
381- jits [key ] = nnx .jit (value )
419+ jits [key ] = MaybeJit (value )
420+
382421 _overwrite_methods (jits )
383422
384423 # Trace the top function (which indirectly traces all the others)
0 commit comments