Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ast tree to simplify expression lifetime management #17156

Merged
merged 29 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
4d60344
added ast tree to simplify expression lifetime management
lamarrr Oct 23, 2024
bf9d71f
Merge branch 'branch-24.12' into ast-expr-enhancement
lamarrr Oct 23, 2024
58c12ba
updated copyright notices
lamarrr Oct 23, 2024
479a7b0
Merge branch 'ast-expr-enhancement' of https://github.com/lamarrr/cud…
lamarrr Oct 23, 2024
04414db
updated ast tree documentation
lamarrr Oct 24, 2024
7d398f1
fixed ast tree docs formatting
lamarrr Oct 28, 2024
6c82411
temporarily changed back to using vector for storing ast operation's …
lamarrr Oct 28, 2024
e3b4823
made operation arity check throw std::invalid_argument
lamarrr Oct 28, 2024
5695f29
Merge branch 'branch-24.12' into ast-expr-enhancement
lamarrr Oct 28, 2024
342b184
corrected ast builder push return type
lamarrr Oct 28, 2024
daa882b
updated documentation for ast tree member functions
lamarrr Oct 28, 2024
239bfd4
started drafting ast tree test
lamarrr Oct 28, 2024
ea1d430
updated ast tree expression management
lamarrr Oct 29, 2024
4b7eebd
added ast tree tests
lamarrr Oct 29, 2024
e17c2ef
updated ast transform test
lamarrr Oct 29, 2024
d9fa6d8
updated ast tree documentation
lamarrr Oct 29, 2024
e33a7a0
Merge remote-tracking branch 'upstream/branch-24.12' into ast-expr-en…
lamarrr Oct 29, 2024
e356ee5
Merge remote-tracking branch 'upstream/branch-24.12' into ast-expr-en…
lamarrr Oct 29, 2024
8f2265c
Update cpp/include/cudf/ast/expressions.hpp
lamarrr Nov 4, 2024
8a6dac6
Update cpp/include/cudf/ast/expressions.hpp
lamarrr Nov 4, 2024
26b2989
Update cpp/include/cudf/ast/expressions.hpp
lamarrr Nov 4, 2024
aa1d17b
Update cpp/include/cudf/ast/expressions.hpp
lamarrr Nov 4, 2024
4cd078f
Update cpp/include/cudf/ast/expressions.hpp
lamarrr Nov 4, 2024
279f047
Update cpp/tests/ast/ast_tree_tests.cpp
lamarrr Nov 4, 2024
984e34f
Merge branch 'branch-24.12' into ast-expr-enhancement
lamarrr Nov 4, 2024
6afd563
Merge branch 'branch-24.12' into ast-expr-enhancement
lamarrr Nov 4, 2024
1dc38fc
Merge branch 'branch-24.12' into ast-expr-enhancement
lamarrr Nov 6, 2024
72b2f58
Merge branch 'branch-24.12' into ast-expr-enhancement
lamarrr Nov 6, 2024
dccfeb9
Merge branch 'branch-24.12' into ast-expr-enhancement
vyasr Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cpp/include/cudf/ast/detail/expression_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/memory_resource.hpp>
#include <cudf/utilities/span.hpp>

#include <thrust/scan.h>

Expand Down Expand Up @@ -300,7 +301,7 @@ class expression_parser {
* @return The indices of the operands stored in the data references.
*/
std::vector<cudf::size_type> visit_operands(
std::vector<std::reference_wrapper<expression const>> operands);
cudf::host_span<std::reference_wrapper<cudf::ast::expression const> const> operands);

/**
* @brief Add a data reference to the internal list.
Expand Down
86 changes: 83 additions & 3 deletions cpp/include/cudf/ast/expressions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <cudf/utilities/error.hpp>

#include <cstdint>
#include <memory>
#include <vector>

namespace CUDF_EXPORT cudf {
namespace ast {
Expand Down Expand Up @@ -478,7 +480,7 @@ class operation : public expression {
*
* @return Vector of operands
*/
[[nodiscard]] std::vector<std::reference_wrapper<expression const>> get_operands() const
[[nodiscard]] std::vector<std::reference_wrapper<expression const>> const& get_operands() const
{
return operands;
}
Expand Down Expand Up @@ -506,8 +508,8 @@ class operation : public expression {
};

private:
ast_operator const op;
std::vector<std::reference_wrapper<expression const>> const operands;
ast_operator op;
std::vector<std::reference_wrapper<expression const>> operands;
};

/**
Expand Down Expand Up @@ -552,6 +554,84 @@ class column_name_reference : public expression {
std::string column_name;
};

/**
* @brief An AST expression tree. it owns and contains multiple dependent expressions. All the
lamarrr marked this conversation as resolved.
Show resolved Hide resolved
* expressions are destroyed once the tree is destructed.
lamarrr marked this conversation as resolved.
Show resolved Hide resolved
*/
class tree {
public:
/**
* @brief construct an empty ast tree
*/
tree() = default;
tree(tree&&) = default;
tree& operator=(tree&&) = default;
~tree() = default;

// the tree is not copyable
tree(tree const&) = delete;
tree& operator=(tree const&) = delete;

/**
* @brief Add an expression to the AST tree
* @param args Arguments to use to construct the ast expression
* @returns a reference to the added expression
*/
template <typename Expr, typename... Args>
expression const& emplace(Args&&... args)
{
static_assert(std::is_base_of_v<expression, Expr>);
return *expressions.emplace_back(std::make_unique<Expr>(std::forward<Args>(args)...));
}

/**
* @brief Add an expression to the AST tree
* @param expr AST expression to be added
* @returns a reference to the added expression
*/
template <typename Expr>
expression const& push(Expr expr)
{
return emplace<Expr>(std::move(expr));
}

/**
* @brief get the first expression in the tree
*/
expression const& front() const { return *expressions.front(); }

/**
* @brief get the last expression in the tree
*/
expression const& back() const { return *expressions.back(); }

/**
* @brief get the number of expressions added to the tree
*/
size_t size() const { return expressions.size(); }

/**
* @brief get the expression at a checked index in the tree
* @returns the expression at the specified index
*/
expression const& at(size_t index) { return *expressions.at(index); }

/**
* @brief get the expression at an unchecked index in the tree
* @returns the expression at the specified index
*/
expression const& operator[](size_t index) const { return *expressions[index]; }

/**
* @brief get an immutable span to the expressions in the tree
* @returns all expressions added to the tree
*/
std::vector<std::unique_ptr<expression>> const& get_expressions() const { return expressions; }

private:
std::vector<std::unique_ptr<expression>> expressions;
bdice marked this conversation as resolved.
Show resolved Hide resolved
};

