Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add logsigmoid op #1520

Merged
merged 10 commits into from
Dec 14, 2024
27 changes: 27 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
)
import thunder.clang as clang
from thunder.clang import (
empty,
full,
full_like,
unsqueeze,
Expand Down Expand Up @@ -1435,6 +1436,32 @@ def _copy_with_setitem_grad(a: TensorProxy, index, value: Number | TensorProxy):

register_grad(pids.COPY_WITH_SETITEM, _copy_with_setitem_grad)


# def _log_sigmoid_grad(
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
# a: TensorProxy,
# ) -> TensorProxy:
# from thunder.torch import abs, exp, log_sigmoid_backward, logsigmoid

# fwd = logsigmoid(a)

# g = get_grad(fwd)
# if a.device.type == "cpu":
# # NOTE PyTorch's CPU computation for logsigmoid's grad uses an additional "buffer" tensor, see
# # https://github.com/pytorch/pytorch/blob/7667235a23e2ffca4d32e6e16aa60a683418e159/torch/_decomp/decompositions.py#L332
# buffer = exp(-abs(a))
# a_grad = log_sigmoid_backward(g, a, buffer)
# else:
# # Here a placeholder tensor is provided.
# placeholder_buffer = empty((0,), device=a.device, dtype=a.dtype)
# a_grad = log_sigmoid_backward(g, a, placeholder_buffer)
# put_grad(a, a_grad)

# return fwd


# register_grad("torch.nn.functional.logsigmoid", _log_sigmoid_grad)


#
# Phantom grad transform helpers
#
Expand Down
34 changes: 30 additions & 4 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,11 +835,15 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor:
celu = _register_torch_operation("celu", module=torch.nn.functional)
elu = _register_torch_operation("elu", module=torch.nn.functional)
gelu = _register_torch_operation("gelu", module=torch.nn.functional)
hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional)
hardswish = _register_torch_operation("hardswish", module=torch.nn.functional)
leaky_relu = _register_torch_operation("leaky_relu", module=torch.nn.functional)
logsigmoid = _register_torch_operation("logsigmoid", module=torch.nn.functional)
log_sigmoid_backward = _register_torch_operation(
"torch.ops.aten.log_sigmoid_backward", like=ltorch.log_sigmoid_backward
)
relu = _register_torch_operation("relu", module=torch.nn.functional)
relu6 = _register_torch_operation("relu6", module=torch.nn.functional)
hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional)
hardswish = _register_torch_operation("hardswish", module=torch.nn.functional)
selu = _register_torch_operation("selu", module=torch.nn.functional)
silu = _register_torch_operation("silu", module=torch.nn.functional)

Expand All @@ -851,11 +855,33 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F
_register_elementwise_unary_implementation(ltorch.elu, elu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.celu, celu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.gelu, gelu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.leaky_relu, leaky_relu, checker=_always_executable)
_register_elementwise_unary_implementation(
ltorch.log_sigmoid_backward, log_sigmoid_backward, checker=_always_executable
)
# _register_elementwise_unary_implementation(ltorch.logsigmoid, logsigmoid)


def log_sigmoid_grad_transform(a):
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
fwd = logsigmoid(a)

g = get_grad(fwd)
# NOTE PyTorch's CPU computation for logsigmoid's grad uses an additional "buffer" tensor, see
# https://github.com/pytorch/pytorch/blob/7667235a23e2ffca4d32e6e16aa60a683418e159/torch/_decomp/decompositions.py#L332
buffer = exp(-abs(a))
a_grad = log_sigmoid_backward(g, a, buffer)

put_grad(a, a_grad)
return fwd


ex.register_implementation(
ltorch.logsigmoid, logsigmoid, checker=_elementwise_unary_checker, grad_transform=log_sigmoid_grad_transform
)
_register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.silu, silu, checker=_always_executable)

Expand Down
11 changes: 11 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,17 @@ def gen(op, device, dtype, requires_grad):
elementwise_unary_ops.append(leaky_relu_opinfo)


logsigmoid_opinfo = OpInfo(
ltorch.logsigmoid,
dtypes=(datatypes.floating,),
sample_input_generator=elementwise_unary_generator,
torch_reference=torch.nn.functional.logsigmoid,
domain=(-1, 1),
test_directives=(),
)
elementwise_unary_ops.append(logsigmoid_opinfo)


relu_opinfo = OpInfo(
ltorch.relu,
sample_input_generator=elementwise_unary_generator,
Expand Down
16 changes: 13 additions & 3 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1812,6 +1812,19 @@ def leaky_relu(a: TensorProxy, /, negative_slope: float = 0.01, inplace: bool =
_inplace_to_out_of_place[leaky_relu] = leaky_relu, 2


@torchsymbol(torch.nn.functional.logsigmoid, is_method=False)
def logsigmoid(a: TensorProxy, /) -> TensorLike:
return where(a > 0, -log1p(exp(-a)), a - log1p(exp(a)))


@torchsymbol("log_sigmoid_backward", id="log_sigmoid_backward")
def log_sigmoid_backward(g: TensorProxy, a: TensorProxy, buffer: TensorProxy) -> TensorLike:
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
# buffer is used by PyTorch in cpu-based calculations. See
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
# https://github.com/pytorch/pytorch/blob/7667235a23e2ffca4d32e6e16aa60a683418e159/torch/_decomp/decompositions.py#L332
# This is addressed in the custom grad fn thunder.core.transforms._log_sigmoid_grad.
return g * where(a > 0, exp(-a) / (1 + exp(-a)), 1 - exp(a) / (1 + exp(a)))


# TODO Should this use clamp? -- Would that propagate NaNs properly?
@torchsymbol(torch.relu, torch.nn.functional.relu, id="torch.relu", is_method=True)
def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike:
Expand Down Expand Up @@ -1858,9 +1871,6 @@ def hardshrink(a: TensorProxy, /, lambd: float = 0.5) -> TensorLike:
return where(abs(a) <= lambd, 0, a)


_inplace_to_out_of_place[hardshrink] = hardshrink, -1


@torchsymbol(torch.nn.functional.hardswish, id="torch.hardswish", is_method=False)
def hardswish(a: TensorProxy, /, inplace: bool = False) -> TensorLike:
utils.check(
Expand Down
1 change: 0 additions & 1 deletion thunder/torch/default_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@
torch.nn.functional.kl_div,
torch.nn.functional.l1_loss,
torch.nn.functional.local_response_norm,
torch.nn.functional.logsigmoid,
torch.nn.functional.lp_pool1d,
torch.nn.functional.lp_pool2d,
torch.nn.functional.lp_pool3d,
Expand Down
Loading