Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add logsigmoid op #1520

Merged
merged 10 commits into from
Dec 14, 2024
15 changes: 15 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2217,6 +2217,21 @@ def embedding_backward(a, num_weights, padding_idx, scale_grad_by_freq, sparse,
return gweight


@register_augmented_forward("torch.nn.functional.logsigmoid")
def log_sigmoid_aug_fwd(a):
from thunder.torch import logsigmoid, relu

primal = logsigmoid(a)
return VJPDual(primal, (a, a))


@register_backward("torch.nn.functional.logsigmoid")
def log_sigmoid_backward(a, _, g):
from thunder.torch import log_sigmoid_backward

return log_sigmoid_backward(g, a, _)


@register_augmented_forward("torch.cumsum")
def cumsum_aug_fwd(a: Proxy, dim: int, *, dtype: None | dtypes.dtype = None) -> VJPDual:
from thunder.torch import cumsum
Expand Down
16 changes: 12 additions & 4 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,11 +835,15 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor:
celu = _register_torch_operation("celu", module=torch.nn.functional)
elu = _register_torch_operation("elu", module=torch.nn.functional)
gelu = _register_torch_operation("gelu", module=torch.nn.functional)
hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional)
hardswish = _register_torch_operation("hardswish", module=torch.nn.functional)
leaky_relu = _register_torch_operation("leaky_relu", module=torch.nn.functional)
logsigmoid = _register_torch_operation("logsigmoid", module=torch.nn.functional)
log_sigmoid_backward = _register_torch_operation(
"torch.ops.aten.log_sigmoid_backward", like=ltorch.log_sigmoid_backward
)
relu = _register_torch_operation("relu", module=torch.nn.functional)
relu6 = _register_torch_operation("relu6", module=torch.nn.functional)
hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional)
hardswish = _register_torch_operation("hardswish", module=torch.nn.functional)
selu = _register_torch_operation("selu", module=torch.nn.functional)
silu = _register_torch_operation("silu", module=torch.nn.functional)

Expand All @@ -851,11 +855,15 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F
_register_elementwise_unary_implementation(ltorch.elu, elu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.celu, celu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.gelu, gelu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.leaky_relu, leaky_relu, checker=_always_executable)
_register_elementwise_unary_implementation(
ltorch.log_sigmoid_backward, log_sigmoid_backward, checker=_always_executable
)
_register_elementwise_unary_implementation(ltorch.logsigmoid, logsigmoid)
_register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.silu, silu, checker=_always_executable)

Expand Down
11 changes: 11 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,6 +1678,17 @@ def gen(op, device, dtype, requires_grad):
elementwise_unary_ops.append(leaky_relu_opinfo)


logsigmoid_opinfo = OpInfo(
ltorch.logsigmoid,
dtypes=(datatypes.floating,),
sample_input_generator=elementwise_unary_generator,
torch_reference=torch.nn.functional.logsigmoid,
domain=(-1, 1),
test_directives=(),
)
elementwise_unary_ops.append(logsigmoid_opinfo)


relu_opinfo = OpInfo(
ltorch.relu,
sample_input_generator=elementwise_unary_generator,
Expand Down
19 changes: 19 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1812,6 +1812,25 @@ def leaky_relu(a: TensorProxy, /, negative_slope: float = 0.01, inplace: bool =
_inplace_to_out_of_place[leaky_relu] = leaky_relu, 2


@torchsymbol(torch.nn.functional.logsigmoid, is_method=False)
def logsigmoid(a: TensorProxy, /) -> TensorLike:
return where(a > 0, -log1p(exp(-a)), a - log1p(exp(a)))


_inplace_to_out_of_place[logsigmoid] = logsigmoid, -1
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved


# @torchsymbol("log_sigmoid_backward", id="log_sigmoid_backward")
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
def log_sigmoid_backward(g: TensorProxy, a: TensorProxy, _: TensorProxy) -> TensorLike:
exp_a = exp(-abs(a))
z = exp_a / (1 + exp_a)
return g * where(a > 0, z, 1 - z)
# return g * where(a > 0, exp(-a) / (1 + exp(-a)), 1 - exp(a) / (1 + exp(a)))


_inplace_to_out_of_place[log_sigmoid_backward] = log_sigmoid_backward, -1


# TODO Should this use clamp? -- Would that propagate NaNs properly?
@torchsymbol(torch.relu, torch.nn.functional.relu, id="torch.relu", is_method=True)
def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike:
Expand Down
1 change: 0 additions & 1 deletion thunder/torch/default_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@
torch.nn.functional.kl_div,
torch.nn.functional.l1_loss,
torch.nn.functional.local_response_norm,
torch.nn.functional.logsigmoid,
torch.nn.functional.lp_pool1d,
torch.nn.functional.lp_pool2d,
torch.nn.functional.lp_pool3d,
Expand Down
Loading