Skip to content

Commit 2d18b7a

Browse files
add atleast_{1d, 2d, 3d} ops & fix snippet_phantom_grad_vs_torch_consistency test (#1881)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7eb2e07 commit 2d18b7a

File tree

5 files changed

+100
-5
lines changed

5 files changed

+100
-5
lines changed

thunder/executors/torchex.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,6 +1175,9 @@ def _triu_transform(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | N
11751175
argmax = _register_torch_operation("argmax")
11761176
argmin = _register_torch_operation("argmin")
11771177
topk = _register_torch_operation("topk")
1178+
atleast_1d = _register_torch_operation("atleast_1d")
1179+
atleast_2d = _register_torch_operation("atleast_2d")
1180+
atleast_3d = _register_torch_operation("atleast_3d")
11781181

11791182

11801183
#
@@ -1259,6 +1262,9 @@ def _topk_transform(
12591262
_register_implementation(ltorch.argmax, argmax, checker=_always_executable)
12601263
_register_implementation(ltorch.argmin, argmin, checker=_always_executable)
12611264
_register_implementation(ltorch.topk, topk, checker=_always_executable, execution_transform=_topk_transform)
1265+
_register_implementation(ltorch.atleast_1d, atleast_1d, checker=_always_executable)
1266+
_register_implementation(ltorch.atleast_2d, atleast_2d, checker=_always_executable)
1267+
_register_implementation(ltorch.atleast_3d, atleast_3d, checker=_always_executable)
12621268

12631269

12641270
#

thunder/tests/opinfos.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5356,6 +5356,7 @@ def unsqueeze_sample_generator(op, device, dtype, requires_grad, **kwargs):
53565356

53575357
unsqueeze_opinfo = OpInfo(
53585358
clang.unsqueeze,
5359+
supports_grad=True,
53595360
sample_input_generator=unsqueeze_sample_generator,
53605361
jax_reference=jax.lax.expand_dims if JAX_AVAILABLE else None,
53615362
test_directives=(
@@ -6018,6 +6019,53 @@ def topk_error_generator(op, device, **kwargs):
60186019
reduction_ops.append(topk_opinfo)
60196020

60206021

6022+
def atleast_1d2d3d_sample_generator(op, device, dtype, requires_grad, **kwargs):
6023+
make = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
6024+
6025+
cases = (
6026+
(),
6027+
(4,),
6028+
(5, 5),
6029+
(6, 7, 8),
6030+
(3, 3, 3, 3),
6031+
)
6032+
6033+
for c in cases:
6034+
yield SampleInput(make(c))
6035+
6036+
yield SampleInput(make(()), make((2,)))
6037+
yield SampleInput(make((2,)), make((5, 5)))
6038+
yield SampleInput(make(()), make((2,)), make((4, 4)))
6039+
yield SampleInput(make(2, 3), make(4, 5), make(6, 6, 6), make(5, 5, 5, 5))
6040+
6041+
6042+
atleast_1d_opinfo = OpInfo(
6043+
ltorch.atleast_1d,
6044+
supports_grad=True,
6045+
sample_input_generator=atleast_1d2d3d_sample_generator,
6046+
torch_reference=torch.atleast_1d,
6047+
)
6048+
reduction_ops.append(atleast_1d_opinfo)
6049+
6050+
6051+
atleast_2d_opinfo = OpInfo(
6052+
ltorch.atleast_2d,
6053+
supports_grad=True,
6054+
sample_input_generator=atleast_1d2d3d_sample_generator,
6055+
torch_reference=torch.atleast_2d,
6056+
)
6057+
reduction_ops.append(atleast_2d_opinfo)
6058+
6059+
6060+
atleast_3d_opinfo = OpInfo(
6061+
ltorch.atleast_3d,
6062+
supports_grad=True,
6063+
sample_input_generator=atleast_1d2d3d_sample_generator,
6064+
torch_reference=torch.atleast_3d,
6065+
)
6066+
reduction_ops.append(atleast_3d_opinfo)
6067+
6068+
60216069
opinfos.extend(reduction_ops)
60226070

60236071

thunder/tests/test_grad.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,7 +1320,12 @@ def is_output_differentiable(x):
13201320
# torch.return_types.topk(
13211321
# values=tensor([1., 1.]),
13221322
# indices=tensor([0, 1]))
1323-
return x.grad_fn is not None
1323+
return x.grad_fn is not None or is_returning_self(x)
1324+
1325+
def is_returning_self(x):
1326+
if x.is_leaf and x.requires_grad:
1327+
return True
1328+
return False
13241329

13251330
def filter_differentiable_outputs(outputs):
13261331
if isinstance(outputs, torch.Tensor):
@@ -1380,7 +1385,10 @@ def upcast_tensors(x: Any) -> Any:
13801385
thunder_flat_grads = grad_op(*sample.args, **sample.kwargs)
13811386

13821387
assert_closer(
1383-
reference=reference_grad_result, candidate=thunder_flat_grads, competitor=torch_grad_result, comparator=comp
1388+
reference=reference_grad_result,
1389+
candidate=thunder_flat_grads,
1390+
competitor=torch_grad_result,
1391+
comparator=comp,
13841392
)
13851393

13861394

thunder/torch/__init__.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3005,6 +3005,42 @@ def topk(
30053005
return clang.topk(a, k, dim, largest, sorted, out=out)
30063006

30073007

3008+
@torchsymbol(torch.atleast_1d, is_method=True)
3009+
def atleast_1d(*args: Union[TensorLike, Sequence[TensorLike]]) -> Union[TensorLike, tuple[TensorLike, ...]]:
3010+
res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args)
3011+
return res if len(res) > 1 else res[0]
3012+
3013+
3014+
@torchsymbol(torch.atleast_2d, is_method=True)
3015+
def atleast_2d(*args: Union[TensorLike, Sequence[TensorLike]]) -> Union[TensorLike, tuple[TensorLike, ...]]:
3016+
3017+
def _unsqueeze_atleast(a):
3018+
if a.ndim == 0:
3019+
return a.unsqueeze(0).unsqueeze(1)
3020+
elif a.ndim == 1:
3021+
return a.unsqueeze(0)
3022+
return a
3023+
3024+
res = tuple(_unsqueeze_atleast(a) if isinstance(a, TensorProxy) else a for a in args)
3025+
return res if len(res) > 1 else res[0]
3026+
3027+
3028+
@torchsymbol(torch.atleast_3d, is_method=True)
3029+
def atleast_3d(*args: Union[TensorLike, Sequence[TensorLike]]) -> Union[TensorLike, tuple[TensorLike, ...]]:
3030+
3031+
def _unsqueeze_atleast(a):
3032+
if a.ndim == 0:
3033+
return a.reshape(1, 1, 1)
3034+
elif a.ndim == 1:
3035+
return a.reshape(1, -1, 1)
3036+
elif a.ndim == 2:
3037+
return a.unsqueeze(-1)
3038+
return a
3039+
3040+
res = tuple(_unsqueeze_atleast(a) if isinstance(a, TensorProxy) else a for a in args)
3041+
return res if len(res) > 1 else res[0]
3042+
3043+
30083044
@torchsymbol(torch.sort, is_method=True)
30093045
def sort(
30103046
a: TensorLike, /, dim: None | int = None, descending: bool = False, stable: bool = False, *, out=None

thunder/torch/default_torch_ops.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@
2626
torch.arctanh,
2727
torch.argsort,
2828
torch.argwhere,
29-
torch.atleast_1d,
30-
torch.atleast_2d,
31-
torch.atleast_3d,
3229
torch.batch_norm_backward_elemt,
3330
torch.batch_norm_backward_reduce,
3431
torch.batch_norm_elemt,

0 commit comments

Comments
 (0)