/** @} */ // end of group
} // namespace ast

Expand Down
4 changes: 2 additions & 2 deletions cpp/src/ast/expression_parser.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -210,7 +210,7 @@ cudf::data_type expression_parser::output_type() const
}

std::vector<cudf::size_type> expression_parser::visit_operands(
std::vector<std::reference_wrapper<expression const>> operands)
cudf::host_span<std::reference_wrapper<expression const> const> operands)
{
auto operand_data_reference_indices = std::vector<cudf::size_type>();
for (auto const& operand : operands) {
Expand Down
26 changes: 17 additions & 9 deletions cpp/src/ast/expressions.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,36 +23,41 @@
#include <cudf/types.hpp>
#include <cudf/utilities/error.hpp>

#include <stdexcept>

namespace cudf {
namespace ast {

operation::operation(ast_operator op, expression const& input) : op(op), operands({input})
operation::operation(ast_operator op, expression const& input) : op{op}, operands{input}
{
if (cudf::ast::detail::ast_operator_arity(op) != 1) {
CUDF_FAIL("The provided operator is not a unary operator.");
}
CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 1,
"The provided operator is not a unary operator.",
std::invalid_argument);
}

operation::operation(ast_operator op, expression const& left, expression const& right)
: op(op), operands({left, right})
: op{op}, operands{left, right}
{
if (cudf::ast::detail::ast_operator_arity(op) != 2) {
CUDF_FAIL("The provided operator is not a binary operator.");
}
CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 2,
"The provided operator is not a binary operator.",
std::invalid_argument);
}

cudf::size_type literal::accept(detail::expression_parser& visitor) const
{
return visitor.visit(*this);
}

cudf::size_type column_reference::accept(detail::expression_parser& visitor) const
{
return visitor.visit(*this);
}

cudf::size_type operation::accept(detail::expression_parser& visitor) const
{
return visitor.visit(*this);
}

cudf::size_type column_name_reference::accept(detail::expression_parser& visitor) const
{
return visitor.visit(*this);
Expand All @@ -63,16 +68,19 @@ auto literal::accept(detail::expression_transformer& visitor) const
{
return visitor.visit(*this);
}

auto column_reference::accept(detail::expression_transformer& visitor) const
-> decltype(visitor.visit(*this))
{
return visitor.visit(*this);
}

auto operation::accept(detail::expression_transformer& visitor) const
-> decltype(visitor.visit(*this))
{
return visitor.visit(*this);
}

auto column_name_reference::accept(detail::expression_transformer& visitor) const
-> decltype(visitor.visit(*this))
{
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/io/parquet/predicate_pushdown.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <cudf/utilities/default_stream.hpp>
#include <cudf/utilities/error.hpp>
#include <cudf/utilities/memory_resource.hpp>
#include <cudf/utilities/span.hpp>
#include <cudf/utilities/traits.hpp>
#include <cudf/utilities/type_dispatcher.hpp>

Expand Down Expand Up @@ -374,7 +375,7 @@ class stats_expression_converter : public ast::detail::expression_transformer {

private:
std::vector<std::reference_wrapper<ast::expression const>> visit_operands(
std::vector<std::reference_wrapper<ast::expression const>> operands)
cudf::host_span<std::reference_wrapper<ast::expression const> const> operands)
{
std::vector<std::reference_wrapper<ast::expression const>> transformed_operands;
for (auto const& operand : operands) {
Expand Down Expand Up @@ -554,7 +555,7 @@ std::reference_wrapper<ast::expression const> named_to_reference_converter::visi

std::vector<std::reference_wrapper<ast::expression const>>
named_to_reference_converter::visit_operands(
std::vector<std::reference_wrapper<ast::expression const>> operands)
cudf::host_span<std::reference_wrapper<ast::expression const> const> operands)
{
std::vector<std::reference_wrapper<ast::expression const>> transformed_operands;
for (auto const& operand : operands) {
Expand Down Expand Up @@ -624,7 +625,7 @@ class names_from_expression : public ast::detail::expression_transformer {
}

private:
void visit_operands(std::vector<std::reference_wrapper<ast::expression const>> operands)
void visit_operands(cudf::host_span<std::reference_wrapper<ast::expression const> const> operands)
{
for (auto const& operand : operands) {
operand.get().accept(*this);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/io/parquet/reader_impl_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ class named_to_reference_converter : public ast::detail::expression_transformer

private:
std::vector<std::reference_wrapper<ast::expression const>> visit_operands(
std::vector<std::reference_wrapper<ast::expression const>> operands);
cudf::host_span<std::reference_wrapper<ast::expression const> const> operands);

std::unordered_map<std::string, size_type> column_name_to_index;
std::optional<std::reference_wrapper<ast::expression const>> _stats_expr;
Expand Down
Loading