diff --git a/cpp/include/cudf/ast/detail/expression_parser.hpp b/cpp/include/cudf/ast/detail/expression_parser.hpp index f4cce8e6da6..b5973d0ace9 100644 --- a/cpp/include/cudf/ast/detail/expression_parser.hpp +++ b/cpp/include/cudf/ast/detail/expression_parser.hpp @@ -19,6 +19,10 @@ #include #include #include +#include +#include + +#include #include #include @@ -296,7 +300,7 @@ class expression_parser { * @return The indices of the operands stored in the data references. */ std::vector visit_operands( - std::vector> operands); + cudf::host_span const> operands); /** * @brief Add a data reference to the internal list. diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index 4299ee5f20f..bcc9ad1b391 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -22,6 +22,8 @@ #include #include +#include +#include namespace CUDF_EXPORT cudf { namespace ast { @@ -478,7 +480,7 @@ class operation : public expression { * * @return Vector of operands */ - [[nodiscard]] std::vector> get_operands() const + [[nodiscard]] std::vector> const& get_operands() const { return operands; } @@ -506,8 +508,8 @@ class operation : public expression { }; private: - ast_operator const op; - std::vector> const operands; + ast_operator op; + std::vector> operands; }; /** @@ -552,6 +554,98 @@ class column_name_reference : public expression { std::string column_name; }; +/** + * @brief An AST expression tree. It owns and contains multiple dependent expressions. All the + * expressions are destroyed when the tree is destructed. + */ +class tree { + public: + /** + * @brief construct an empty ast tree + */ + tree() = default; + + /** + * @brief Moves the ast tree + */ + tree(tree&&) = default; + + /** + * @brief move-assigns the AST tree + * @returns a reference to the move-assigned tree + */ + tree& operator=(tree&&) = default; + + ~tree() = default; + + // the tree is not copyable + tree(tree const&) = delete; + tree& operator=(tree const&) = delete; + + /** + * @brief Add an expression to the AST tree + * @param args Arguments to use to construct the ast expression + * @returns a reference to the added expression + */ + template + Expr const& emplace(Args&&... args) + { + static_assert(std::is_base_of_v); + auto expr = std::make_shared(std::forward(args)...); + Expr const& expr_ref = *expr; + expressions.emplace_back(std::static_pointer_cast(std::move(expr))); + return expr_ref; + } + + /** + * @brief Add an expression to the AST tree + * @param expr AST expression to be added + * @returns a reference to the added expression + */ + template + Expr const& push(Expr expr) + { + return emplace(std::move(expr)); + } + + /** + * @brief get the first expression in the tree + * @returns the first inserted expression into the tree + */ + expression const& front() const { return *expressions.front(); } + + /** + * @brief get the last expression in the tree + * @returns the last inserted expression into the tree + */ + expression const& back() const { return *expressions.back(); } + + /** + * @brief get the number of expressions added to the tree + * @returns the number of expressions added to the tree + */ + size_t size() const { return expressions.size(); } + + /** + * @brief get the expression at an index in the tree. Index is checked. + * @param index index of expression in the ast tree + * @returns the expression at the specified index + */ + expression const& at(size_t index) { return *expressions.at(index); } + + /** + * @brief get the expression at an index in the tree. Index is unchecked. + * @param index index of expression in the ast tree + * @returns the expression at the specified index + */ + expression const& operator[](size_t index) const { return *expressions[index]; } + + private: + // TODO: use better ownership semantics, the shared_ptr here is redundant. Consider using a bump + // allocator with type-erased deleters. + std::vector> expressions; +}; + /** @} */ // end of group } // namespace ast diff --git a/cpp/src/ast/expression_parser.cpp b/cpp/src/ast/expression_parser.cpp index 5815ce33e33..d0e4c59ca54 100644 --- a/cpp/src/ast/expression_parser.cpp +++ b/cpp/src/ast/expression_parser.cpp @@ -207,7 +207,7 @@ cudf::data_type expression_parser::output_type() const } std::vector expression_parser::visit_operands( - std::vector> operands) + cudf::host_span const> operands) { auto operand_data_reference_indices = std::vector(); for (auto const& operand : operands) { diff --git a/cpp/src/ast/expressions.cpp b/cpp/src/ast/expressions.cpp index 4c2b56dd4f5..b7e4e4609cb 100644 --- a/cpp/src/ast/expressions.cpp +++ b/cpp/src/ast/expressions.cpp @@ -20,36 +20,41 @@ #include #include +#include + namespace cudf { namespace ast { -operation::operation(ast_operator op, expression const& input) : op(op), operands({input}) +operation::operation(ast_operator op, expression const& input) : op{op}, operands{input} { - if (cudf::ast::detail::ast_operator_arity(op) != 1) { - CUDF_FAIL("The provided operator is not a unary operator."); - } + CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 1, + "The provided operator is not a unary operator.", + std::invalid_argument); } operation::operation(ast_operator op, expression const& left, expression const& right) - : op(op), operands({left, right}) + : op{op}, operands{left, right} { - if (cudf::ast::detail::ast_operator_arity(op) != 2) { - CUDF_FAIL("The provided operator is not a binary operator."); - } + CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 2, + "The provided operator is not a binary operator.", + std::invalid_argument); } cudf::size_type literal::accept(detail::expression_parser& visitor) const { return visitor.visit(*this); } + cudf::size_type column_reference::accept(detail::expression_parser& visitor) const { return visitor.visit(*this); } + cudf::size_type operation::accept(detail::expression_parser& visitor) const { return visitor.visit(*this); } + cudf::size_type column_name_reference::accept(detail::expression_parser& visitor) const { return visitor.visit(*this); @@ -60,16 +65,19 @@ auto literal::accept(detail::expression_transformer& visitor) const { return visitor.visit(*this); } + auto column_reference::accept(detail::expression_transformer& visitor) const -> decltype(visitor.visit(*this)) { return visitor.visit(*this); } + auto operation::accept(detail::expression_transformer& visitor) const -> decltype(visitor.visit(*this)) { return visitor.visit(*this); } + auto column_name_reference::accept(detail::expression_transformer& visitor) const -> decltype(visitor.visit(*this)) { diff --git a/cpp/src/io/parquet/predicate_pushdown.cpp b/cpp/src/io/parquet/predicate_pushdown.cpp index a965f3325d5..cd3dcd2bce4 100644 --- a/cpp/src/io/parquet/predicate_pushdown.cpp +++ b/cpp/src/io/parquet/predicate_pushdown.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -373,7 +374,7 @@ class stats_expression_converter : public ast::detail::expression_transformer { private: std::vector> visit_operands( - std::vector> operands) + cudf::host_span const> operands) { std::vector> transformed_operands; for (auto const& operand : operands) { @@ -553,7 +554,7 @@ std::reference_wrapper named_to_reference_converter::visi std::vector> named_to_reference_converter::visit_operands( - std::vector> operands) + cudf::host_span const> operands) { std::vector> transformed_operands; for (auto const& operand : operands) { @@ -623,7 +624,7 @@ class names_from_expression : public ast::detail::expression_transformer { } private: - void visit_operands(std::vector> operands) + void visit_operands(cudf::host_span const> operands) { for (auto const& operand : operands) { operand.get().accept(*this); diff --git a/cpp/src/io/parquet/reader_impl_helpers.hpp b/cpp/src/io/parquet/reader_impl_helpers.hpp index 6487c92f48f..fd692c0cdd6 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.hpp +++ b/cpp/src/io/parquet/reader_impl_helpers.hpp @@ -425,7 +425,7 @@ class named_to_reference_converter : public ast::detail::expression_transformer private: std::vector> visit_operands( - std::vector> operands); + cudf::host_span const> operands); std::unordered_map column_name_to_index; std::optional> _stats_expr; diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 23632f6fbba..e9ba58ba224 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -650,7 +650,7 @@ ConfigureTest(ENCODE_TEST encode/encode_tests.cpp) # ################################################################################################## # * ast tests ------------------------------------------------------------------------------------- -ConfigureTest(AST_TEST ast/transform_tests.cpp) +ConfigureTest(AST_TEST ast/transform_tests.cpp ast/ast_tree_tests.cpp) # ################################################################################################## # * lists tests ---------------------------------------------------------------------------------- diff --git a/cpp/tests/ast/ast_tree_tests.cpp b/cpp/tests/ast/ast_tree_tests.cpp new file mode 100644 index 00000000000..1a960c68e23 --- /dev/null +++ b/cpp/tests/ast/ast_tree_tests.cpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +template +using column_wrapper = cudf::test::fixed_width_column_wrapper; + +TEST(AstTreeTest, ExpressionTree) +{ + namespace ast = cudf::ast; + using op = ast::ast_operator; + using operation = ast::operation; + + // computes (y = mx + c)... and linearly interpolates them using interpolator t + auto m0_col = column_wrapper{10, 20, 50, 100}; + auto x0_col = column_wrapper{10, 5, 2, 1}; + auto c0_col = column_wrapper{100, 100, 100, 100}; + + auto m1_col = column_wrapper{10, 20, 50, 100}; + auto x1_col = column_wrapper{20, 10, 4, 2}; + auto c1_col = column_wrapper{200, 200, 200, 200}; + + auto one_scalar = cudf::numeric_scalar{1}; + auto t_scalar = cudf::numeric_scalar{0.5F}; + + auto table = cudf::table_view{{m0_col, x0_col, c0_col, m1_col, x1_col, c1_col}}; + + ast::tree tree{}; + + auto const& one = tree.push(ast::literal{one_scalar}); + auto const& t = tree.push(ast::literal{t_scalar}); + auto const& m0 = tree.push(ast::column_reference(0)); + auto const& x0 = tree.push(ast::column_reference(1)); + auto const& c0 = tree.push(ast::column_reference(2)); + auto const& m1 = tree.push(ast::column_reference(3)); + auto const& x1 = tree.push(ast::column_reference(4)); + auto const& c1 = tree.push(ast::column_reference(5)); + + // compute: y0 = m0 x0 + c0 + auto const& y0 = tree.push(operation{op::ADD, tree.push(operation{op::MUL, m0, x0}), c0}); + + // compute: y1 = m1 x1 + c1 + auto const& y1 = tree.push(operation{op::ADD, tree.push(operation{op::MUL, m1, x1}), c1}); + + // compute weighted: (1 - t) * y0 + auto const& y0_w = tree.push(operation{op::MUL, tree.push(operation{op::SUB, one, t}), y0}); + + // compute weighted: y = t * y1 + auto const& y1_w = tree.push(operation{op::MUL, t, y1}); + + // add weighted: result = lerp(y0, y1, t) = (1 - t) * y0 + t * y1 + auto result = cudf::compute_column(table, tree.push(operation{op::ADD, y0_w, y1_w})); + + auto expected = column_wrapper{300, 300, 300, 300}; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view()); +} diff --git a/cpp/tests/ast/transform_tests.cpp b/cpp/tests/ast/transform_tests.cpp index 7af88d8aa34..e28d92bb615 100644 --- a/cpp/tests/ast/transform_tests.cpp +++ b/cpp/tests/ast/transform_tests.cpp @@ -530,9 +530,10 @@ TEST_F(TransformTest, UnaryTrigonometry) TEST_F(TransformTest, ArityCheckFailure) { auto col_ref_0 = cudf::ast::column_reference(0); - EXPECT_THROW(cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0), cudf::logic_error); + EXPECT_THROW(cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0), + std::invalid_argument); EXPECT_THROW(cudf::ast::operation(cudf::ast::ast_operator::ABS, col_ref_0, col_ref_0), - cudf::logic_error); + std::invalid_argument); } TEST_F(TransformTest, StringComparison)