Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
0bc8457
Remove unused make_proxy_name import from jit_ext.py
IvanYashchuk Dec 30, 2024
97e346a
Do not ignore specified name for Proxy creation
IvanYashchuk Dec 30, 2024
2204893
Make TensorProxy requires_grad argument to be False by default
IvanYashchuk Dec 30, 2024
369e8ad
Allow Proxy creation without a set TraceCtx
IvanYashchuk Dec 30, 2024
58d6151
Remove unused code
IvanYashchuk Dec 30, 2024
12ef335
Use prefix='m' instead of name='m' in jit_ext.py
IvanYashchuk Dec 30, 2024
70a5b92
Merge branch 'proxy-update1' into proxy-update2
IvanYashchuk Dec 30, 2024
cea37a3
Merge branch 'proxy-update2' into proxy-update3
IvanYashchuk Dec 30, 2024
c66b78b
Use prefix='obj' instead of name 'obj' in jit_ext.py
IvanYashchuk Dec 30, 2024
05cf51d
Use prefix='subscr' instead of name 'subscr' in jit_ext.py
IvanYashchuk Dec 30, 2024
8d4507a
Merge branch 'proxy-update1' into proxy-update2
IvanYashchuk Dec 30, 2024
7b0b34c
Merge branch 'proxy-update2' into proxy-update3
IvanYashchuk Dec 30, 2024
bfc5813
Use prefix='module' instead of name 'module' in jit_ext.py
IvanYashchuk Dec 30, 2024
eb92cc5
Merge branch 'proxy-update1' into proxy-update2
IvanYashchuk Dec 31, 2024
8e436cf
Merge branch 'proxy-update2' into proxy-update3
IvanYashchuk Dec 31, 2024
b639afb
Merge remote-tracking branch 'upstream/main' into proxy-update1
IvanYashchuk Jan 7, 2025
63505b0
Raise an error when the name being added to the trace is already used
IvanYashchuk Jan 7, 2025
e8a14ea
Update test to check if the error is raised
IvanYashchuk Jan 7, 2025
1913f49
Merge branch 'proxy-update1' into proxy-update2
IvanYashchuk Jan 7, 2025
609b2d4
Merge branch 'proxy-update2' into proxy-update3
IvanYashchuk Jan 7, 2025
0737e35
Discard name from computation_trace as _general_jit_wrap_callback may…
IvanYashchuk Jan 7, 2025
7e8e24d
Merge branch 'proxy-update1' into proxy-update2
IvanYashchuk Jan 17, 2025
28062ae
Merge remote-tracking branch 'upstream/main' into proxy-update2
IvanYashchuk Jan 17, 2025
0ffc0ae
Merge branch 'proxy-update2' into proxy-update3
IvanYashchuk Jan 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def is_proxy_name_available(name: None | str = None):

def make_proxy_name(*, name: None | str = None, prefix: None | str = None) -> str:
trc = get_tracectx()
if trc is None:
return name
return trc.make_name(name=name, prefix=prefix)


Expand Down Expand Up @@ -1421,7 +1423,7 @@ def __init__(
shape: ShapeLike | None = None,
device: devices.Device | None = None,
dtype: dtypes.dtype | None = None,
requires_grad: bool | None = None,
requires_grad: bool = False,
grad: TensorProxy | None = None,
prefix: None | str = None,
distparallel_type: DistParallelType | None = None,
Expand Down
21 changes: 16 additions & 5 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3118,16 +3118,27 @@ def test_debug_options():
assert dill.dumps(dict(DebugOptions.__dict__)) == initial_state


def test_default_tensor_proxy():
from thunder.core.proxies import TensorProxy
from thunder.core.dtypes import float32
from thunder.core.devices import cpu

# It should be possible to create a TensorProxy with default values for all
# optional arguments
t = TensorProxy(shape=(1,), device=cpu, dtype=float32)
assert not t.requires_grad
assert t.device == cpu
assert t.dtype == float32


def test_proxy_same_name():
from thunder.core.proxies import TensorProxy
from thunder.core.trace import detached_trace
from thunder.core.dtypes import float32
from thunder.core.devices import cpu

with detached_trace():
t = TensorProxy(name="test", shape=(1,), device=cpu, dtype=float32, requires_grad=True)
with pytest.raises(RuntimeError, match="already used"):
t2 = TensorProxy(name="test", shape=(1,), device=cpu, dtype=float32, requires_grad=True)
t = TensorProxy(name="test", shape=(1,), device=cpu, dtype=float32)
with pytest.raises(RuntimeError, match="already used"):
t2 = TensorProxy(name="test", shape=(1,), device=cpu, dtype=float32)


def test_save_trace():
Expand Down
Loading