Skip to content

Commit

Permalink
Iterate over fwd_args for hopefully more precise new_fwd_args (#1565
Browse files Browse the repository at this point in the history
)

Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar authored Dec 18, 2024
1 parent d201c8c commit c4148ce
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,22 @@ def _generate_random_str_id() -> str:
# note that this key is quite new: https://github.com/pytorch/pytorch/pull/134087
# non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx")
length_of_tensor_args = sum(args_tensor_mask)
new_fwd_args = (wrap_const(None),) + fwd_args[:length_of_tensor_args]

# N.B.(crcrpar) When `torch.compile(..., dynamic=True)`,
# GraphModules' forward seem to take `SymInt` and other values
# as its argument with some probability. Though that piece of information unfortunately
# does not seem to be indicated in ``args_tensor_`` nor ``non_differentiable_idx``.
# Thus we optimistically iterate over ``fwd_args`` and gather non-tensor values to ``fwd_args``.
new_fwd_args = []
for i, v in enumerate(fwd_args):
if i < length_of_tensor_args:
new_fwd_args.append(v)
else:
# note(crcrpar): we might want to include `FutureTensorProxy` and
# a proxy of tensor subclass in the near future.
if not isinstance(unwrap(v), TensorProxy):
new_fwd_args.append(v)
new_fwd_args = (wrap_const(None),) + tuple(new_fwd_args)

aug_fwd_trace, aug_fwd_provenance = _convert_pytorchfunc_to_thundertrace(fwd, False, *new_fwd_args)
if aug_fwd_trace is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
Expand Down

0 comments on commit c4148ce

Please sign in to comment.