Skip to content

Commit ee736c1

Browse files
committed
more general setattr
1 parent 780407d commit ee736c1

File tree

3 files changed

+178
-41
lines changed

3 files changed

+178
-41
lines changed

thunder/core/interpreter.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,9 @@ class PseudoInst(str, enum.Enum):
932932
SUPER = "SUPER"
933933
BUILTINS = "BUILTINS"
934934
STORE_SUBSCR = "STORE_SUBSCR"
935+
STORE_ATTR = "STORE_ATTR"
935936
LIST_TO_TUPLE = "LIST_TO_TUPLE"
937+
NEW = "NEW"
936938

937939

938940
@dataclasses.dataclass
@@ -2073,9 +2075,13 @@ def impl(fn, iterable, initializer, null):
20732075
return _interpret_call(impl, fn, iterable, initializer, null)
20742076

20752077

2078+
class ThunderInterpreterObject:
2079+
pass
2080+
2081+
20762082
# An iterator to be returned from Sequence.__iter__ lookasides below. This will be run in the interpreter
20772083
# Note: this potentially might imitate a list_iterator / tuple_iterator more...
2078-
class SequenceIter:
2084+
class SequenceIter(ThunderInterpreterObject):
20792085
def __init__(self, s, is_reversed=False):
20802086
self.s = s
20812087
self.next_pos = 0 if not is_reversed else len(s) - 1
@@ -2377,7 +2383,7 @@ def reverse(self, /):
23772383
return wrap_const(None)
23782384

23792385

2380-
class MappingKeysIterator(Iterator):
2386+
class MappingKeysIterator(Iterator, ThunderInterpreterObject):
23812387
# note: the __init__ will be executed by Python itself, and
23822388
# the caller needs to set up the wrapped_attribute for _mapping
23832389
# The other methods are called through the interpreter mechanism.
@@ -2395,7 +2401,7 @@ def __next__(self):
23952401
return k
23962402

23972403

2398-
class MappingKeysView:
2404+
class MappingKeysView(ThunderInterpreterObject):
23992405
def __init__(self, mapping):
24002406
self._mapping = mapping
24012407

@@ -2425,7 +2431,7 @@ def __reversed__(self):
24252431
return mapping_iter
24262432

24272433

2428-
class MappingValuesIterator:
2434+
class MappingValuesIterator(ThunderInterpreterObject):
24292435
def __init__(self, mapping, is_reversed=False):
24302436
self._mapping = mapping
24312437
if is_reversed:
@@ -2440,15 +2446,15 @@ def __next__(self):
24402446
return dict.__getitem__(self._mapping, next(self._key_iter))
24412447

24422448

2443-
class MappingValuesWrapper:
2449+
class MappingValuesWrapper(ThunderInterpreterObject):
24442450
def __init__(self, mapping):
24452451
self._mapping = mapping
24462452

24472453
def __iter__(self):
24482454
return MappingValuesIterator(self._mapping)
24492455

24502456

2451-
class MappingItemsIterator:
2457+
class MappingItemsIterator(ThunderInterpreterObject):
24522458
def __init__(self, mapping, is_reversed=False):
24532459
self._mapping = mapping
24542460
if is_reversed:
@@ -2464,7 +2470,7 @@ def __next__(self):
24642470
return k, dict.__getitem__(self._mapping, k)
24652471

24662472

2467-
class MappingItemsWrapper:
2473+
class MappingItemsWrapper(ThunderInterpreterObject):
24682474
def __init__(self, mapping):
24692475
self._mapping = mapping
24702476

@@ -2476,7 +2482,7 @@ class MutMappingWrapperMethods(WrappedValue):
24762482
def __new__(cls, /, *args, **kwds):
24772483
uvalue = unwrap(cls)()
24782484
# todo: for subclasses, better record the call to the constructor
2479-
return wrap_const(uvalue)
2485+
return wrap(uvalue, provenance=ProvenanceRecord(PseudoInst.NEW, inputs=[cls.provenance]))
24802486

