Skip to content

Commit c1273fc

Browse files
j2kuncopybara-github
authored andcommitted
Add a data structure for a tree of arithmetic operations
The basic structure just represents a (leaf-type-agnostic) tree of operations and provides a mechanism to create a visitor using std::visit. This is needed for lower_eval to separate the construction of the lowered arithmetic tree from the materialization of the IR. However, I think it will (with adaptations) be useful for other situations as well: - Symbolic noise analysis in #1817 - To simplify the core routine in operation-balancer (https://github.com/google/heir/blob/b0cf72da113e6c7282733f8ba6bfcb7754a7495c/lib/Transforms/OperationBalancer/OperationBalancer.cpp#L74) PiperOrigin-RevId: 770267746
1 parent b0cf72d commit c1273fc

File tree

3 files changed

+201
-0
lines changed

3 files changed

+201
-0
lines changed

lib/Utils/ArithmeticTree.h

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#ifndef LIB_UTILS_ARITHMETICTREE_H_
2+
#define LIB_UTILS_ARITHMETICTREE_H_
3+
4+
#include <memory>
5+
#include <variant>
6+
7+
namespace mlir {
8+
namespace heir {
9+
10+
// This file contains a generic tree structure that can be used for representing
11+
// arithmetic trees with leaf nodes of various types.
12+
template <typename T>
13+
struct ArithmeticTreeNode;
14+
15+
// A leaf node for the tree
16+
template <typename T>
17+
struct LeafNode {
18+
T value;
19+
};
20+
21+
using ConstantNode = LeafNode<double>;
22+
23+
template <typename T>
24+
struct AddNode {
25+
std::unique_ptr<ArithmeticTreeNode<T>> left;
26+
std::unique_ptr<ArithmeticTreeNode<T>> right;
27+
};
28+
29+
template <typename T>
30+
struct MultiplyNode {
31+
std::unique_ptr<ArithmeticTreeNode<T>> left;
32+
std::unique_ptr<ArithmeticTreeNode<T>> right;
33+
};
34+
35+
template <typename T>
36+
struct PowerNode {
37+
std::unique_ptr<ArithmeticTreeNode<T>> base;
38+
size_t exponent;
39+
};
40+
41+
template <typename T>
42+
struct ArithmeticTreeNode {
43+
public:
44+
std::variant<ConstantNode, LeafNode<T>, AddNode<T>, MultiplyNode<T>,
45+
PowerNode<T>>
46+
node_variant;
47+
48+
explicit ArithmeticTreeNode(double constant)
49+
: node_variant(ConstantNode{constant}) {}
50+
explicit ArithmeticTreeNode(const T& value)
51+
: node_variant(LeafNode<T>{value}) {}
52+
explicit ArithmeticTreeNode(T&& value)
53+
: node_variant(LeafNode<T>{std::move(value)}) {}
54+
55+
private:
56+
ArithmeticTreeNode() = default;
57+
58+
public:
59+
// Static factory methods
60+
static std::unique_ptr<ArithmeticTreeNode<T>> constant(double constant) {
61+
return std::unique_ptr<ArithmeticTreeNode<T>>(
62+
new ArithmeticTreeNode<T>(constant));
63+
}
64+
65+
static std::unique_ptr<ArithmeticTreeNode<T>> leaf(const T& value) {
66+
return std::unique_ptr<ArithmeticTreeNode<T>>(
67+
new ArithmeticTreeNode<T>(value));
68+
}
69+
70+
static std::unique_ptr<ArithmeticTreeNode<T>> add(
71+
std::unique_ptr<ArithmeticTreeNode<T>> lhs,
72+
std::unique_ptr<ArithmeticTreeNode<T>> rhs) {
73+
assert(lhs && rhs && "invalid add");
74+
auto node =
75+
std::unique_ptr<ArithmeticTreeNode<T>>(new ArithmeticTreeNode<T>());
76+
node->node_variant.template emplace<AddNode<T>>(std::move(lhs),
77+
std::move(rhs));
78+
return node;
79+
}
80+
81+
static std::unique_ptr<ArithmeticTreeNode<T>> mul(
82+
std::unique_ptr<ArithmeticTreeNode<T>> lhs,
83+
std::unique_ptr<ArithmeticTreeNode<T>> rhs) {
84+
assert(lhs && rhs && "invalid mul");
85+
auto node =
86+
std::unique_ptr<ArithmeticTreeNode<T>>(new ArithmeticTreeNode<T>());
87+
node->node_variant.template emplace<MultiplyNode<T>>(std::move(lhs),
88+
std::move(rhs));
89+
return node;
90+
}
91+
92+
static std::unique_ptr<ArithmeticTreeNode<T>> power(
93+
std::unique_ptr<ArithmeticTreeNode<T>> base, size_t exponent) {
94+
assert(base && "invalid base for power");
95+
auto node =
96+
std::unique_ptr<ArithmeticTreeNode<T>>(new ArithmeticTreeNode<T>());
97+
node->node_variant.template emplace<PowerNode<T>>(std::move(base),
98+
exponent);
99+
return node;
100+
}
101+
102+
// The presence of std::unique_ptr in AddNode, MultiplyNode, and PowerNode
103+
// makes these types non-copyable. Consequently, ArithmeticTreeNode
104+
// itself is move-only by default.
105+
ArithmeticTreeNode(const ArithmeticTreeNode&) = delete; // No copying
106+
ArithmeticTreeNode& operator=(const ArithmeticTreeNode&) =
107+
delete; // No copy assignment
108+
109+
ArithmeticTreeNode(ArithmeticTreeNode&& other) noexcept =
110+
default; // Allow move construction
111+
ArithmeticTreeNode& operator=(ArithmeticTreeNode&& other) noexcept =
112+
default; // Allow move assignment
113+
114+
// Visitor pattern
115+
template <typename VisitorFunc>
116+
decltype(auto) visit(VisitorFunc&& visitor) {
117+
return std::visit(std::forward<VisitorFunc>(visitor), node_variant);
118+
}
119+
120+
template <typename VisitorFunc>
121+
decltype(auto) visit(VisitorFunc&& visitor) const {
122+
return std::visit(std::forward<VisitorFunc>(visitor), node_variant);
123+
}
124+
};
125+
126+
} // namespace heir
127+
} // namespace mlir
128+
129+
#endif // LIB_UTILS_ARITHMETICTREE_H_

