Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add logsigmoid op #1520

Merged
merged 10 commits into from
Dec 14, 2024
17 changes: 17 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,6 +1678,23 @@ def gen(op, device, dtype, requires_grad):
elementwise_unary_ops.append(leaky_relu_opinfo)


logsigmoid_opinfo = OpInfo(
ltorch.logsigmoid,
dtypes=(datatypes.floating,),
sample_input_generator=elementwise_unary_generator,
torch_reference=torch.nn.functional.logsigmoid,
test_directives=(
# test tols are too tight for these half precision tests
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
DecorateInfo(
pytest.mark.skip,
"test_core_vs_torch_consistency",
dtypes=(datatypes.float16, datatypes.bfloat16),
),
),
)
elementwise_unary_ops.append(logsigmoid_opinfo)


relu_opinfo = OpInfo(
ltorch.relu,
sample_input_generator=elementwise_unary_generator,
Expand Down
8 changes: 8 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1812,6 +1812,14 @@ def leaky_relu(a: TensorProxy, /, negative_slope: float = 0.01, inplace: bool =
_inplace_to_out_of_place[leaky_relu] = leaky_relu, 2


@torchsymbol(torch.nn.functional.logsigmoid, is_method=False)
def logsigmoid(a: TensorProxy, /):
return log(sigmoid(a))
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved


_inplace_to_out_of_place[logsigmoid] = logsigmoid, -1
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved


# TODO Should this use clamp? -- Would that propagate NaNs properly?
@torchsymbol(torch.relu, torch.nn.functional.relu, id="torch.relu", is_method=True)
def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike:
Expand Down
1 change: 0 additions & 1 deletion thunder/torch/default_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@
torch.nn.functional.kl_div,
torch.nn.functional.l1_loss,
torch.nn.functional.local_response_norm,
torch.nn.functional.logsigmoid,
torch.nn.functional.lp_pool1d,
torch.nn.functional.lp_pool2d,
torch.nn.functional.lp_pool3d,
Expand Down
Loading