24812487
def __init__(self, *other, **kwds):
24822488
MutMappingWrapperMethods.update(self, *other, **kwds)
@@ -2775,7 +2781,6 @@ def _type_call_lookaside(wrapped_typ, *args, **kwargs):
27752781
obj = _interpret_call(typ.__new__, wrapped_typ, *args, **kwargs)
27762782
if obj is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
27772783
return obj
2778-
27792784
wrapped_init = _interpret_call(getattr, obj, wrap_const("__init__"))
27802785
assert not isinstance(wrapped_init, INTERPRETER_SIGNALS)
27812786
populate_attribute_wrapper(wrapped_init, "__self__", obj)
@@ -7151,6 +7156,7 @@ def interpret(
71517156
callbacks: dict[INTERPRETER_CALLBACKS, Callable] = default_callbacks,
71527157
debug_log: None | StringIO = None,
71537158
with_provenance_tracking: bool = False,
7159+
unwrap_result: bool = True,
71547160
uncacheable_classes: list[type] | None = None,
71557161
record_history: bool = False,
71567162
) -> Callable:
@@ -7205,7 +7211,8 @@ def fn_2(args, kwargs):
72057211
populate_attribute_wrapper(wrapped_cell, "cell_contents", fn_wrapped)
72067212

72077213
interpretation_result: Any = _interpret_call(wrapped_fn_2, args, kwargs)
7208-
interpretation_result = unwrap(interpretation_result)
7214+
if unwrap_result:
7215+
interpretation_result = unwrap(interpretation_result)
72097216

72107217
except BaseException as e:
72117218
# TODO Highlight the portion of the line that originated the opcode on Python versions that include

thunder/core/jit_ext.py

+118-31
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
PseudoInst,
9595
ProvenanceRecord,
9696
interpreter_needs_wrap,
97+
ThunderInterpreterObject,
9798
)
9899
from thunder.core.langctxs import set_langctx, reset_langctx, Languages, resolve_language
99100
from thunder.core.baseutils import extract_callable_name
@@ -350,7 +351,7 @@ def proxify(self, value: WrappedValue) -> Any:
350351
)
351352
return proxy_s
352353
else:
353-
raise ValueError("cannot proxify value of {type(uvalue).__type} objects")
354+
raise ValueError(f"cannot proxify value of {type(uvalue).__type__} objects")
354355

355356

356357
_jit_ctx = contextvars.ContextVar("jitctx")
@@ -445,6 +446,21 @@ def _general_jit_getattr_lookaside(obj: Any, name: str, *maybe_default: Any):
445446
getattr_lookaside = default_lookaside(getattr)
446447
assert getattr_lookaside is not None
447448

449+
uobj = unwrap(obj)
450+
uname = unwrap(name)
451+
if isinstance(uobj, AnyProxy):
452+
if uname == "__dict__":
453+
return wrap(
454+
obj.original_value.__dict__,
455+
provenance=ProvenanceRecord(
456+
PseudoInst.LOAD_ATTR,
457+
inputs=[
458+
obj.provenance,
459+
name.provenance,
460+
],
461+
),
462+
)
463+
448464
value = getattr_lookaside(obj, name, *maybe_default)
449465
if value is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
450466
return value
@@ -510,16 +526,66 @@ def _general_jit_ordered_dict_setitem(d, key, value):
510526
return dict_setitem_lookaside(d, key, value)
511527

512528

529+
_TORCH_DYNAMIC_TYPES = {
530+
torch.amp.autocast_mode.autocast,
531+
torch.autograd.grad_mode.set_grad_enabled,
532+
torch.autograd.grad_mode.no_grad,
533+
}
534+
535+
536+
def is_created_during_tracing(provenance):
537+
if (
538+
provenance.inst is PseudoInst.OPAQUE
539+
and provenance.inputs[0].inst is PseudoInst.CONSTANT
540+
and provenance.inputs[0].value == object.__new__
541+
):
542+
return True
543+
if provenance.inst is PseudoInst.NEW:
544+
return True
545+
return False
546+
547+
548+
@interpreter_needs_wrap
549+
def _raw_object_setattr(obj: Any, name: str, value: Any):
550+
return object.__setattr__(obj, name, value)
551+
552+
553+
@register_general_jit_lookaside(object.__setattr__)
554+
def _general_jit_object_setattr_lookaside(obj: Any, name: str, value: Any):
555+
uobj = unwrap(obj)
556+
if is_created_during_tracing(obj.provenance) or type(uobj) in _TORCH_DYNAMIC_TYPES:
557+
return _raw_object_setattr(obj, name, value)
558+
559+
if should_register_for_prologue(obj.provenance) and (obj.original_value is obj.nothing):
560+
if getattr(obj.provenance, "proxy", None) is None:
561+
p: AnyProxy = AnyProxy(uobj, history=obj.provenance)
562+
obj.provenance.proxy = p
563+
obj.register_proxy(p)
564+
uobj = p
565+
566+
d = _interpret_call(getattr, obj, wrap_const("__dict__"))
567+
if d is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
568+
return d
569+
d.provenance.ext_flag |= EXT_FLAG_IS_MODULE_MEMBER_DICT
570+
ud = unwrap(d)
571+
assert type(ud) == dict
572+
res = _interpret_call(ud.__setitem__, name, value)
573+
return res
574+
575+
513576
@register_general_jit_lookaside(setattr)
514577
def _general_jit_setattr_lookaside(obj: Any, name: str, value: Any):
515578
setattr_lookaside = default_lookaside(setattr)
516579
assert setattr_lookaside is not None
517580