lib/Utils/ArithmeticTreeTest.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include <iomanip>
2+
3+
#include "gmock/gmock.h" // from @googletest
4+
#include "gtest/gtest.h" // from @googletest
5+
#include "lib/Utils/ArithmeticTree.h"
6+
7+
namespace mlir {
8+
namespace heir {
9+
namespace {
10+
11+
using StringLeavedTree = ArithmeticTreeNode<std::string>;
12+
13+
struct FlattenedStringVisitor {
14+
std::string operator()(const ConstantNode& node) const {
15+
std::stringstream ss;
16+
ss << std::fixed << std::setprecision(2) << node.value;
17+
return ss.str();
18+
}
19+
20+
std::string operator()(const LeafNode<std::string>& node) const {
21+
return node.value;
22+
}
23+
24+
std::string operator()(const AddNode<std::string>& node) const {
25+
std::stringstream ss;
26+
ss << "(" << node.left->visit(*this) << " + " << node.right->visit(*this)
27+
<< ")";
28+
return ss.str();
29+
}
30+
31+
std::string operator()(const MultiplyNode<std::string>& node) const {
32+
std::stringstream ss;
33+
ss << node.left->visit(*this) << " * " << node.right->visit(*this);
34+
return ss.str();
35+
}
36+
37+
std::string operator()(const PowerNode<std::string>& node) const {
38+
std::stringstream ss;
39+
ss << "(" << node.base->visit(*this) << " ^ " << node.exponent << ")";
40+
return ss.str();
41+
}
42+
};
43+
44+
TEST(ArithmeticTreeTest, TestPrint) {
45+
auto root = StringLeavedTree::mul(
46+
StringLeavedTree::add(StringLeavedTree::leaf("x"),
47+
StringLeavedTree::constant(3.0)),
48+
StringLeavedTree::power(StringLeavedTree::leaf("y"), 2));
49+
50+
FlattenedStringVisitor visitor;
51+
std::string result = root->visit(visitor);
52+
EXPECT_EQ(result, "(x + 3.00) * (y ^ 2)");
53+
}
54+
55+
} // namespace
56+
} // namespace heir
57+
} // namespace mlir

lib/Utils/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,18 @@ cc_library(
153153
"@llvm-project//mlir:Support",
154154
],
155155
)
156+
157+
cc_library(
158+
name = "ArithmeticTree",
159+
srcs = ["ArithmeticTree.h"],
160+
hdrs = ["ArithmeticTree.h"],
161+
)
162+
163+
cc_test(
164+
name = "ArithmeticTreeTest",
165+
srcs = ["ArithmeticTreeTest.cpp"],
166+
deps = [
167+
":ArithmeticTree",
168+
"@googletest//:gtest_main",
169+
],
170+
)

0 commit comments

Comments
 (0)