Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up imports of jit_ext.py #1588

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 37 additions & 74 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
@@ -1,116 +1,75 @@
import thunder
from __future__ import annotations
import math
from typing import Any, Optional, Dict, Tuple, Literal
import builtins
from typing import Any
import collections
from collections.abc import ValuesView, Iterable, Iterator
from collections.abc import Callable, Sequence
import weakref
import random
from functools import partial, wraps, reduce
import linecache
import operator
import copy
from functools import wraps
import contextvars
from contextlib import contextmanager
import dis
import warnings
from enum import Enum, auto
from io import StringIO
import inspect
import time

from thunder.core.compile_data import compile_data_and_stats, get_cache_option, get_compile_data
import thunder.clang as clang
import thunder.core.transforms
from thunder.core.baseutils import run_once

from types import (
BuiltinMethodType,
CellType,
ClassMethodDescriptorType,
CodeType,
CoroutineType,
FrameType,
FunctionType,
MethodType,
GetSetDescriptorType,
MethodDescriptorType,
MethodType,
ModuleType,
NoneType,
BuiltinFunctionType,
BuiltinMethodType,
MethodWrapperType,
WrapperDescriptorType,
TracebackType,
CellType,
ModuleType,
CodeType,
BuiltinFunctionType,
FunctionType,
MethodType,
GetSetDescriptorType,
UnionType,
WrapperDescriptorType,
)

import torch
import torch.utils.checkpoint

from thunder.core.compile_data import compile_data_and_stats, get_cache_option, get_compile_data
import thunder.clang as clang
from thunder.core import dtypes
import thunder.core.transforms
from thunder.core.proxies import (
AnyProxy,
DistParallelType,
proxy,
NumberProxy,
Proxy,
ProxyTag,
AnyProxy,
NumberProxy,
StringProxy,
TensorProxy,
FutureTensorProxy,
make_proxy_name,
Variable,
variableify,
unvariableify,
is_proxy_name_available,
proxy,
unvariableify,
variableify,
)
from thunder.core.trace import set_tracectx, reset_tracectx, tracectx, from_trace
from thunder.core.interpreter import (
InterpreterLogItem,
InterpreterFrame,
interpret,
_interpret_call,
CapsuleType,
default_callbacks,
INTERPRETER_CALLBACKS,
INTERPRETER_SIGNALS,
default_opcode_interpreter,
_default_lookaside_map,
InterpreterRuntimeCtx,
ProvenanceRecord,
PseudoInst,
WrappedValue,
_interpret_call,
default_callbacks,
default_lookaside,
do_raise,
get_interpreterruntimectx,
InterpreterRuntimeCtx,
interpret,
interpreter_needs_wrap,
is_opaque,
Py_NULL,
member_descriptor,
WrappedValue,
unwrap,
wrap,
wrap_const,
PseudoInst,
ProvenanceRecord,
interpreter_needs_wrap,
)
from thunder.core.langctxs import set_langctx, reset_langctx, Languages, resolve_language
from thunder.core.baseutils import extract_callable_name
from thunder.core.codeutils import get_siginfo, SigInfo
from thunder.core.codeutils import SigInfo
import thunder.core.prims as prims
from thunder.common import transform_for_execution
from thunder.core.options import CACHE_OPTIONS, SHARP_EDGES_OPTIONS, DebugOptions
from thunder.core.symbol import Symbol, BoundSymbol, is_traceable

from thunder.extend import Executor
from thunder.common import CompileData, CompileStats
from thunder.core.symbol import Symbol
from thunder.core.trace import TraceCtx, TraceResults
from thunder.torch import _torch_to_thunder_function_map
from thunder.clang import _clang_fn_set
from thunder.core.pytree import tree_map, tree_iter
from thunder.core.compile_data import compile_data_and_stats

#
# jit_ext.py implements extensions of thunder's interpreter
Expand Down Expand Up @@ -266,7 +225,9 @@ def proxify(self, value: WrappedValue) -> Any:
DistParallelType.REPLICATED,
DistParallelType.FULLY_SHARDED,
):
p_new = thunder.distributed.prims.synchronize(
from thunder.distributed.prims import synchronize

p_new = synchronize(
p,
self._process_group_for_ddp,
)
Expand Down Expand Up @@ -889,8 +850,8 @@ def autocast_exit(autocast_obj, exc_type, exc_val, exc_tb):

@register_general_jit_lookaside(torch.finfo)
@interpreter_needs_wrap
def _general_jit_torch_finfo_lookaside(dtype: thunder.dtypes.dtype):
torch_dtype = thunder.dtypes.to_torch_dtype(dtype)
def _general_jit_torch_finfo_lookaside(dtype: dtypes.dtype):
torch_dtype = dtypes.to_torch_dtype(dtype)
res = torch.finfo(torch_dtype)
return res

Expand Down Expand Up @@ -1400,6 +1361,8 @@ def get_parameter_or_buffer_or_submodule_name_and_root(provenance):


def unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs):
from thunder import _get_cache_info

already_unpacked: dict[int, Proxy] = {}
orig_modules: dict[int, Proxy] = {}

Expand Down Expand Up @@ -1671,7 +1634,7 @@ def is_variableified_tensorproxy(v: Variable | Proxy) -> Proxy:

prim(*args)

cache_info = thunder._get_cache_info()
cache_info = _get_cache_info()
# assert len of cache info to ensure that we're not missing anything?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the advantage here?
To my mind, the original version is clearer because you immediately see that it is a function from a different module (and which). I can see how the shorter way is important for things that get used a lot, but this seems distributing the information that this is a utility function from the thunder/__init__.py int two lines that are 300 lines apart.

if cache_info:
cache_info_p = Proxy(name="cache_info")
Expand Down Expand Up @@ -1760,7 +1723,7 @@ def bind_inputs(name, trace, input_vars, input_proxies):
trace.args = input_proxies


def _get_process_group_from(*fn_and_args) -> Optional["ProcessGroup"]:
def _get_process_group_from(*fn_and_args) -> Optional[ProcessGroup]:
# `ddp` and `fsdp` transforms add attribute `procses_group_for_ddp`
# on the Module that they wrap. This module could be passed to `thunder.jit`
# as the function to be jitted or as an argument of the function to be jitted.
Expand Down Expand Up @@ -1815,7 +1778,7 @@ def thunder_general_jit(
compile_data = get_compile_data()
executor_lookasides = {k: interpreter_needs_wrap(v) for k, v in compile_data.executor_lookasides.items()}

process_group_for_ddp: Optional["ProcessGroup"] = _get_process_group_from(fn, *args, *kwargs.values())
process_group_for_ddp: Optional[ProcessGroup] = _get_process_group_from(fn, *args, *kwargs.values())
ctx: JitCtx = JitCtx(
prologue_trace,
computation_trace,
Expand Down
Loading