Skip to content

Commit 9e53a85

Browse files
code-perspectivecopybara-github
authored andcommitted
Add dot_product_128 test to openfhe bgv tests
PiperOrigin-RevId: 734349383
1 parent ac31b6f commit 9e53a85

File tree

6 files changed

+249
-41
lines changed

6 files changed

+249
-41
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# See README.md for setup required to run these tests
2+
3+
load("@heir//tests/Examples/openfhe:test.bzl", "openfhe_end_to_end_test")
4+
5+
package(default_applicable_licenses = ["@heir//:license"])
6+
7+
openfhe_end_to_end_test(
8+
name = "dot_product_128_debug_test",
9+
generated_lib_header = "dot_product_128_debug_lib.h",
10+
heir_opt_flags = [
11+
"--mlir-to-bgv=ciphertext-degree=128 annotate-noise-bound=true plaintext-modulus=4194305",
12+
"--scheme-to-openfhe=insert-debug-handler-calls=true",
13+
],
14+
mlir_src = "dot_product_128.mlir",
15+
tags = ["notap"],
16+
test_src = "dot_product_128_debug_test.cpp",
17+
)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
func.func @dot_product(%arg0: tensor<128xi32> {secret.secret}, %arg1: tensor<128xi32> {secret.secret}) -> i32 {
2+
%c0 = arith.constant 0 : index
3+
%c0_si32 = arith.constant 0 : i32
4+
%0 = affine.for %arg2 = 0 to 128 iter_args(%iter = %c0_si32) -> (i32) {
5+
%1 = tensor.extract %arg0[%arg2] : tensor<128xi32>
6+
%2 = tensor.extract %arg1[%arg2] : tensor<128xi32>
7+
%3 = arith.muli %1, %2 : i32
8+
%4 = arith.addi %iter, %3 : i32
9+
affine.yield %4 : i32
10+
}
11+
return %0 : i32
12+
}
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#include <cmath>
2+
#include <cstddef>
3+
#include <cstdint>
4+
#include <iostream>
5+
#include <map>
6+
#include <ostream>
7+
#include <string>
8+
#include <vector>
9+
10+
#include "gtest/gtest.h" // from @googletest
11+
#include "src/core/include/lattice/hal/lat-backend.h" // from @openfhe
12+
#include "src/core/include/utils/inttypes.h" // from @openfhe
13+
#include "src/pke/include/key/privatekey-fwd.h" // from @openfhe
14+
15+
// Generated headers (block clang-format from messing up order)
16+
#include "tests/Examples/openfhe/bgv/dot_product_128_debug/dot_product_128_debug_lib.h"
17+
18+
// DecryptCore not accessible from CryptoContext
19+
// so copy from @openfhe//src/pke/lib/schemerns/rns-pke.cpp
20+
DCRTPoly DecryptCore(const std::vector<DCRTPoly>& cv,
21+
const PrivateKey<DCRTPoly> privateKey) {
22+
const DCRTPoly& s = privateKey->GetPrivateElement();
23+
24+
size_t sizeQ = s.GetParams()->GetParams().size();
25+
size_t sizeQl = cv[0].GetParams()->GetParams().size();
26+
27+
size_t diffQl = sizeQ - sizeQl;
28+
29+
auto scopy(s);
30+
scopy.DropLastElements(diffQl);
31+
32+
DCRTPoly sPower(scopy);
33+
34+
DCRTPoly b(cv[0]);
35+
b.SetFormat(Format::EVALUATION);
36+
37+
DCRTPoly ci;
38+
for (size_t i = 1; i < cv.size(); i++) {
39+
ci = cv[i];
40+
ci.SetFormat(Format::EVALUATION);
41+
42+
b += sPower * ci;
43+
sPower *= scopy;
44+
}
45+
return b;
46+
}
47+
48+
#define OP
49+
#define DECRYPT
50+
#define NOISE
51+
52+
void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct,
53+
const std::map<std::string, std::string>& debugAttrMap) {
54+
#ifdef OP
55+
auto isBlockArgument = debugAttrMap.at("asm.is_block_arg");
56+
if (isBlockArgument == "1") {
57+
std::cout << "Input" << std::endl;
58+
} else {
59+
std::cout << debugAttrMap.at("asm.op_name") << std::endl;
60+
}
61+
#endif
62+
63+
#ifdef DECRYPT
64+
PlaintextT ptxt;
65+
cc->Decrypt(sk, ct, &ptxt);
66+
ptxt->SetLength(std::stod(debugAttrMap.at("message.size")));
67+
std::cout << " " << ptxt << std::endl;
68+
#endif
69+
70+
#ifdef NOISE
71+
auto cv = ct->GetElements();
72+
size_t sizeQl = cv[0].GetParams()->GetParams().size();
73+
74+
auto b = DecryptCore(cv, sk);
75+
b.SetFormat(Format::COEFFICIENT);
76+
77+
double noise = (log2(b.Norm()));
78+
79+
double logQ = 0;
80+
std::vector<double> logqi_v;
81+
for (usint i = 0; i < sizeQl; i++) {
82+
double logqi =
83+
log2(cv[0].GetParams()->GetParams()[i]->GetModulus().ConvertToInt());
84+
logqi_v.push_back(logqi);
85+
logQ += logqi;
86+
}
87+
88+
std::cout << " cv " << cv.size() << " Ql " << sizeQl << " logQ: " << logQ
89+
<< " logqi: " << logqi_v << " budget " << logQ - noise - 1
90+
<< " noise: " << noise << std::endl;
91+
92+
// print the predicted bound by analysis
93+
if (debugAttrMap.find("noise.bound") != debugAttrMap.end()) {
94+
double noiseBound = std::stod(debugAttrMap.at("noise.bound"));
95+
96+
std::cout << " noise bound: " << noiseBound
97+
<< " gap: " << noiseBound - noise << std::endl;
98+
}
99+
#endif
100+
}
101+
102+
namespace mlir {
103+
namespace heir {
104+
namespace openfhe {
105+
106+
int32_t DotProductPlaintext(std::vector<int32_t> arg0,
107+
std::vector<int32_t> arg1) {
108+
int32_t dot_product = 0;
109+
for (int i = 0; i < arg0.size(); ++i) {
110+
dot_product += arg0[i] * arg1[i];
111+
}
112+
return dot_product;
113+
}
114+
115+
TEST(DotProduct128Test, RunTest) {
116+
auto cryptoContext = dot_product__generate_crypto_context();
117+
auto keyPair = cryptoContext->KeyGen();
118+
auto publicKey = keyPair.publicKey;
119+
auto secretKey = keyPair.secretKey;
120+
cryptoContext =
121+
dot_product__configure_crypto_context(cryptoContext, secretKey);
122+
std::cout << *cryptoContext->GetCryptoParameters() << std::endl;
123+
124+
std::vector<int32_t> arg0;
125+
std::vector<int32_t> arg1;
126+
for (int i = 0; i < 128; ++i) {
127+
arg0.push_back(22);
128+
arg1.push_back((22));
129+
}
130+
131+
int32_t expected = DotProductPlaintext(arg0, arg1);
132+
133+
auto arg0Encrypted =
134+
dot_product__encrypt__arg0(cryptoContext, arg0, publicKey);
135+
auto arg1Encrypted =
136+
dot_product__encrypt__arg1(cryptoContext, arg1, publicKey);
137+
auto outputEncrypted =
138+
dot_product(cryptoContext, secretKey, arg0Encrypted, arg1Encrypted);
139+
auto actual =
140+
dot_product__decrypt__result0(cryptoContext, outputEncrypted, secretKey);
141+
142+
EXPECT_EQ(expected, actual);
143+
}
144+
145+
} // namespace openfhe
146+
} // namespace heir
147+
} // namespace mlir

