Skip to content

Commit 6ce33c2

Browse files
j2kuncopybara-github
authored andcommitted
Add a data structure for a DAG of arithmetic operations
The basic structure just represents a (leaf-type-agnostic) DAG of operations and provides a mechanism to create a visitor using std::visit (with a base class for a caching visitor). This is needed for lower_eval to separate the construction of the lowered arithmetic tree from the materialization of the IR. However, I think it will (with adaptations) be useful for other situations as well: - Symbolic noise analysis in #1817 - To simplify the core routine in operation-balancer (https://github.com/google/heir/blob/b0cf72da113e6c7282733f8ba6bfcb7754a7495c/lib/Transforms/OperationBalancer/OperationBalancer.cpp#L74) PiperOrigin-RevId: 770267746
1 parent 0fe389a commit 6ce33c2

File tree

3 files changed

+310
-0
lines changed

3 files changed

+310
-0
lines changed

lib/Utils/ArithmeticDag.h

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
#ifndef LIB_UTILS_ARITHMETICDAG_H_
2+
#define LIB_UTILS_ARITHMETICDAG_H_
3+
4+
#include <cassert>
5+
#include <cstddef>
6+
#include <memory>
7+
#include <unordered_map>
8+
#include <utility>
9+
#include <variant>
10+
11+
namespace mlir {
12+
namespace heir {
13+
14+
// This file contains a generic DAG structure that can be used for representing
15+
// arithmetic DAGs with leaf nodes of various types.
16+
template <typename T>
17+
struct ArithmeticDagNode;
18+
19+
// A leaf node for the DAG
20+
template <typename T>
21+
struct LeafNode {
22+
T value;
23+
};
24+
25+
struct ConstantNode {
26+
double value;
27+
};
28+
29+
template <typename T>
30+
struct AddNode {
31+
std::shared_ptr<ArithmeticDagNode<T>> left;
32+
std::shared_ptr<ArithmeticDagNode<T>> right;
33+
};
34+
35+
template <typename T>
36+
struct MultiplyNode {
37+
std::shared_ptr<ArithmeticDagNode<T>> left;
38+
std::shared_ptr<ArithmeticDagNode<T>> right;
39+
};
40+
41+
template <typename T>
42+
struct PowerNode {
43+
std::shared_ptr<ArithmeticDagNode<T>> base;
44+
size_t exponent;
45+
};
46+
47+
template <typename T>
48+
struct ArithmeticDagNode {
49+
public:
50+
std::variant<ConstantNode, LeafNode<T>, AddNode<T>, MultiplyNode<T>,
51+
PowerNode<T>>
52+
node_variant;
53+
54+
private:
55+
ArithmeticDagNode() = default;
56+
57+
public:
58+
// Static factory methods
59+
static std::shared_ptr<ArithmeticDagNode<T>> leaf(const T& value) {
60+
auto node =
61+
std::shared_ptr<ArithmeticDagNode<T>>(new ArithmeticDagNode<T>());
62+
// Note, to satisfy variant we need to use aggregate initialization inside
63+
// emplace
64+
node->node_variant.template emplace<LeafNode<T>>(new LeafNode<T>{value});
65+
return node;
66+
}
67+
68+
static std::shared_ptr<ArithmeticDagNode<T>> constant(double constant) {
69+
auto node =
70+
std::shared_ptr<ArithmeticDagNode<T>>(new ArithmeticDagNode<T>());
71+
node->node_variant.template emplace<ConstantNode>(ConstantNode{constant});
72+
return node;
73+
}
74+
75+
static std::shared_ptr<ArithmeticDagNode<T>> add(
76+
std::shared_ptr<ArithmeticDagNode<T>> lhs,
77+
std::shared_ptr<ArithmeticDagNode<T>> rhs) {
78+
assert(lhs && rhs && "invalid add");
79+
auto node =
80+
std::shared_ptr<ArithmeticDagNode<T>>(new ArithmeticDagNode<T>());
81+
node->node_variant.template emplace<AddNode<T>>(
82+
AddNode<T>{std::move(lhs), std::move(rhs)});
83+
return node;
84+
}
85+
86+
static std::shared_ptr<ArithmeticDagNode<T>> mul(
87+
std::shared_ptr<ArithmeticDagNode<T>> lhs,
88+
std::shared_ptr<ArithmeticDagNode<T>> rhs) {
89+
assert(lhs && rhs && "invalid mul");
90+
auto node =
91+
std::shared_ptr<ArithmeticDagNode<T>>(new ArithmeticDagNode<T>());
92+
node->node_variant.template emplace<MultiplyNode<T>>(
93+
MultiplyNode<T>{std::move(lhs), std::move(rhs)});
94+
return node;
95+
}
96+
97+
static std::shared_ptr<ArithmeticDagNode<T>> power(
98+
std::shared_ptr<ArithmeticDagNode<T>> base, size_t exponent) {
99+
assert(base && "invalid base for power");
100+
auto node =
101+
std::shared_ptr<ArithmeticDagNode<T>>(new ArithmeticDagNode<T>());
102+
node->node_variant.template emplace<PowerNode<T>>(
103+
PowerNode<T>{std::move(base), exponent});
104+
return node;
105+
}
106+
107+
ArithmeticDagNode(const ArithmeticDagNode&) = default;
108+
ArithmeticDagNode& operator=(const ArithmeticDagNode&) = default;
109+
ArithmeticDagNode(ArithmeticDagNode&&) noexcept = default;
110+
ArithmeticDagNode& operator=(ArithmeticDagNode&&) noexcept = default;
111+
112+
// Visitor pattern
113+
template <typename VisitorFunc>
114+
decltype(auto) visit(VisitorFunc&& visitor) {
115+
return std::visit(std::forward<VisitorFunc>(visitor), node_variant);
116+
}
117+
118+
template <typename VisitorFunc>
119+
decltype(auto) visit(VisitorFunc&& visitor) const {
120+
return std::visit(std::forward<VisitorFunc>(visitor), node_variant);
121+
}
122+
};
123+
124+
/// A base class for visitors that caches intermediate results.
125+
///
126+
/// Template parameters:
127+
/// T: The type of the leaf nodes.
128+
/// ResultType: The type of the result of the visit.
129+
template <typename T, typename ResultType>
130+
class CachingVisitor {
131+
public:
132+
virtual ~CachingVisitor() = default;
133+
134+
/// The main entry point that contains the caching logic.
135+
ResultType process(const std::shared_ptr<ArithmeticDagNode<T>>& node) {
136+
assert(node != nullptr && "invalid null node!");
137+
138+
const auto* node_ptr = node.get();
139+
if (auto it = cache.find(node_ptr); it != cache.end()) {
140+
return it->second;
141+
}
142+
143+
ResultType result = std::visit(*this, node->node_variant);
144+
cache[node_ptr] = result;
145+
return result;
146+
}
147+
148+
// --- Virtual Visit Methods ---
149+
// Derived classes must override these for the node types they support.
150+
151+
virtual ResultType operator()(const ConstantNode& node) {
152+
assert(false && "Visit logic for ConstantNode is not implemented.");
153+
}
154+
155+
virtual ResultType operator()(const LeafNode<T>& node) {
156+
assert(false && "Visit logic for LeafNode is not implemented.");
157+
}
158+
159+
virtual ResultType operator()(const AddNode<T>& node) {
160+
assert(false && "Visit logic for AddNode is not implemented.");
161+
}
162+
163+
virtual ResultType operator()(const MultiplyNode<T>& node) {
164+
assert(false && "Visit logic for MultiplyNode is not implemented.");
165+
}
166+
167+
virtual ResultType operator()(const PowerNode<T>& node) {
168+
assert(false && "Visit logic for PowerNode is not implemented.");
169+
}
170+
171+
private:
172+
std::unordered_map<const ArithmeticDagNode<T>*, ResultType> cache;
173+
};
174+
175+
} // namespace heir
176+
} // namespace mlir
177+
178+
#endif // LIB_UTILS_ARITHMETICDAG_H_

