9494 PseudoInst ,
9595 ProvenanceRecord ,
9696 interpreter_needs_wrap ,
97+ ThunderInterpreterObject ,
9798)
9899from thunder .core .langctxs import set_langctx , reset_langctx , Languages , resolve_language
99100from thunder .core .baseutils import extract_callable_name
@@ -350,7 +351,7 @@ def proxify(self, value: WrappedValue) -> Any:
350351 )
351352 return proxy_s
352353 else :
353- raise ValueError ("cannot proxify value of {type(uvalue).__type } objects" )
354+ raise ValueError (f "cannot proxify value of { type (uvalue ).__type__ } objects" )
354355
355356
356357_jit_ctx = contextvars .ContextVar ("jitctx" )
@@ -445,6 +446,21 @@ def _general_jit_getattr_lookaside(obj: Any, name: str, *maybe_default: Any):
445446 getattr_lookaside = default_lookaside (getattr )
446447 assert getattr_lookaside is not None
447448
449+ uobj = unwrap (obj )
450+ uname = unwrap (name )
451+ if isinstance (uobj , AnyProxy ):
452+ if uname == "__dict__" :
453+ return wrap (
454+ obj .original_value .__dict__ ,
455+ provenance = ProvenanceRecord (
456+ PseudoInst .LOAD_ATTR ,
457+ inputs = [
458+ obj .provenance ,
459+ name .provenance ,
460+ ],
461+ ),
462+ )
463+
448464 value = getattr_lookaside (obj , name , * maybe_default )
449465 if value is INTERPRETER_SIGNALS .EXCEPTION_RAISED :
450466 return value
@@ -510,16 +526,66 @@ def _general_jit_ordered_dict_setitem(d, key, value):
510526 return dict_setitem_lookaside (d , key , value )
511527
512528
529+ _TORCH_DYNAMIC_TYPES = {
530+ torch .amp .autocast_mode .autocast ,
531+ torch .autograd .grad_mode .set_grad_enabled ,
532+ torch .autograd .grad_mode .no_grad ,
533+ }
534+
535+
536+ def is_created_during_tracing (provenance ):
537+ if (
538+ provenance .inst is PseudoInst .OPAQUE
539+ and provenance .inputs [0 ].inst is PseudoInst .CONSTANT
540+ and provenance .inputs [0 ].value == object .__new__
541+ ):
542+ return True
543+ if provenance .inst is PseudoInst .NEW :
544+ return True
545+ return False
546+
547+
548+ @interpreter_needs_wrap
549+ def _raw_object_setattr (obj : Any , name : str , value : Any ):
550+ return object .__setattr__ (obj , name , value )
551+
552+
553+ @register_general_jit_lookaside (object .__setattr__ )
554+ def _general_jit_object_setattr_lookaside (obj : Any , name : str , value : Any ):
555+ uobj = unwrap (obj )
556+ if is_created_during_tracing (obj .provenance ) or type (uobj ) in _TORCH_DYNAMIC_TYPES :
557+ return _raw_object_setattr (obj , name , value )
558+
559+ if should_register_for_prologue (obj .provenance ) and (obj .original_value is obj .nothing ):
560+ if getattr (obj .provenance , "proxy" , None ) is None :
561+ p : AnyProxy = AnyProxy (uobj , history = obj .provenance )
562+ obj .provenance .proxy = p
563+ obj .register_proxy (p )
564+ uobj = p
565+
566+ d = _interpret_call (getattr , obj , wrap_const ("__dict__" ))
567+ if d is INTERPRETER_SIGNALS .EXCEPTION_RAISED :
568+ return d
569+ d .provenance .ext_flag |= EXT_FLAG_IS_MODULE_MEMBER_DICT
570+ ud = unwrap (d )
571+ assert type (ud ) == dict
572+ res = _interpret_call (ud .__setitem__ , name , value )
573+ return res
574+
575+
513576@register_general_jit_lookaside (setattr )
514577def _general_jit_setattr_lookaside (obj : Any , name : str , value : Any ):
515578 setattr_lookaside = default_lookaside (setattr )
516579 assert setattr_lookaside is not None
517580
518581 uobj = unwrap (obj )
519582 uname = unwrap (name )
583+
520584 if isinstance (uobj , torch .nn .Module ):
521- # 1) modify the inner thing
522- # 2) divert the actual setattr...
585+ # 1) populate the wrappeers for the member dicts
586+ # 2) let the original setattr do it's thing by modifying the
587+ # the member dict
588+ # This might generalize to other things, too...
523589 for n in MODULE_MEMBER_DICT_ATTRS :
524590 member_dict = _interpret_call (getattr , obj , wrap_const (n ))
525591 member_dict .provenance .ext_flag |= EXT_FLAG_IS_MODULE_MEMBER_DICT
@@ -682,9 +748,12 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
682748
683749 custom_autograd_function_cls = unwrap (obj )
684750 custom_forward = custom_autograd_function_cls .forward
685- ctx = torch .autograd .function .FunctionCtx ()
751+ typ = torch .autograd .function .FunctionCtx
752+ ctx = typ ()
686753 ctx_proxy = proxy (ctx , name = None , history = None )
687- wrapped_ctx = wrap_const (ctx_proxy )
754+ wrapped_ctx = wrap_const (
755+ ctx_proxy , provenance = ProvenanceRecord (PseudoInst .NEW , inputs = [wrap_const (typ ).provenance ])
756+ )
688757 trace_of_fwd , fwd_output_provenance = _convert_pytorchfunc_to_thundertrace (
689758 custom_forward , True , wrapped_ctx , * args , ** kwargs
690759 )
@@ -1241,10 +1310,10 @@ def _general_jit_global_callback(orig_value: Any, name: str) -> Any:
12411310 return orig_value
12421311
12431312
1244- _safe_provenance_inst = {
1313+ _input_provenance_inst = {
12451314 "INPUT_ARGS" ,
12461315 "INPUT_KWARGS" ,
1247- "INPUT_FN" ,
1316+ "INPUT_FN" , # or self
12481317 "LOAD_ATTR" ,
12491318 "CONSTANT" ,
12501319 "BINARY_SUBSCR" ,
@@ -1259,7 +1328,7 @@ def should_register_for_prologue(pr):
12591328 inst = inst .opname
12601329 else :
12611330 inst = inst .value
1262- if inst not in _safe_provenance_inst :
1331+ if inst not in _input_provenance_inst :
12631332 return False
12641333 if inst == "CONSTANT" and callable (pr .value ):
12651334 if pr .value .__name__ != "__getitem__" and pr .value != GetSetDescriptorType .__get__ :
@@ -1509,6 +1578,7 @@ def from_load_attr(provenance, *, new_output=False):
15091578 output = Proxy (prefix = "obj" )
15101579 else :
15111580 output = p
1581+
15121582 param_ordering [id (output )] = (output , param_ordering [id (orig_obj )][1 ] + [math .inf , "." + str (name )])
15131583 bsym = prims .unpack_attr .bind (obj , name , output = output )
15141584 prologue_trace .bound_symbols .append (bsym )
@@ -1766,27 +1836,43 @@ def process_recorded_modifications(ctx, epilogue_trace):
17661836 for k , (inst , * args ) in last_modification .items ():
17671837 if inst == PseudoInst .STORE_SUBSCR :
17681838 (value ,) = args
1769- assert isinstance (value .value , Proxy )
17701839
1771- assert modified_object .provenance .inst is PseudoInst .LOAD_ATTR
1772- assert modified_object .provenance .inputs [1 ].inst is PseudoInst .CONSTANT
1773- assert modified_object .provenance .inputs [1 ].value == "_buffers"
1774-
1775- typ , name , root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root (
1776- modified_object .provenance .inputs [0 ]
1777- )
1778- assert typ == "_modules"
1779- root_module_proxy = root_for_provenances .get (root_module_provenance )
1780- if root_module_proxy is None :
1781- ## we want this to created in the compute trace context for namespace...
1782- root_module_proxy = Proxy (history = root_module_provenance )
1783- epilogue_trace .add_name (root_module_proxy .name )
1784- root_for_provenances [root_module_provenance ] = root_module_proxy
1785-
1786- name = "." .join (name + [k ])
1787- with tracectx (epilogue_trace ):
1788- bsym = prims .pack_buffer .bind (root_module_proxy , name , value .value , output = None )
1789- epilogue_trace .bound_symbols .append (bsym )
1840+ assert isinstance (value .value , (Proxy , int , tuple )) ## todo: better criterion
1841+
1842+ if (
1843+ modified_object .provenance .inst is PseudoInst .LOAD_ATTR
1844+ and modified_object .provenance .inputs [1 ].inst is PseudoInst .CONSTANT
1845+ and modified_object .provenance .inputs [1 ].value == "_buffers"
1846+ ):
1847+ typ , name , root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root (
1848+ modified_object .provenance .inputs [0 ]
1849+ )
1850+ assert typ == "_modules"
1851+ root_module_proxy = root_for_provenances .get (root_module_provenance )
1852+ if root_module_proxy is None :
1853+ ## we want this to created in the compute trace context for namespace...
1854+ root_module_proxy = Proxy (history = root_module_provenance )
1855+ epilogue_trace .add_name (root_module_proxy .name )
1856+ root_for_provenances [root_module_provenance ] = root_module_proxy
1857+
1858+ name = "." .join (name + [k ])
1859+ with tracectx (epilogue_trace ):
1860+ bsym = prims .pack_buffer .bind (root_module_proxy , name , value .value , output = None )
1861+ epilogue_trace .bound_symbols .append (bsym )
1862+ elif (
1863+ modified_object .provenance .inst is PseudoInst .LOAD_ATTR
1864+ and modified_object .provenance .inputs [1 ].inst is PseudoInst .CONSTANT
1865+ and modified_object .provenance .inputs [1 ].value == "__dict__"
1866+ ):
1867+ name = k
1868+ setattr_obj_provenance = modified_object .provenance .inputs [0 ]
1869+ if hasattr (setattr_obj_provenance , "proxy" ):
1870+ setattr_obj_proxy = setattr_obj_provenance .proxy
1871+ with tracectx (epilogue_trace ):
1872+ bsym = prims .pack_attr .bind (setattr_obj_proxy , name , value .value , output = None )
1873+ epilogue_trace .bound_symbols .append (bsym )
1874+ else :
1875+ raise NotImplementedError (f"Modifications of { modified_object .provenance } are not supported" )
17901876 else :
17911877 raise NotImplementedError (f"Modifications { inst } on dicts are not supported" )
17921878 else :
@@ -1882,6 +1968,7 @@ def thunder_general_jit(
18821968 fn_lookaside = general_jit_lookaside ,
18831969 callbacks = general_jit_callbacks ,
18841970 with_provenance_tracking = True ,
1971+ unwrap_result = False ,
18851972 uncacheable_classes = (torch .Tensor , int , float , str , NoneType ),
18861973 record_history = compile_data .debug_options .record_interpreter_history ,
18871974 )
@@ -1891,11 +1978,12 @@ def thunder_general_jit(
18911978 result = jfn (* args , ** kwargs )
18921979 computation_trace .set_current_source_location (None , None )
18931980 process_recorded_modifications (ctx , epilogue_trace )
1981+ uresult = unwrap (result )
18941982 last_interpreter_log = jfn ._last_interpreter_log
1895- result_proxies = tuple (p for p in tree_iter (result ) if isinstance (p , (TensorProxy , NumberProxy )))
1983+ result_proxies = tuple (p for p in tree_iter (uresult ) if isinstance (p , (TensorProxy , NumberProxy )))
18961984 prims .python_return (result_proxies )
18971985 with tracectx (epilogue_trace ):
1898- prims .python_return (result )
1986+ prims .python_return (uresult )
18991987
19001988 pro_to_comp , pro_to_comp_set , computation_intermediates = get_computation_inputs_and_intermediates (
19011989 computation_trace
@@ -1958,5 +2046,4 @@ def restrict_proxy_swapmap(proxies: tuple[Proxy]) -> dict[Variable, Proxy]:
19582046 epilogue_trace = _apply_trace_proxy_rename (
19592047 epilogue_trace , restrict_proxy_swapmap (pro_to_epi_proxies + comp_to_epi_proxies ), "epilogue"
19602048 )
1961-
19622049 return TraceResults (prologue_trace , computation_trace , epilogue_trace , last_interpreter_log )
0 commit comments