tests/Examples/openfhe/test.bzl

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""A macro providing an end-to-end test for OpenFHE codegen."""
22

33
load("@heir//bazel/openfhe:copts.bzl", "MAYBE_OPENFHE_LINKOPTS", "MAYBE_OPENMP_COPTS")
4-
load("@heir//tools:heir-opt.bzl", "heir_opt")
5-
load("@heir//tools:heir-translate.bzl", "heir_translate")
4+
load("@heir//tools:heir-openfhe.bzl", "openfhe_lib")
65

76
def openfhe_end_to_end_test(name, mlir_src, test_src, generated_lib_header, heir_opt_flags = [], heir_translate_flags = [], data = [], tags = [], deps = [], **kwargs):
87
"""A rule for running generating OpenFHE and running a test on it.
@@ -20,46 +19,8 @@ def openfhe_end_to_end_test(name, mlir_src, test_src, generated_lib_header, heir
2019
deps: Deps to pass to cc_test and cc_library
2120
**kwargs: Keyword arguments to pass to cc_library and cc_test.
2221
"""
23-
cc_codegen_target = name + ".heir_translate_cc"
24-
h_codegen_target = name + ".heir_translate_h"
2522
cc_lib_target_name = "%s_cc_lib" % name
26-
generated_cc_filename = "%s_lib.inc.cc" % name
27-
heir_opt_name = "%s_heir_opt" % name
28-
generated_heir_opt_name = "%s_heir_opt.mlir" % name
29-
heir_translate_flags = heir_translate_flags + ["--emit-openfhe-pke", "--openfhe-include-type=source-relative"]
30-
31-
if heir_opt_flags:
32-
heir_opt(
33-
name = heir_opt_name,
34-
src = mlir_src,
35-
pass_flags = heir_opt_flags,
36-
generated_filename = generated_heir_opt_name,
37-
)
38-
else:
39-
generated_heir_opt_name = mlir_src
40-
41-
heir_translate(
42-
name = cc_codegen_target,
43-
src = generated_heir_opt_name,
44-
pass_flags = heir_translate_flags,
45-
generated_filename = generated_cc_filename,
46-
)
47-
heir_translate(
48-
name = h_codegen_target,
49-
src = generated_heir_opt_name,
50-
pass_flags = heir_translate_flags,
51-
generated_filename = generated_lib_header,
52-
)
53-
native.cc_library(
54-
name = cc_lib_target_name,
55-
srcs = [":" + generated_cc_filename],
56-
hdrs = [":" + generated_lib_header],
57-
deps = deps + ["@openfhe//:pke"],
58-
tags = tags,
59-
copts = MAYBE_OPENMP_COPTS,
60-
linkopts = MAYBE_OPENFHE_LINKOPTS,
61-
**kwargs
62-
)
23+
openfhe_lib(name, mlir_src, generated_lib_header, cc_lib_target_name, heir_opt_flags, heir_translate_flags, tags, deps, **kwargs)
6324
native.cc_test(
6425
name = name,
6526
srcs = [test_src],

tools/BUILD

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,13 @@ bzl_library(
262262
"@rules_python//python:py_library_bzl",
263263
],
264264
)
265+
266+
bzl_library(
267+
name = "heir_openfhe_bzl",
268+
srcs = ["heir-openfhe.bzl"],
269+
visibility = ["//visibility:public"],
270+
deps = [
271+
":heir_opt_bzl",
272+
":heir_translate_bzl",
273+
],
274+
)

tools/heir-openfhe.bzl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""A macro providing an end-to-end library for OpenFHE codegen."""
2+
3+
load("@heir//bazel/openfhe:copts.bzl", "MAYBE_OPENFHE_LINKOPTS", "MAYBE_OPENMP_COPTS")
4+
load("@heir//tools:heir-opt.bzl", "heir_opt")
5+
load("@heir//tools:heir-translate.bzl", "heir_translate")
6+
7+
def openfhe_lib(name, mlir_src, generated_lib_header, cc_lib_target_name, heir_opt_flags = [], heir_translate_flags = [], tags = [], deps = [], **kwargs):
8+
"""A rule for running generating OpenFHE and running a test on it.
9+
10+
Args:
11+
name: The name of the cc_test target and the generated .cc file basename.
12+
mlir_src: The source mlir file to run through heir-translate
13+
generated_lib_header: The name of the generated .h file (explicit
14+
because it needs to be manually #include'd in the test_src file)
15+
cc_lib_target_name: The name of the generated cc_library target
16+
heir_opt_flags: Flags to pass to heir-opt before heir-translate
17+
heir_translate_flags: Flags to pass to heir-translate
18+
tags: Tags to pass to cc_test and cc_library
19+
deps: Deps to pass to cc_test and cc_library
20+
**kwargs: Keyword arguments to pass to cc_library and cc_test.
21+
"""
22+
cc_codegen_target = name + ".heir_translate_cc"
23+
h_codegen_target = name + ".heir_translate_h"
24+
25+
generated_cc_filename = "%s_lib.inc.cc" % name
26+
heir_opt_name = "%s_heir_opt" % name
27+
generated_heir_opt_name = "%s_heir_opt.mlir" % name
28+
heir_translate_flags = heir_translate_flags + ["--emit-openfhe-pke", "--openfhe-include-type=source-relative"]
29+
30+
if heir_opt_flags:
31+
heir_opt(
32+
name = heir_opt_name,
33+
src = mlir_src,
34+
pass_flags = heir_opt_flags,
35+
generated_filename = generated_heir_opt_name,
36+
)
37+
else:
38+
generated_heir_opt_name = mlir_src
39+
40+
heir_translate(
41+
name = cc_codegen_target,
42+
src = generated_heir_opt_name,
43+
pass_flags = heir_translate_flags,
44+
generated_filename = generated_cc_filename,
45+
)
46+
heir_translate(
47+
name = h_codegen_target,
48+
src = generated_heir_opt_name,
49+
pass_flags = heir_translate_flags,
50+
generated_filename = generated_lib_header,
51+
)
52+
native.cc_library(
53+
name = cc_lib_target_name,
54+
srcs = [":" + generated_cc_filename],
55+
hdrs = [":" + generated_lib_header],
56+
deps = deps + ["@openfhe//:pke"],
57+
tags = tags,
58+
copts = MAYBE_OPENMP_COPTS,
59+
linkopts = MAYBE_OPENFHE_LINKOPTS,
60+
**kwargs
61+
)

0 commit comments

Comments
 (0)