lib/Utils/ArithmeticDagTest.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#include <cmath>
2+
#include <iomanip>
3+
#include <ios>
4+
#include <sstream>
5+
#include <string>
6+
7+
#include "gtest/gtest.h" // from @googletest
8+
#include "lib/Utils/ArithmeticDag.h"
9+
10+
namespace mlir {
11+
namespace heir {
12+
namespace {
13+
14+
using StringLeavedDag = ArithmeticDagNode<std::string>;
15+
using DoubleLeavedDag = ArithmeticDagNode<double>;
16+
17+
struct FlattenedStringVisitor {
18+
std::string operator()(const ConstantNode& node) const {
19+
std::stringstream ss;
20+
ss << std::fixed << std::setprecision(2) << node.value;
21+
return ss.str();
22+
}
23+
24+
std::string operator()(const LeafNode<std::string>& node) const {
25+
return node.value;
26+
}
27+
28+
std::string operator()(const AddNode<std::string>& node) const {
29+
std::stringstream ss;
30+
ss << "(" << node.left->visit(*this) << " + " << node.right->visit(*this)
31+
<< ")";
32+
return ss.str();
33+
}
34+
35+
std::string operator()(const MultiplyNode<std::string>& node) const {
36+
std::stringstream ss;
37+
ss << node.left->visit(*this) << " * " << node.right->visit(*this);
38+
return ss.str();
39+
}
40+
41+
std::string operator()(const PowerNode<std::string>& node) const {
42+
std::stringstream ss;
43+
ss << "(" << node.base->visit(*this) << " ^ " << node.exponent << ")";
44+
return ss.str();
45+
}
46+
};
47+
48+
class EvalVisitor : public CachingVisitor<double, double> {
49+
public:
50+
EvalVisitor() : CachingVisitor<double, double>(), callCount(0) {}
51+
52+
// To test that caching works as expected.
53+
int callCount;
54+
55+
double operator()(const ConstantNode& node) override {
56+
callCount += 1;
57+
return node.value;
58+
}
59+
60+
double operator()(const LeafNode<double>& node) override {
61+
callCount += 1;
62+
return node.value;
63+
}
64+
65+
double operator()(const AddNode<double>& node) override {
66+
// Recursive calls use the public `process` method from the base class
67+
// to ensure caching is applied at every step.
68+
callCount += 1;
69+
return this->process(node.left) + this->process(node.right);
70+
}
71+
72+
double operator()(const MultiplyNode<double>& node) override {
73+
callCount += 1;
74+
return this->process(node.left) * this->process(node.right);
75+
}
76+
77+
double operator()(const PowerNode<double>& node) override {
78+
callCount += 1;
79+
return std::pow(this->process(node.base), node.exponent);
80+
}
81+
};
82+
83+
TEST(ArithmeticDagTest, TestPrint) {
84+
auto root = StringLeavedDag::mul(
85+
StringLeavedDag::add(StringLeavedDag::leaf("x"),
86+
StringLeavedDag::constant(3.0)),
87+
StringLeavedDag::power(StringLeavedDag::leaf("y"), 2));
88+
89+
FlattenedStringVisitor visitor;
90+
std::string result = root->visit(visitor);
91+
EXPECT_EQ(result, "(x + 3.00) * (y ^ 2)");
92+
}
93+
94+
TEST(ArithmeticDagTest, TestProperDag) {
95+
auto shared = StringLeavedDag::power(StringLeavedDag::leaf("y"), 2);
96+
auto root =
97+
StringLeavedDag::mul(StringLeavedDag::add(shared, shared), shared);
98+
99+
FlattenedStringVisitor visitor;
100+
std::string result = root->visit(visitor);
101+
EXPECT_EQ(result, "((y ^ 2) + (y ^ 2)) * (y ^ 2)");
102+
}
103+
104+
TEST(ArithmeticDagTest, TestEvaluationVisitor) {
105+
auto shared = DoubleLeavedDag::power(DoubleLeavedDag::leaf(2.0), 2);
106+
auto root = DoubleLeavedDag::mul(DoubleLeavedDag::add(shared, shared),
107+
DoubleLeavedDag::constant(3.0));
108+
109+
EvalVisitor visitor;
110+
double result = root->visit(visitor);
111+
EXPECT_EQ(result, 24.0);
112+
EXPECT_EQ(visitor.callCount, 5);
113+
}
114+
115+
} // namespace
116+
} // namespace heir
117+
} // namespace mlir

lib/Utils/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,18 @@ cc_library(
153153
"@llvm-project//mlir:Support",
154154
],
155155
)
156+
157+
cc_library(
158+
name = "ArithmeticDag",
159+
srcs = ["ArithmeticDag.h"],
160+
hdrs = ["ArithmeticDag.h"],
161+
)
162+
163+
cc_test(
164+
name = "ArithmeticDagTest",
165+
srcs = ["ArithmeticDagTest.cpp"],
166+
deps = [
167+
":ArithmeticDag",
168+
"@googletest//:gtest_main",
169+
],
170+
)

0 commit comments

Comments
 (0)