94
94
PseudoInst ,
95
95
ProvenanceRecord ,
96
96
interpreter_needs_wrap ,
97
+ ThunderInterpreterObject ,
97
98
)
98
99
from thunder .core .langctxs import set_langctx , reset_langctx , Languages , resolve_language
99
100
from thunder .core .baseutils import extract_callable_name
@@ -350,7 +351,7 @@ def proxify(self, value: WrappedValue) -> Any:
350
351
)
351
352
return proxy_s
352
353
else :
353
- raise ValueError ("cannot proxify value of {type(uvalue).__type } objects" )
354
+ raise ValueError (f "cannot proxify value of { type (uvalue ).__type__ } objects" )
354
355
355
356
356
357
_jit_ctx = contextvars .ContextVar ("jitctx" )
@@ -445,6 +446,21 @@ def _general_jit_getattr_lookaside(obj: Any, name: str, *maybe_default: Any):
445
446
getattr_lookaside = default_lookaside (getattr )
446
447
assert getattr_lookaside is not None
447
448
449
+ uobj = unwrap (obj )
450
+ uname = unwrap (name )
451
+ if isinstance (uobj , AnyProxy ):
452
+ if uname == "__dict__" :
453
+ return wrap (
454
+ obj .original_value .__dict__ ,
455
+ provenance = ProvenanceRecord (
456
+ PseudoInst .LOAD_ATTR ,
457
+ inputs = [
458
+ obj .provenance ,
459
+ name .provenance ,
460
+ ],
461
+ ),
462
+ )
463
+
448
464
value = getattr_lookaside (obj , name , * maybe_default )
449
465
if value is INTERPRETER_SIGNALS .EXCEPTION_RAISED :
450
466
return value
@@ -510,16 +526,66 @@ def _general_jit_ordered_dict_setitem(d, key, value):
510
526
return dict_setitem_lookaside (d , key , value )
511
527
512
528
529
+ _TORCH_DYNAMIC_TYPES = {
530
+ torch .amp .autocast_mode .autocast ,
531
+ torch .autograd .grad_mode .set_grad_enabled ,
532
+ torch .autograd .grad_mode .no_grad ,
533
+ }
534
+
535
+
536
+ def is_created_during_tracing (provenance ):
537
+ if (
538
+ provenance .inst is PseudoInst .OPAQUE
539
+ and provenance .inputs [0 ].inst is PseudoInst .CONSTANT
540
+ and provenance .inputs [0 ].value == object .__new__
541
+ ):
542
+ return True
543
+ if provenance .inst is PseudoInst .NEW :
544
+ return True
545
+ return False
546
+
547
+
548
+ @interpreter_needs_wrap
549
+ def _raw_object_setattr (obj : Any , name : str , value : Any ):
550
+ return object .__setattr__ (obj , name , value )
551
+
552
+
553
+ @register_general_jit_lookaside (object .__setattr__ )
554
+ def _general_jit_object_setattr_lookaside (obj : Any , name : str , value : Any ):
555
+ uobj = unwrap (obj )
556
+ if is_created_during_tracing (obj .provenance ) or type (uobj ) in _TORCH_DYNAMIC_TYPES :
557
+ return _raw_object_setattr (obj , name , value )
558
+
559
+ if should_register_for_prologue (obj .provenance ) and (obj .original_value is obj .nothing ):
560
+ if getattr (obj .provenance , "proxy" , None ) is None :
561
+ p : AnyProxy = AnyProxy (uobj , history = obj .provenance )
562
+ obj .provenance .proxy = p
563
+ obj .register_proxy (p )
564
+ uobj = p
565
+
566
+ d = _interpret_call (getattr , obj , wrap_const ("__dict__" ))
567
+ if d is INTERPRETER_SIGNALS .EXCEPTION_RAISED :
568
+ return d
569
+ d .provenance .ext_flag |= EXT_FLAG_IS_MODULE_MEMBER_DICT
570
+ ud = unwrap (d )
571
+ assert type (ud ) == dict
572
+ res = _interpret_call (ud .__setitem__ , name , value )
573
+ return res
574
+
575
+
513
576
@register_general_jit_lookaside (setattr )
514
577
def _general_jit_setattr_lookaside (obj : Any , name : str , value : Any ):
515
578
setattr_lookaside = default_lookaside (setattr )
516
579
assert setattr_lookaside is not None
517
580
518
581
uobj = unwrap (obj )
519
582
uname = unwrap (name )
583
+
520
584
if isinstance (uobj , torch .nn .Module ):
521
- # 1) modify the inner thing
522
- # 2) divert the actual setattr...
585
+ # 1) populate the wrappeers for the member dicts
586
+ # 2) let the original setattr do it's thing by modifying the
587
+ # the member dict
588
+ # This might generalize to other things, too...
523
589
for n in MODULE_MEMBER_DICT_ATTRS :
524
590
member_dict = _interpret_call (getattr , obj , wrap_const (n ))
525
591
member_dict .provenance .ext_flag |= EXT_FLAG_IS_MODULE_MEMBER_DICT
@@ -682,9 +748,12 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
682
748
683
749
custom_autograd_function_cls = unwrap (obj )
684
750
custom_forward = custom_autograd_function_cls .forward
685
- ctx = torch .autograd .function .FunctionCtx ()
751
+ typ = torch .autograd .function .FunctionCtx
752
+ ctx = typ ()
686
753
ctx_proxy = proxy (ctx , name = None , history = None )
687
- wrapped_ctx = wrap_const (ctx_proxy )
754
+ wrapped_ctx = wrap_const (
755
+ ctx_proxy , provenance = ProvenanceRecord (PseudoInst .NEW , inputs = [wrap_const (typ ).provenance ])
756
+ )
688
757
trace_of_fwd , fwd_output_provenance = _convert_pytorchfunc_to_thundertrace (
689
758
custom_forward , True , wrapped_ctx , * args , ** kwargs
690
759
)
@@ -1241,10 +1310,10 @@ def _general_jit_global_callback(orig_value: Any, name: str) -> Any:
1241
1310
return orig_value
1242
1311
1243
1312
1244
- _safe_provenance_inst = {
1313
+ _input_provenance_inst = {
1245
1314
"INPUT_ARGS" ,
1246
1315
"INPUT_KWARGS" ,
1247
- "INPUT_FN" ,
1316
+ "INPUT_FN" , # or self
1248
1317
"LOAD_ATTR" ,
1249
1318
"CONSTANT" ,
1250
1319
"BINARY_SUBSCR" ,
@@ -1259,7 +1328,7 @@ def should_register_for_prologue(pr):
1259
1328
inst = inst .opname
1260
1329
else :
1261
1330
inst = inst .value
1262
- if inst not in _safe_provenance_inst :
1331
+ if inst not in _input_provenance_inst :
1263
1332
return False
1264
1333
if inst == "CONSTANT" and callable (pr .value ):
1265
1334
if pr .value .__name__ != "__getitem__" and pr .value != GetSetDescriptorType .__get__ :
@@ -1509,6 +1578,7 @@ def from_load_attr(provenance, *, new_output=False):
1509
1578
output = Proxy (prefix = "obj" )
1510
1579
else :
1511
1580
output = p
1581
+
1512
1582
param_ordering [id (output )] = (output , param_ordering [id (orig_obj )][1 ] + [math .inf , "." + str (name )])
1513
1583
bsym = prims .unpack_attr .bind (obj , name , output = output )
1514
1584
prologue_trace .bound_symbols .append (bsym )
@@ -1766,27 +1836,43 @@ def process_recorded_modifications(ctx, epilogue_trace):
1766
1836
for k , (inst , * args ) in last_modification .items ():
1767
1837
if inst == PseudoInst .STORE_SUBSCR :
1768
1838
(value ,) = args
1769
- assert isinstance (value .value , Proxy )
1770
1839
1771
- assert modified_object .provenance .inst is PseudoInst .LOAD_ATTR
1772
- assert modified_object .provenance .inputs [1 ].inst is PseudoInst .CONSTANT
1773
- assert modified_object .provenance .inputs [1 ].value == "_buffers"
1774
-
1775
- typ , name , root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root (
1776
- modified_object .provenance .inputs [0 ]
1777
- )
1778
- assert typ == "_modules"
1779
- root_module_proxy = root_for_provenances .get (root_module_provenance )
1780
- if root_module_proxy is None :
1781
- ## we want this to created in the compute trace context for namespace...
1782
- root_module_proxy = Proxy (history = root_module_provenance )
1783
- epilogue_trace .add_name (root_module_proxy .name )
1784
- root_for_provenances [root_module_provenance ] = root_module_proxy
1785
-
1786
- name = "." .join (name + [k ])
1787
- with tracectx (epilogue_trace ):
1788
- bsym = prims .pack_buffer .bind (root_module_proxy , name , value .value , output = None )
1789
- epilogue_trace .bound_symbols .append (bsym )
1840
+ assert isinstance (value .value , (Proxy , int , tuple )) ## todo: better criterion
1841
+
1842
+ if (
1843
+ modified_object .provenance .inst is PseudoInst .LOAD_ATTR
1844
+ and modified_object .provenance .inputs [1 ].inst is PseudoInst .CONSTANT
1845
+ and modified_object .provenance .inputs [1 ].value == "_buffers"
1846
+ ):
1847
+ typ , name , root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root (
1848
+ modified_object .provenance .inputs [0 ]
1849
+ )
1850
+ assert typ == "_modules"
1851
+ root_module_proxy = root_for_provenances .get (root_module_provenance )
1852
+ if root_module_proxy is None :
1853
+ ## we want this to created in the compute trace context for namespace...
1854
+ root_module_proxy = Proxy (history = root_module_provenance )
1855
+ epilogue_trace .add_name (root_module_proxy .name )
1856
+ root_for_provenances [root_module_provenance ] = root_module_proxy
1857
+
1858
+ name = "." .join (name + [k ])
1859
+ with tracectx (epilogue_trace ):
1860
+ bsym = prims .pack_buffer .bind (root_module_proxy , name , value .value , output = None )
1861
+ epilogue_trace .bound_symbols .append (bsym )
1862
+ elif (
1863
+ modified_object .provenance .inst is PseudoInst .LOAD_ATTR
1864
+ and modified_object .provenance .inputs [1 ].inst is PseudoInst .CONSTANT
1865
+ and modified_object .provenance .inputs [1 ].value == "__dict__"
1866
+ ):
1867
+ name = k
1868
+ setattr_obj_provenance = modified_object .provenance .inputs [0 ]
1869
+ if hasattr (setattr_obj_provenance , "proxy" ):
1870
+ setattr_obj_proxy = setattr_obj_provenance .proxy
1871
+ with tracectx (epilogue_trace ):
1872
+ bsym = prims .pack_attr .bind (setattr_obj_proxy , name , value .value , output = None )
1873
+ epilogue_trace .bound_symbols .append (bsym )
1874
+ else :
1875
+ raise NotImplementedError (f"Modifications of { modified_object .provenance } are not supported" )
1790
1876
else :
1791
1877
raise NotImplementedError (f"Modifications { inst } on dicts are not supported" )
1792
1878
else :
@@ -1882,6 +1968,7 @@ def thunder_general_jit(
1882
1968
fn_lookaside = general_jit_lookaside ,
1883
1969
callbacks = general_jit_callbacks ,
1884
1970
with_provenance_tracking = True ,
1971
+ unwrap_result = False ,
1885
1972
uncacheable_classes = (torch .Tensor , int , float , str , NoneType ),
1886
1973
record_history = compile_data .debug_options .record_interpreter_history ,
1887
1974
)
@@ -1891,11 +1978,12 @@ def thunder_general_jit(
1891
1978
result = jfn (* args , ** kwargs )
1892
1979
computation_trace .set_current_source_location (None , None )
1893
1980
process_recorded_modifications (ctx , epilogue_trace )
1981
+ uresult = unwrap (result )
1894
1982
last_interpreter_log = jfn ._last_interpreter_log
1895
- result_proxies = tuple (p for p in tree_iter (result ) if isinstance (p , (TensorProxy , NumberProxy )))
1983
+ result_proxies = tuple (p for p in tree_iter (uresult ) if isinstance (p , (TensorProxy , NumberProxy )))
1896
1984
prims .python_return (result_proxies )
1897
1985
with tracectx (epilogue_trace ):
1898
- prims .python_return (result )
1986
+ prims .python_return (uresult )
1899
1987
1900
1988
pro_to_comp , pro_to_comp_set , computation_intermediates = get_computation_inputs_and_intermediates (
1901
1989
computation_trace
@@ -1958,5 +2046,4 @@ def restrict_proxy_swapmap(proxies: tuple[Proxy]) -> dict[Variable, Proxy]:
1958
2046
epilogue_trace = _apply_trace_proxy_rename (
1959
2047
epilogue_trace , restrict_proxy_swapmap (pro_to_epi_proxies + comp_to_epi_proxies ), "epilogue"
1960
2048
)
1961
-
1962
2049
return TraceResults (prologue_trace , computation_trace , epilogue_trace , last_interpreter_log )
0 commit comments