Skip to content

Commit 16752d2

Browse files
dvadymcopybara-github
authored andcommitted
Chebyshev Decomposition
Implemented the decomposition of the polynomial in Chebyshev bases by k-th Chebyshev polynomial. See docstring in ChebyshevDecomposition.h for more detailes. The decomposition will be used in HE evaluation of polynomials Chebyshev basis. PiperOrigin-RevId: 752750648
1 parent ef7857d commit 16752d2

File tree

4 files changed

+221
-0
lines changed

4 files changed

+221
-0
lines changed

lib/Utils/Polynomial/BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,20 @@ cc_test(
2424
"@llvm-project//mlir:Support",
2525
],
2626
)
27+
28+
cc_library(
29+
name = "ChebyshevDecomposition",
30+
srcs = ["ChebyshevDecomposition.cpp"],
31+
hdrs = ["ChebyshevDecomposition.h"],
32+
deps = ["@llvm-project//llvm:Support"],
33+
)
34+
35+
cc_test(
36+
name = "ChebyshevDecompositionTest",
37+
srcs = ["ChebyshevDecompositionTest.cpp"],
38+
deps = [
39+
":ChebyshevDecomposition",
40+
"@googletest//:gtest_main",
41+
"@llvm-project//llvm:Support",
42+
],
43+
)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#include "lib/Utils/Polynomial/ChebyshevDecomposition.h"
2+
3+
#include <cstdlib>
4+
#include <utility>
5+
6+
namespace mlir {
7+
namespace heir {
8+
namespace polynomial {
9+
10+
namespace {
11+
12+
// Finds polynomials q, r in the Chebyshev basis, such that
13+
// p = q*T_k + r.
14+
std::pair<ChebyshevBasisPolynomial, ChebyshevBasisPolynomial> dividePolynomials(
15+
ChebyshevBasisPolynomial p, int k) {
16+
if (k >= p.size()) {
17+
return {{}, p};
18+
}
19+
ChebyshevBasisPolynomial q(p.size() - k, APFloat(0.0));
20+
for (int i = p.size() - 1; i >= k; --i) {
21+
if (p[i].isZero()) {
22+
continue;
23+
}
24+
if (i == k) {
25+
q[0] = p[i];
26+
} else {
27+
// Formula: 2T_m(x)T_n(x) = T_{m+n}(x) + T_{|m-n|}(x) is used
28+
// https://en.wikipedia.org/wiki/Chebyshev_polynomials#Products_of_Chebyshev_polynomials
29+
// Namely
30+
// p_iT_i + ... = p_i(2T_kT_{i-k}-T_{|i-2k|}) + ... =
31+
// T_k*(2p_iT_{i-k}) + // this term goes to q
32+
// -p_iT_{|i-2k|} + ... // the rest stays in p
33+
// As a result on each iteration we decrease the degree of the polynomial
34+
// which we divide.
35+
q[i - k] = p[i] * APFloat(2.0);
36+
p[std::abs(i - 2 * k)].subtract(p[i], llvm::APFloat::rmNearestTiesToEven);
37+
38+
p[i] = APFloat(0.0);
39+
}
40+
}
41+
ChebyshevBasisPolynomial r(p.begin(), p.begin() + k);
42+
return {q, r};
43+
}
44+
} // namespace
45+
46+
ChebyshevDecomposition decompose(
47+
const ChebyshevBasisPolynomial &cheb_polynomial, int decomposition_degree) {
48+
ChebyshevBasisPolynomial p = cheb_polynomial;
49+
ChebyshevDecomposition result{.generatorDegree = decomposition_degree};
50+
while (!p.empty()) {
51+
auto [next_p, q] = dividePolynomials(p, decomposition_degree);
52+
result.coeffs.push_back(q);
53+
p = next_p;
54+
}
55+
return result;
56+
}
57+
58+
} // namespace polynomial
59+
} // namespace heir
60+
} // namespace mlir
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#ifndef LIB_UTILS_POLYNOMIAL_CHEBYSHEVDECOMPOSITION_H_
2+
#define LIB_UTILS_POLYNOMIAL_CHEBYSHEVDECOMPOSITION_H_
3+
4+
#include <vector>
5+
6+
#include "llvm/include/llvm/ADT/APFloat.h" // from @llvm-project
7+
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
8+
9+
namespace mlir {
10+
namespace heir {
11+
namespace polynomial {
12+
13+
using ::llvm::APFloat;
14+
using ::llvm::SmallVector;
15+
16+
// Represents the polynomial in the Chebyshev polynomials basis. Namely,
17+
// let p is ChebyshevBasisPolynomial, then it represents the polynomial:
18+
// P(x) = p[0]*T_0(x) + p[1]T_1(x) + ... + p[k] T_k(x).
19+
using ChebyshevBasisPolynomial = SmallVector<APFloat>;
20+
21+
// Represents the polynomial:
22+
// p = coeffs[0] + coeffs[1]*T_k + coeffs[2]*T_k^2 + ... + coeffs[l]*T_k^l
23+
// where, k = generatorDegree, T_k is k-th Chebyshev polynomial and qs are
24+
// polynomials in the basis of the Chebyshev polynomials.
25+
struct ChebyshevDecomposition {
26+
int generatorDegree;
27+
std::vector<ChebyshevBasisPolynomial> coeffs;
28+
};
29+
30+
// Decomposes the polynomial in the Chebyshev polynomials basis.
31+
ChebyshevDecomposition decompose(
32+
const ChebyshevBasisPolynomial &cheb_polynomial, int decomposition_degree);
33+
34+
} // namespace polynomial
35+
} // namespace heir
36+
} // namespace mlir
37+
38+
#endif // LIB_UTILS_POLYNOMIAL_CHEBYSHEVDECOMPOSITION_H_
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#include "gmock/gmock.h" // from @googletest
2+
#include "gtest/gtest.h" // from @googletest
3+
#include "lib/Utils/Polynomial/ChebyshevDecomposition.h"
4+
#include "llvm/include/llvm/ADT/APFloat.h" // from @llvm-project
5+
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
6+
7+
namespace mlir {
8+
namespace heir {
9+
namespace polynomial {
10+
namespace {
11+
12+
using ::llvm::APFloat;
13+
using ::llvm::SmallVector;
14+
using ::testing::ElementsAre;
15+
16+
TEST(ChebyshevDecompositionTest, EmptyPolynomial) {
17+
ChebyshevBasisPolynomial p;
18+
ChebyshevDecomposition decomposition = decompose(p, 1);
19+
EXPECT_EQ(decomposition.generatorDegree, 1);
20+
EXPECT_TRUE(decomposition.coeffs.empty());
21+
}
22+
23+
TEST(ChebyshevDecompositionTest, ConstantPolynomial) {
24+
ChebyshevBasisPolynomial p = {APFloat(5.0)};
25+
ChebyshevDecomposition decomposition = decompose(p, 1);
26+
ASSERT_EQ(decomposition.coeffs.size(), 1);
27+
EXPECT_THAT(decomposition.coeffs[0], ElementsAre(APFloat(5.0)));
28+
}
29+
30+
TEST(ChebyshevDecompositionTest, LinearPolynomialK1) {
31+
ChebyshevBasisPolynomial p = {APFloat(-1.0), APFloat(-3.0)};
32+
ChebyshevDecomposition decomposition = decompose(p, 1);
33+
ASSERT_EQ(decomposition.coeffs.size(), 2);
34+
EXPECT_THAT(decomposition.coeffs[0], ElementsAre(APFloat(-1.0)));
35+
EXPECT_THAT(decomposition.coeffs[1], ElementsAre(APFloat(-3.0)));
36+
}
37+
38+
TEST(ChebyshevDecompositionTest, LinearPolynomialK3) {
39+
ChebyshevBasisPolynomial p = {APFloat(-1.0), APFloat(-3.0)};
40+
ChebyshevDecomposition decomposition = decompose(p, 3);
41+
EXPECT_EQ(decomposition.generatorDegree, 3);
42+
ASSERT_EQ(decomposition.coeffs.size(), 1);
43+
EXPECT_THAT(decomposition.coeffs[0],
44+
ElementsAre(APFloat(-1.0), APFloat(-3.0)));
45+
}
46+
47+
TEST(ChebyshevDecompositionTest, QuadraticPolynomial) {
48+
ChebyshevBasisPolynomial p = {APFloat(1.0), APFloat(-2.0), APFloat(3.0)};
49+
ChebyshevDecomposition decomposition = decompose(p, 2);
50+
ASSERT_EQ(decomposition.coeffs.size(), 2);
51+
EXPECT_THAT(decomposition.coeffs[0],
52+
ElementsAre(APFloat(1.0), APFloat(-2.0)));
53+
EXPECT_THAT(decomposition.coeffs[1], ElementsAre(APFloat(3.0)));
54+
}
55+
56+
// The expected output was found with the reference Python implementation and it
57+
// was verified independently by evaluating the polynomials on one of the
58+
// points.
59+
TEST(ChebyshevDecompositionTest, Degree7Polynomial) {
60+
ChebyshevBasisPolynomial p = {APFloat(1.0), APFloat(-2.0), APFloat(3.0),
61+
APFloat(4.0), APFloat(5.0), APFloat(6.0),
62+
APFloat(-7.0), APFloat(8.0)};
63+
ChebyshevDecomposition decomposition = decompose(p, 3);
64+
EXPECT_EQ(decomposition.generatorDegree, 3);
65+
ASSERT_EQ(decomposition.coeffs.size(), 3);
66+
EXPECT_THAT(decomposition.coeffs[0],
67+
ElementsAre(APFloat(8.0), APFloat(-16.0), APFloat(-2.0)));
68+
EXPECT_THAT(decomposition.coeffs[1],
69+
ElementsAre(APFloat(4.0), APFloat(10.0), APFloat(-4.0)));
70+
EXPECT_THAT(decomposition.coeffs[2],
71+
ElementsAre(APFloat(-14.0), APFloat(32.0)));
72+
}
73+
74+
TEST(ChebyshevDecompositionTest, Degree20Polynomial) {
75+
ChebyshevBasisPolynomial p = {APFloat(-1.0), APFloat(2.0), APFloat(3.0),
76+
APFloat(4.0), APFloat(5.0), APFloat(-6.0),
77+
APFloat(7.0), APFloat(-8.0), APFloat(9.0),
78+
APFloat(10.0), APFloat(11.0), APFloat(12.0),
79+
APFloat(-13.0), APFloat(14.0), APFloat(15.0),
80+
APFloat(-16.0), APFloat(17.0), APFloat(18.0),
81+
APFloat(19.0), APFloat(20.0), APFloat(21.0)};
82+
ChebyshevDecomposition decomposition = decompose(p, 4);
83+
EXPECT_EQ(decomposition.generatorDegree, 4);
84+
ASSERT_EQ(decomposition.coeffs.size(), 6);
85+
EXPECT_THAT(
86+
decomposition.coeffs[0],
87+
ElementsAre(APFloat(7.0), APFloat(2.0), APFloat(19.0), APFloat(32.0)));
88+
EXPECT_THAT(decomposition.coeffs[1],
89+
ElementsAre(APFloat(149.0), APFloat(-12.0), APFloat(8.0),
90+
APFloat(100.0)));
91+
EXPECT_THAT(decomposition.coeffs[2],
92+
ElementsAre(APFloat(-118.0), APFloat(-112.0), APFloat(-244.0),
93+
APFloat(-248.0)));
94+
EXPECT_THAT(decomposition.coeffs[3],
95+
ElementsAre(APFloat(-472.0), APFloat(-48.0), APFloat(-32.0),
96+
APFloat(-272.0)));
97+
EXPECT_THAT(decomposition.coeffs[4],
98+
ElementsAre(APFloat(136.0), APFloat(288.0), APFloat(304.0),
99+
APFloat(320.0)));
100+
EXPECT_THAT(decomposition.coeffs[5], ElementsAre(APFloat(336.0)));
101+
}
102+
103+
} // namespace
104+
} // namespace polynomial
105+
} // namespace heir
106+
} // namespace mlir

0 commit comments

Comments
 (0)