Skip to content

Commit

Permalink
Add support for unary negation operator (#17560)
Browse files Browse the repository at this point in the history
This PR adds support for unary negation operator in libcudf and plumbs the changes through cudf python and cudf polars.

Authors:
  - Matthew Murray (https://github.com/Matt711)

Approvers:
  - Basit Ayantunde (https://github.com/lamarrr)
  - Matthew Roeschke (https://github.com/mroeschke)
  - David Wendt (https://github.com/davidwendt)
  - Bradley Dice (https://github.com/bdice)
  - Lawrence Mitchell (https://github.com/wence-)

URL: #17560
  • Loading branch information
Matt711 authored Jan 31, 2025
1 parent 94229d5 commit 51b0f9e
Show file tree
Hide file tree
Showing 13 changed files with 198 additions and 41 deletions.
3 changes: 2 additions & 1 deletion cpp/include/cudf/unary.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2024, NVIDIA CORPORATION.
* Copyright (c) 2018-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -139,6 +139,7 @@ enum class unary_operator : int32_t {
RINT, ///< Rounds the floating-point argument arg to an integer value
BIT_INVERT, ///< Bitwise Not (~)
NOT, ///< Logical Not (!)
NEGATE, ///< Unary negation (-), only for signed numeric and duration types.
};

/**
Expand Down
6 changes: 5 additions & 1 deletion cpp/include/cudf/utilities/traits.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -94,6 +94,8 @@ constexpr inline bool has_common_type_v = detail::has_common_type_impl<void, Ts.
/// Checks if a type is a timestamp type.
template <typename T>
using is_timestamp_t = cuda::std::disjunction<std::is_same<cudf::timestamp_D, T>,
std::is_same<cudf::timestamp_h, T>,
std::is_same<cudf::timestamp_m, T>,
std::is_same<cudf::timestamp_s, T>,
std::is_same<cudf::timestamp_ms, T>,
std::is_same<cudf::timestamp_us, T>,
Expand All @@ -102,6 +104,8 @@ using is_timestamp_t = cuda::std::disjunction<std::is_same<cudf::timestamp_D, T>
/// Checks if a type is a duration type.
template <typename T>
using is_duration_t = cuda::std::disjunction<std::is_same<cudf::duration_D, T>,
std::is_same<cudf::duration_h, T>,
std::is_same<cudf::duration_m, T>,
std::is_same<cudf::duration_s, T>,
std::is_same<cudf::duration_ms, T>,
std::is_same<cudf::duration_us, T>,
Expand Down
56 changes: 52 additions & 4 deletions cpp/src/unary/math_ops.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -234,6 +234,16 @@ struct DeviceNot {
}
};

// negation

struct DeviceNegate {
template <typename T>
T __device__ operator()(T data)
{
return -data;
}
};

// fixed_point ops

/*
Expand Down Expand Up @@ -278,6 +288,12 @@ struct fixed_point_abs {
__device__ T operator()(T data) { return numeric::detail::abs(data); }
};

template <typename T>
struct fixed_point_negate {
T n;
__device__ T operator()(T data) { return -data; }
};

template <typename T, template <typename> typename FixedPointFunctor>
std::unique_ptr<column> unary_op_with(column_view const& input,
rmm::cuda_stream_view stream,
Expand Down Expand Up @@ -414,6 +430,34 @@ struct MathOpDispatcher {
}
};

template <typename UFN>
struct NegateOpDispatcher {
template <typename T>
static constexpr bool is_supported()
{
return std::is_signed_v<T> || cudf::is_duration<T>();
}

template <typename T, std::enable_if_t<is_supported<T>()>* = nullptr>
std::unique_ptr<cudf::column> operator()(cudf::column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
return transform_fn<T, UFN>(input.begin<T>(),
input.end<T>(),
cudf::detail::copy_bitmask(input, stream, mr),
input.null_count(),
stream,
mr);
}

template <typename T, typename... Args>
std::enable_if_t<!is_supported<T>(), std::unique_ptr<cudf::column>> operator()(Args&&...)
{
CUDF_FAIL("Unsupported data type for negate operation");
}
};

template <typename UFN>
struct BitwiseOpDispatcher {
template <typename T, std::enable_if_t<std::is_integral_v<T>>* = nullptr>
Expand Down Expand Up @@ -550,9 +594,10 @@ struct FixedPointOpDispatcher {
{
// clang-format off
switch (op) {
case cudf::unary_operator::CEIL: return unary_op_with<T, fixed_point_ceil>(input, stream, mr);
case cudf::unary_operator::FLOOR: return unary_op_with<T, fixed_point_floor>(input, stream, mr);
case cudf::unary_operator::ABS: return unary_op_with<T, fixed_point_abs>(input, stream, mr);
case cudf::unary_operator::CEIL: return unary_op_with<T, fixed_point_ceil>(input, stream, mr);
case cudf::unary_operator::FLOOR: return unary_op_with<T, fixed_point_floor>(input, stream, mr);
case cudf::unary_operator::ABS: return unary_op_with<T, fixed_point_abs>(input, stream, mr);
case cudf::unary_operator::NEGATE: return unary_op_with<T, fixed_point_negate>(input, stream, mr);
default: CUDF_FAIL("Unsupported fixed_point unary operation");
}
// clang-format on
Expand Down Expand Up @@ -639,6 +684,9 @@ std::unique_ptr<cudf::column> unary_operation(cudf::column_view const& input,
case cudf::unary_operator::NOT:
return cudf::type_dispatcher(
input.type(), detail::LogicalOpDispatcher<detail::DeviceNot>{}, input, stream, mr);
case cudf::unary_operator::NEGATE:
return cudf::type_dispatcher(
input.type(), detail::NegateOpDispatcher<detail::DeviceNegate>{}, input, stream, mr);
default: CUDF_FAIL("Undefined unary operation");
}
}
Expand Down
67 changes: 65 additions & 2 deletions cpp/tests/unary/math_ops_test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,69 @@

#include <vector>

using TypesToNegate = cudf::test::Types<int8_t,
int16_t,
int32_t,
int64_t,
float,
double,
cudf::duration_D,
cudf::duration_s,
cudf::duration_ms,
cudf::duration_us,
cudf::duration_ns>;

template <typename T>
struct UnaryNegateTests : public cudf::test::BaseFixture {};

TYPED_TEST_SUITE(UnaryNegateTests, TypesToNegate);

TYPED_TEST(UnaryNegateTests, SimpleNEGATE)
{
using T = TypeParam;
cudf::test::fixed_width_column_wrapper<T> input{{0, 1, 2, 3}};
auto const v = cudf::test::make_type_param_vector<T>({0, -1, -2, -3});
cudf::test::fixed_width_column_wrapper<T> expected(v.begin(), v.end());
auto output = cudf::unary_operation(input, cudf::unary_operator::NEGATE);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, output->view());
}

using TypesNotToNegate = cudf::test::Types<uint8_t,
uint16_t,
uint32_t,
uint64_t,
cudf::timestamp_D,
cudf::timestamp_s,
cudf::timestamp_ms,
cudf::timestamp_us,
cudf::timestamp_ns>;

template <typename T>
struct UnaryNegateErrorTests : public cudf::test::BaseFixture {};

TYPED_TEST_SUITE(UnaryNegateErrorTests, TypesNotToNegate);

TYPED_TEST(UnaryNegateErrorTests, UnsupportedTypesFail)
{
using T = TypeParam;
cudf::test::fixed_width_column_wrapper<T> input({1, 2, 3, 4});
EXPECT_THROW(cudf::unary_operation(input, cudf::unary_operator::NEGATE), cudf::logic_error);
}

struct UnaryNegateComplexTypesErrorTests : public cudf::test::BaseFixture {};

TEST_F(UnaryNegateComplexTypesErrorTests, NegateStringColumnFail)
{
cudf::test::strings_column_wrapper input({"foo", "bar"});
EXPECT_THROW(cudf::unary_operation(input, cudf::unary_operator::NEGATE), cudf::logic_error);
}

TEST_F(UnaryNegateComplexTypesErrorTests, NegateListsColumnFail)
{
cudf::test::lists_column_wrapper<int32_t> input{{1, 2}, {3, 4}};
EXPECT_THROW(cudf::unary_operation(input, cudf::unary_operator::NEGATE), cudf::logic_error);
}

template <typename T>
struct UnaryLogicalOpsTest : public cudf::test::BaseFixture {};

Expand Down Expand Up @@ -274,7 +337,7 @@ TYPED_TEST(UnaryMathFloatOpsTest, SimpleTANH)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, output->view());
}

TYPED_TEST(UnaryMathFloatOpsTest, SimpleiASINH)
TYPED_TEST(UnaryMathFloatOpsTest, SimpleASINH)
{
cudf::test::fixed_width_column_wrapper<TypeParam> input{{0.0}};
cudf::test::fixed_width_column_wrapper<TypeParam> expected{{0.0}};
Expand Down
16 changes: 15 additions & 1 deletion cpp/tests/unary/unary_ops_test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -266,6 +266,20 @@ struct FixedPointUnaryTests : public cudf::test::BaseFixture {};

TYPED_TEST_SUITE(FixedPointUnaryTests, cudf::test::FixedPointTypes);

TYPED_TEST(FixedPointUnaryTests, FixedPointUnaryNegate)
{
using namespace numeric;
using decimalXX = TypeParam;
using RepType = cudf::device_storage_type_t<decimalXX>;
using fp_wrapper = cudf::test::fixed_point_column_wrapper<RepType>;

auto const input = fp_wrapper{{0, -1234, -3456, -6789, 1234, 3456, 6789}, scale_type{-3}};
auto const expected = fp_wrapper{{0, 1234, 3456, 6789, -1234, -3456, -6789}, scale_type{-3}};
auto const result = cudf::unary_operation(input, cudf::unary_operator::NEGATE);

CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view());
}

TYPED_TEST(FixedPointUnaryTests, FixedPointUnaryAbs)
{
using namespace numeric;
Expand Down
29 changes: 1 addition & 28 deletions python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np
import pandas as pd
import pyarrow as pa
from numba.np import numpy_support
from typing_extensions import Self

import pylibcudf as plc
Expand All @@ -24,7 +23,6 @@
from cudf.core.mixins import BinaryOperand
from cudf.core.scalar import pa_scalar_to_plc_scalar
from cudf.errors import MixedTypeError
from cudf.utils import cudautils
from cudf.utils.dtypes import (
find_common_type,
min_column_type,
Expand All @@ -33,7 +31,7 @@
)

if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from collections.abc import Sequence

from cudf._typing import (
ColumnBinaryOperand,
Expand All @@ -45,13 +43,6 @@
from cudf.core.buffer import Buffer
from cudf.core.column import DecimalBaseColumn

_unaryop_map = {
"ASIN": "ARCSIN",
"ACOS": "ARCCOS",
"ATAN": "ARCTAN",
"INVERT": "BIT_INVERT",
}


class NumericalColumn(NumericalBaseColumn):
"""
Expand Down Expand Up @@ -192,24 +183,6 @@ def transform(self, compiled_op, np_dtype: np.dtype) -> ColumnBase:
)
return type(self).from_pylibcudf(plc_column)

def unary_operator(self, unaryop: str | Callable) -> ColumnBase:
if callable(unaryop):
nb_type = numpy_support.from_dtype(self.dtype)
nb_signature = (nb_type,)
compiled_op = cudautils.compile_udf(unaryop, nb_signature)
np_dtype = np.dtype(compiled_op[1])
return self.transform(compiled_op, np_dtype)

unaryop = unaryop.upper()
unaryop = _unaryop_map.get(unaryop, unaryop)
unaryop = plc.unary.UnaryOperator[unaryop]
with acquire_spill_lock():
return type(self).from_pylibcudf(
plc.unary.unary_operation(
self.to_pylibcudf(mode="read"), unaryop
)
)

def __invert__(self):
if self.dtype.kind in "ui":
return self.unary_operator("invert")
Expand Down
35 changes: 35 additions & 0 deletions python/cudf/cudf/core/column/numerical_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING, Literal, cast

import numpy as np
from numba.np import numpy_support

import pylibcudf as plc

Expand All @@ -14,12 +15,23 @@
from cudf.core.column.column import ColumnBase
from cudf.core.missing import NA
from cudf.core.mixins import Scannable
from cudf.utils import cudautils

if TYPE_CHECKING:
from collections.abc import Callable

from cudf._typing import ScalarLike
from cudf.core.column.decimal import DecimalDtype


_unaryop_map = {
"ASIN": "ARCSIN",
"ACOS": "ARCCOS",
"ATAN": "ARCTAN",
"INVERT": "BIT_INVERT",
}


class NumericalBaseColumn(ColumnBase, Scannable):
"""
A column composed of numerical (bool, integer, float, decimal) data.
Expand Down Expand Up @@ -268,3 +280,26 @@ def _scan(self, op: str) -> ColumnBase:
return self.scan(op.replace("cum", ""), True)._with_type_metadata(
self.dtype
)

def unary_operator(self, unaryop: str | Callable) -> ColumnBase:
if callable(unaryop):
nb_type = numpy_support.from_dtype(self.dtype)
nb_signature = (nb_type,)
compiled_op = cudautils.compile_udf(unaryop, nb_signature)
np_dtype = np.dtype(compiled_op[1])
return self.transform(compiled_op, np_dtype)

unaryop = unaryop.upper()
unaryop = _unaryop_map.get(unaryop, unaryop)
unaryop = plc.unary.UnaryOperator[unaryop]
with acquire_spill_lock():
return type(self).from_pylibcudf(
plc.unary.unary_operation(
self.to_pylibcudf(mode="read"), unaryop
)
)

def transform(self, compiled_op, np_dtype: np.dtype) -> ColumnBase:
raise NotImplementedError(
"transform is not implemented for NumericalBaseColumn"
)
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,7 +1644,7 @@ def __neg__(self):
(
col.unary_operator("not")
if col.dtype.kind == "b"
else -1 * col
else col.unary_operator("negate")
for col in self._columns
)
)
Expand Down
7 changes: 7 additions & 0 deletions python/cudf/cudf/tests/test_unaops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import operator
import re
from decimal import Decimal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -134,3 +135,9 @@ def test_series_bool_neg():
sr = Series([True, False, True, None, False, None, True, True])
psr = sr.to_pandas(nullable=True)
assert_eq((-sr).to_pandas(nullable=True), -psr, check_dtype=True)


def test_series_decimal_neg():
sr = Series([Decimal("0.0"), Decimal("1.23"), Decimal("4.567")])
psr = sr.to_pandas()
assert_eq((-sr).to_pandas(), -psr, check_dtype=True)
Loading

0 comments on commit 51b0f9e

Please sign in to comment.