518581
uobj = unwrap(obj)
519582
uname = unwrap(name)
583+
520584
if isinstance(uobj, torch.nn.Module):
521-
# 1) modify the inner thing
522-
# 2) divert the actual setattr...
585+
# 1) populate the wrappeers for the member dicts
586+
# 2) let the original setattr do it's thing by modifying the
587+
# the member dict
588+
# This might generalize to other things, too...
523589
for n in MODULE_MEMBER_DICT_ATTRS:
524590
member_dict = _interpret_call(getattr, obj, wrap_const(n))
525591
member_dict.provenance.ext_flag |= EXT_FLAG_IS_MODULE_MEMBER_DICT
@@ -682,9 +748,12 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
682748

683749
custom_autograd_function_cls = unwrap(obj)
684750
custom_forward = custom_autograd_function_cls.forward
685-
ctx = torch.autograd.function.FunctionCtx()
751+
typ = torch.autograd.function.FunctionCtx
752+
ctx = typ()
686753
ctx_proxy = proxy(ctx, name=None, history=None)
687-
wrapped_ctx = wrap_const(ctx_proxy)
754+
wrapped_ctx = wrap_const(
755+
ctx_proxy, provenance=ProvenanceRecord(PseudoInst.NEW, inputs=[wrap_const(typ).provenance])
756+
)
688757
trace_of_fwd, fwd_output_provenance = _convert_pytorchfunc_to_thundertrace(
689758
custom_forward, True, wrapped_ctx, *args, **kwargs
690759
)
@@ -1241,10 +1310,10 @@ def _general_jit_global_callback(orig_value: Any, name: str) -> Any:
12411310
return orig_value
12421311

12431312

