diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 726353983a..aaf2cd2964 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -1,116 +1,75 @@ -import thunder +from __future__ import annotations import math -from typing import Any, Optional, Dict, Tuple, Literal -import builtins +from typing import Any import collections -from collections.abc import ValuesView, Iterable, Iterator from collections.abc import Callable, Sequence -import weakref -import random -from functools import partial, wraps, reduce -import linecache -import operator -import copy +from functools import wraps import contextvars from contextlib import contextmanager import dis import warnings -from enum import Enum, auto -from io import StringIO -import inspect -import time - -from thunder.core.compile_data import compile_data_and_stats, get_cache_option, get_compile_data -import thunder.clang as clang -import thunder.core.transforms -from thunder.core.baseutils import run_once - from types import ( + BuiltinMethodType, CellType, - ClassMethodDescriptorType, - CodeType, - CoroutineType, - FrameType, FunctionType, - MethodType, + GetSetDescriptorType, MethodDescriptorType, + MethodType, ModuleType, NoneType, - BuiltinFunctionType, - BuiltinMethodType, - MethodWrapperType, - WrapperDescriptorType, - TracebackType, - CellType, - ModuleType, - CodeType, - BuiltinFunctionType, - FunctionType, - MethodType, - GetSetDescriptorType, UnionType, + WrapperDescriptorType, ) import torch import torch.utils.checkpoint + +from thunder.core.compile_data import compile_data_and_stats, get_cache_option, get_compile_data +import thunder.clang as clang +from thunder.core import dtypes +import thunder.core.transforms from thunder.core.proxies import ( + AnyProxy, DistParallelType, - proxy, + NumberProxy, Proxy, ProxyTag, - AnyProxy, - NumberProxy, - StringProxy, TensorProxy, - FutureTensorProxy, - make_proxy_name, Variable, - variableify, - unvariableify, is_proxy_name_available, + proxy, + unvariableify, + variableify, ) from thunder.core.trace import set_tracectx, reset_tracectx, tracectx, from_trace from thunder.core.interpreter import ( - InterpreterLogItem, - InterpreterFrame, - interpret, - _interpret_call, - CapsuleType, - default_callbacks, INTERPRETER_CALLBACKS, INTERPRETER_SIGNALS, - default_opcode_interpreter, - _default_lookaside_map, + InterpreterRuntimeCtx, + ProvenanceRecord, + PseudoInst, + WrappedValue, + _interpret_call, + default_callbacks, default_lookaside, do_raise, get_interpreterruntimectx, - InterpreterRuntimeCtx, + interpret, + interpreter_needs_wrap, is_opaque, - Py_NULL, - member_descriptor, - WrappedValue, unwrap, wrap, wrap_const, - PseudoInst, - ProvenanceRecord, - interpreter_needs_wrap, ) from thunder.core.langctxs import set_langctx, reset_langctx, Languages, resolve_language -from thunder.core.baseutils import extract_callable_name -from thunder.core.codeutils import get_siginfo, SigInfo +from thunder.core.codeutils import SigInfo import thunder.core.prims as prims -from thunder.common import transform_for_execution from thunder.core.options import CACHE_OPTIONS, SHARP_EDGES_OPTIONS, DebugOptions -from thunder.core.symbol import Symbol, BoundSymbol, is_traceable - -from thunder.extend import Executor -from thunder.common import CompileData, CompileStats +from thunder.core.symbol import Symbol from thunder.core.trace import TraceCtx, TraceResults from thunder.torch import _torch_to_thunder_function_map from thunder.clang import _clang_fn_set from thunder.core.pytree import tree_map, tree_iter -from thunder.core.compile_data import compile_data_and_stats # # jit_ext.py implements extensions of thunder's interpreter @@ -266,7 +225,9 @@ def proxify(self, value: WrappedValue) -> Any: DistParallelType.REPLICATED, DistParallelType.FULLY_SHARDED, ): - p_new = thunder.distributed.prims.synchronize( + from thunder.distributed.prims import synchronize + + p_new = synchronize( p, self._process_group_for_ddp, ) @@ -889,8 +850,8 @@ def autocast_exit(autocast_obj, exc_type, exc_val, exc_tb): @register_general_jit_lookaside(torch.finfo) @interpreter_needs_wrap -def _general_jit_torch_finfo_lookaside(dtype: thunder.dtypes.dtype): - torch_dtype = thunder.dtypes.to_torch_dtype(dtype) +def _general_jit_torch_finfo_lookaside(dtype: dtypes.dtype): + torch_dtype = dtypes.to_torch_dtype(dtype) res = torch.finfo(torch_dtype) return res @@ -1400,6 +1361,8 @@ def get_parameter_or_buffer_or_submodule_name_and_root(provenance): def unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs): + from thunder import _get_cache_info + already_unpacked: dict[int, Proxy] = {} orig_modules: dict[int, Proxy] = {} @@ -1671,7 +1634,7 @@ def is_variableified_tensorproxy(v: Variable | Proxy) -> Proxy: prim(*args) - cache_info = thunder._get_cache_info() + cache_info = _get_cache_info() # assert len of cache info to ensure that we're not missing anything? if cache_info: cache_info_p = Proxy(name="cache_info") @@ -1760,7 +1723,7 @@ def bind_inputs(name, trace, input_vars, input_proxies): trace.args = input_proxies -def _get_process_group_from(*fn_and_args) -> Optional["ProcessGroup"]: +def _get_process_group_from(*fn_and_args) -> Optional[ProcessGroup]: # `ddp` and `fsdp` transforms add attribute `procses_group_for_ddp` # on the Module that they wrap. This module could be passed to `thunder.jit` # as the function to be jitted or as an argument of the function to be jitted. @@ -1815,7 +1778,7 @@ def thunder_general_jit( compile_data = get_compile_data() executor_lookasides = {k: interpreter_needs_wrap(v) for k, v in compile_data.executor_lookasides.items()} - process_group_for_ddp: Optional["ProcessGroup"] = _get_process_group_from(fn, *args, *kwargs.values()) + process_group_for_ddp: Optional[ProcessGroup] = _get_process_group_from(fn, *args, *kwargs.values()) ctx: JitCtx = JitCtx( prologue_trace, computation_trace,