1244-
_safe_provenance_inst = {
1313+
_input_provenance_inst = {
12451314
"INPUT_ARGS",
12461315
"INPUT_KWARGS",
1247-
"INPUT_FN",
1316+
"INPUT_FN", # or self
12481317
"LOAD_ATTR",
12491318
"CONSTANT",
12501319
"BINARY_SUBSCR",
@@ -1259,7 +1328,7 @@ def should_register_for_prologue(pr):
12591328
inst = inst.opname
12601329
else:
12611330
inst = inst.value
1262-
if inst not in _safe_provenance_inst:
1331+
if inst not in _input_provenance_inst:
12631332
return False
12641333
if inst == "CONSTANT" and callable(pr.value):
12651334
if pr.value.__name__ != "__getitem__" and pr.value != GetSetDescriptorType.__get__:
@@ -1509,6 +1578,7 @@ def from_load_attr(provenance, *, new_output=False):
15091578
output = Proxy(prefix="obj")
15101579
else:
15111580
output = p
1581+
15121582
param_ordering[id(output)] = (output, param_ordering[id(orig_obj)][1] + [math.inf, "." + str(name)])
15131583
bsym = prims.unpack_attr.bind(obj, name, output=output)
15141584
prologue_trace.bound_symbols.append(bsym)
@@ -1766,27 +1836,43 @@ def process_recorded_modifications(ctx, epilogue_trace):
17661836
for k, (inst, *args) in last_modification.items():
17671837
if inst == PseudoInst.STORE_SUBSCR:
17681838
(value,) = args
1769-
assert isinstance(value.value, Proxy)
17701839

1771-
assert modified_object.provenance.inst is PseudoInst.LOAD_ATTR
1772-
assert modified_object.provenance.inputs[1].inst is PseudoInst.CONSTANT
1773-
assert modified_object.provenance.inputs[1].value == "_buffers"
1774-
1775-
typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root(
1776-
modified_object.provenance.inputs[0]
1777-
)
1778-
assert typ == "_modules"
1779-
root_module_proxy = root_for_provenances.get(root_module_provenance)
1780-
if root_module_proxy is None:
1781-
## we want this to created in the compute trace context for namespace...
1782-
root_module_proxy = Proxy(history=root_module_provenance)
1783-
epilogue_trace.add_name(root_module_proxy.name)
1784-
root_for_provenances[root_module_provenance] = root_module_proxy
1785-
1786-
name = ".".join(name + [k])
1787-
with tracectx(epilogue_trace):
1788-
bsym = prims.pack_buffer.bind(root_module_proxy, name, value.value, output=None)
1789-
epilogue_trace.bound_symbols.append(bsym)
1840+
assert isinstance(value.value, (Proxy, int, tuple)) ## todo: better criterion
1841+
1842+
if (
1843+
modified_object.provenance.inst is PseudoInst.LOAD_ATTR
1844+
and modified_object.provenance.inputs[1].inst is PseudoInst.CONSTANT
1845+
and modified_object.provenance.inputs[1].value == "_buffers"
1846+
):
1847+
typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root(
1848+
modified_object.provenance.inputs[0]
1849+
)
1850+
assert typ == "_modules"
1851+
root_module_proxy = root_for_provenances.get(root_module_provenance)
1852+
if root_module_proxy is None:
1853+
## we want this to created in the compute trace context for namespace...
1854+
root_module_proxy = Proxy(history=root_module_provenance)
1855+
epilogue_trace.add_name(root_module_proxy.name)
1856+
root_for_provenances[root_module_provenance] = root_module_proxy
1857+
1858+
name = ".".join(name + [k])
1859+
with tracectx(epilogue_trace):
1860+
bsym = prims.pack_buffer.bind(root_module_proxy, name, value.value, output=None)
1861+
epilogue_trace.bound_symbols.append(bsym)
1862+
elif (
1863+
modified_object.provenance.inst is PseudoInst.LOAD_ATTR
1864+
and modified_object.provenance.inputs[1].inst is PseudoInst.CONSTANT
1865+
and modified_object.provenance.inputs[1].value == "__dict__"
1866+
):
1867+
name = k
1868+
setattr_obj_provenance = modified_object.provenance.inputs[0]
1869+
if hasattr(setattr_obj_provenance, "proxy"):
1870+
setattr_obj_proxy = setattr_obj_provenance.proxy
1871+
with tracectx(epilogue_trace):
1872+
bsym = prims.pack_attr.bind(setattr_obj_proxy, name, value.value, output=None)
1873+
epilogue_trace.bound_symbols.append(bsym)
1874+
else:
1875+
raise NotImplementedError(f"Modifications of {modified_object.provenance} are not supported")
17901876
else:
17911877
raise NotImplementedError(f"Modifications {inst} on dicts are not supported")
17921878
else:
@@ -1882,6 +1968,7 @@ def thunder_general_jit(
18821968
fn_lookaside=general_jit_lookaside,
18831969
callbacks=general_jit_callbacks,
18841970
with_provenance_tracking=True,
1971+
unwrap_result=False,
18851972
uncacheable_classes=(torch.Tensor, int, float, str, NoneType),
18861973
record_history=compile_data.debug_options.record_interpreter_history,
18871974
)
@@ -1891,11 +1978,12 @@ def thunder_general_jit(
18911978
result = jfn(*args, **kwargs)
18921979
computation_trace.set_current_source_location(None, None)
18931980
process_recorded_modifications(ctx, epilogue_trace)
1981+
uresult = unwrap(result)
18941982
last_interpreter_log = jfn._last_interpreter_log
1895-
result_proxies = tuple(p for p in tree_iter(result) if isinstance(p, (TensorProxy, NumberProxy)))
1983+
result_proxies = tuple(p for p in tree_iter(uresult) if isinstance(p, (TensorProxy, NumberProxy)))
18961984
prims.python_return(result_proxies)
18971985
with tracectx(epilogue_trace):
1898-
prims.python_return(result)
1986+
prims.python_return(uresult)
18991987

19001988
pro_to_comp, pro_to_comp_set, computation_intermediates = get_computation_inputs_and_intermediates(
19011989
computation_trace
@@ -1958,5 +2046,4 @@ def restrict_proxy_swapmap(proxies: tuple[Proxy]) -> dict[Variable, Proxy]:
19582046
epilogue_trace = _apply_trace_proxy_rename(
19592047
epilogue_trace, restrict_proxy_swapmap(pro_to_epi_proxies + comp_to_epi_proxies), "epilogue"
19602048
)
1961-
19622049
return TraceResults(prologue_trace, computation_trace, epilogue_trace, last_interpreter_log)

0 commit comments

Comments
 (0)