diff --git a/CHANGELOGS.md b/CHANGELOGS.md index 0dae80f0..01a9cb3f 100644 --- a/CHANGELOGS.md +++ b/CHANGELOGS.md @@ -8,7 +8,7 @@ ## [Unreleased] -> please add your unreleased change here. +> - [Feature] Add LeichiPaillier algorithms to HEU. ## [0.4.4] diff --git a/heu/library/algorithms/leichi_paillier/BUILD.bazel b/heu/library/algorithms/leichi_paillier/BUILD.bazel new file mode 100644 index 00000000..181be8b0 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/BUILD.bazel @@ -0,0 +1,131 @@ +load("@yacl//bazel:yacl.bzl", "yacl_cc_library", "yacl_cc_test") + +package(default_visibility = ["//visibility:public"]) + +test_suite( + name = "leichi_paillier_tests", +) + +config_setting( + name = "use_leichi", + values = {"define": "enable_leichi=true"}, +) + +yacl_cc_library( + name = "leichi_paillier_defs", + hdrs = ["leichi.h"], + deps = [":leichi_paillier"], +) + +yacl_cc_library( + name = "leichi_paillier", + srcs = select({ + ":use_leichi":[ + "vector_decryptor.cc", + "vector_encryptor.cc", + "vector_evaluator.cc", + "key_generator.cc", + "public_key.cc", + "secret_key.cc", + "plaintext.cc", + "ciphertext.cc", + "utils.cc", + "runtime.cc", + ], + "//conditions:default":[], + }), + + hdrs = select({ + ":use_leichi":[ + "plaintext.h", + "ciphertext.h", + "vector_decryptor.h", + "vector_encryptor.h", + "vector_evaluator.h", + "key_generator.h", + "public_key.h", + "secret_key.h", + "leichi.h", + "utils.h", + "runtime.h", + ], + "//conditions:default":[], + }), + + visibility = ["//visibility:public"], + + deps = select({ + ":use_leichi":[ + "//heu/library/algorithms/util", + "@com_github_msgpack_msgpack//:msgpack", + "@com_github_openssl_openssl//:openssl", + "@com_github_uscilab_cereal//:cereal", + ":pcie", + "compiler", + ], + "//conditions:default":[], + }), + + defines = select({ + "use_leichi":["APPLY_LEICHI"], + "//conditions:default":[], + }) +) + + +yacl_cc_library( + name = "pcie", + srcs = ["pcie/pcie.cc"], + hdrs = ["pcie/pcie.h"], + deps = [ + + ], +) + +yacl_cc_library( + name = "compiler", + srcs = ["compiler/compiler.cc"], + hdrs = ["compiler/compiler.h"], + deps = [ + + ], +) + +yacl_cc_test( + name = "key_generator_test", + srcs = select({ + ":use_leichi":[ + "key_generator_test.cc"], + "//conditions:default":[], + }), + deps = select({ + ":use_leichi":[ + ":leichi_paillier", + "@com_github_openssl_openssl//:openssl"], + "//conditions:default":[], + }), + + defines = select({ + "use_leichi":["APPLY_LEICHI"], + "//conditions:default":[], + }) +) + +yacl_cc_test( + name = "leichi_test", + srcs = select({ + ":use_leichi":[ + "leichi_test.cc"], + "//conditions:default":[], + }), + + deps = select({ + ":use_leichi":[ + ":leichi_paillier"], + "//conditions:default":[], + }), + defines = select({ + "use_leichi":["APPLY_LEICHI"], + "//conditions:default":[], + }) +) diff --git a/heu/library/algorithms/leichi_paillier/README.md b/heu/library/algorithms/leichi_paillier/README.md new file mode 100644 index 00000000..71afca80 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/README.md @@ -0,0 +1,36 @@ +# 使用方法 + +## 简介 +``` +leichi_paillier 默认打开; +``` +## HEU编译 + +bazel build heu/... +``` +bazel test heu/... --test_output=all --cache_test_results=no + +使用 leichi_paillier +``` +bazel test heu/... --test_output=all --cache_test_results=no --define enable_leichi=true + + +## leichi_paillier相关单元测试 + +使用 leichi_paillier +``` +bazel test --test_output=all --cache_test_results=no heu/library/algorithms/leichi_paillier:encryptor_test --define enable_leichi=true +bazel test --test_output=all --cache_test_results=no heu/library/algorithms/leichi_paillier:key_generator_test --define enable_leichi=true + +``` +## Benchmark测试 + +使用 leichi_paillier +``` +scalar 场景性能测试 +``` +bazel run -c opt heu/library/benchmark:phe -- --schema=Leichi --define enable_leichi=true + +vector 场景性能测试 +``` +bazel run -c opt heu/library/benchmark:np -- --schema=Leichi --define enable_leichi=true \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/ciphertext.cc b/heu/library/algorithms/leichi_paillier/ciphertext.cc new file mode 100644 index 00000000..3f1ec95b --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/ciphertext.cc @@ -0,0 +1,40 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "heu/library/algorithms/leichi_paillier/ciphertext.h" +#include + +namespace heu::lib::algorithms::leichi_paillier { + + std::string Ciphertext::ToString() const { + char* str = BN_bn2dec(bn_); + std::string result(str); + return result; + } + + std::ostream &operator<<(std::ostream &os, const Ciphertext &ct) { + char* str = BN_bn2dec(ct.bn_); + os << str; + return os; + } + + bool Ciphertext::operator==(const Ciphertext &other) const { + return (BN_cmp(bn_, other.bn_) == 0)?true:false; + } + + bool Ciphertext::operator!=(const Ciphertext &other) const { + return (BN_cmp(bn_, other.bn_) == 0)?true:false; + } + +} // namespace heu::lib::algorithms::leichi_paillier diff --git a/heu/library/algorithms/leichi_paillier/ciphertext.h b/heu/library/algorithms/leichi_paillier/ciphertext.h new file mode 100644 index 00000000..f1bccfff --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/ciphertext.h @@ -0,0 +1,72 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "openssl/bn.h" +#include +#include +#include +#include "yacl/base/byte_container_view.h" +#include "cereal/archives/portable_binary.hpp" +#include +#pragma once +namespace heu::lib::algorithms::leichi_paillier { + class Ciphertext { + public: + BIGNUM* bn_; + public: + Ciphertext() { + bn_ = BN_new(); + } + ~Ciphertext(){ + BN_free(bn_); + } + + Ciphertext(const Ciphertext& other) { + bn_ = BN_dup(other.bn_); + } + + Ciphertext& operator=(const Ciphertext& other) { + if (this != &other) { + BN_copy(bn_, other.bn_); + } + return *this; + } + + explicit Ciphertext(BIGNUM *bn){ bn_ = bn;}//BN_dup(bn);} + + std::string ToString() const; + friend std::ostream &operator<<(std::ostream &os, const Ciphertext &ct); + + bool operator==(const Ciphertext &other) const; + bool operator!=(const Ciphertext &other) const; + + yacl::Buffer Serialize() const{ + uint32_t n_bits_len = BN_num_bits(bn_); + uint8_t* n_arr = new uint8_t[n_bits_len]; + std::vector vec_tmp; + BN_bn2bin(bn_, n_arr); + uint32_t bytes_len = std::ceil(n_bits_len/8.0); + for(uint32_t i=0;i CHIP_ORDER = { + "CHIP_0", + "CHIP_2", + "CHIP_4", + "CHIP_7", + "CHIP_9", + "CHIP_11", + "CHIP_12", + "CHIP_15", +}; + +std::map CHIP_TABLE = { + {"CHIP_0", (int) std::pow(2, 15)}, + {"CHIP_1", (int) std::pow(2, 14)}, + {"CHIP_2", (int) std::pow(2, 13)}, + {"CHIP_3", (int) std::pow(2, 12)}, + {"CHIP_4", (int) std::pow(2, 11)}, + {"CHIP_5", (int) std::pow(2, 10)}, + {"CHIP_6", (int) std::pow(2, 9)}, + {"CHIP_7", (int) std::pow(2, 8)}, + {"CHIP_8", (int) std::pow(2, 7)}, + {"CHIP_9", (int) std::pow(2, 6)}, + {"CHIP_10", (int) std::pow(2, 5)}, + {"CHIP_11", (int) std::pow(2, 4)}, + {"CHIP_12", (int) std::pow(2, 3)}, + {"CHIP_13", (int) std::pow(2, 2)}, + {"CHIP_14", (int) std::pow(2, 1)}, + {"CHIP_15", (int) std::pow(2, 0)}, +}; + +uint64_t gen_inst_l(uint32_t address, uint32_t length, uint8_t des_pe, uint8_t des_reg, uint16_t times, bool change_flag, uint8_t pe_gate) { + if (address >= 80*1024) { + throw std::runtime_error("Sram data in can't hold the data need"); + } + uint64_t inst = 0; + inst = inst + ( (0b01 % static_cast(std::pow(2, 2))) << 62 ); + inst = inst + ( (address % static_cast(std::pow(2, 18))) << 44 ); + inst = inst + ( (length % static_cast(std::pow(2, 18))) << 26 ); + inst = inst + ( (des_pe % static_cast(std::pow(2, 6))) << 20 ); + inst = inst + ( (des_reg % static_cast(std::pow(2, 4))) << 16 ); + inst = inst + ( (times % static_cast(std::pow(2, 10))) << 6 ); + inst = inst + ( (change_flag % static_cast(std::pow(2, 1))) << 5 ); + inst = inst + ( (pe_gate % static_cast(std::pow(2, 5))) << 0 ); + return inst; +} + +uint64_t gen_inst_c(uint8_t pe_state, bool cal_flag, uint8_t pe_n_state, uint32_t address, uint8_t pe_gate) { + uint64_t inst = 0; + inst = inst + ( (0b10 % static_cast(std::pow(2, 2))) << 62 ); + inst = inst + ( (pe_state % static_cast(std::pow(2, 8))) << 54 ); + inst = inst + ( (cal_flag % static_cast(std::pow(2, 1))) << 53 ); + inst = inst + ( (pe_n_state % static_cast(std::pow(2, 8))) << 45 ); + inst = inst + ( (address % static_cast(std::pow(2, 18))) << 27 ); + inst = inst + ( (pe_gate % static_cast(std::pow(2, 5))) << 22 ); + inst = inst + ( 0b0 << 0 ); + return inst; +} + +uint64_t gen_inst_i(uint64_t ddr_address,uint64_t ddr_length) { + uint64_t inst = 0; + inst = inst + ( (0b11 % static_cast(std::pow(2, 2))) << 62 ); + inst = inst + ( (ddr_address % static_cast(std::pow(2, 32))) << 30 ); + inst = inst + ( ((ddr_length-1) & 0x1fff % static_cast(std::pow(2, 13))) << 17 ); + return inst; +} + +uint64_t gen_inst_none(){ + uint128_t inst = 0; + return inst; +} + +void check_inst_sram_depth(std::vector _inst) +{ + std::stringstream ss; + if(_inst.size()>=LOCAL_INST_BUFFER_DEPTH) + { + ss << "Sram inst in can't hold the data need"; + throw std::runtime_error(ss.str()); + } +} + +uint128_t gen_inst_r(uint16_t chip, uint8_t data_type, uint32_t data_address, uint32_t data) { + uint128_t inst = 0; + inst = inst + ( (0b110 % static_cast(std::pow(2, 3))) << 125 ); + inst = inst + ( (chip % static_cast(std::pow(2, 16))) << 109 ); + inst = inst + ( (data_type % static_cast(std::pow(2, 3))) << 106 ); + inst = inst + ( (data_address % static_cast(std::pow(2, 32))) << 74 ); + inst = inst + ( (data % static_cast(std::pow(2, 32))) << 42 ); + return inst; +} + +uint128_t gen_inst_l1(uint16_t chip, uint32_t ddr_address, uint32_t ddr_length, uint8_t data_type, uint32_t data_address, bool bool_check) { + uint128_t inst = 0; + if (ddr_length == 0) { + ddr_length = 1; + } + inst = inst + ( (0b000 % static_cast(std::pow(2, 3))) << 125 ); + inst = inst + ( (chip % static_cast(std::pow(2, 16))) << 109 ); + inst = inst + ( (ddr_address % static_cast(std::pow(2, 32))) << 77 ); + inst = inst + ( ((ddr_length-1) % static_cast(std::pow(2, 32))) << 45 ); + inst = inst + ( (data_type % static_cast(std::pow(2, 3))) << 42 ); + inst = inst + ( (data_address % static_cast(std::pow(2, 32))) << 10 ); + inst = inst + ( (bool_check % static_cast(std::pow(2, 0))) << 9 ); + return inst; +} + +uint128_t gen_inst_l2(uint16_t chip, uint16_t times) { + uint128_t inst = 0; + inst = inst + ( (0b001 % static_cast(std::pow(2, 3))) << 125 ); + inst = inst + ( (chip % static_cast(std::pow(2, 16))) << 109 ); + inst = inst + ( ((times-1) % static_cast(std::pow(2, 16))) << 93 ); + return inst; +} + +uint8_t get_bit_of_data(uint64_t data, uint32_t start_bit, uint32_t end_bit) { + uint32_t length = start_bit - end_bit + 1; + uint64_t dat_tmp = static_cast(std::pow(2,end_bit)); + uint8_t result = static_cast(data/dat_tmp) % static_cast(pow(2,length)); + return result; +} + +generator_fpga::generator_fpga() +{ + chip_num = NUMBER_OF_CHIP; +} + +void generator_fpga::clear() +{ + inst.clear(); +} + +void generator_fpga::gen_inst(struct Program program,std::vector> task_split,struct memory_allocation_t memory_allocation, + std::vector>> inst,std::vector>> inst_split) +{ + clear(); + __gen_inst_none__(); + _gen_inst_vector_(program,task_split,memory_allocation,inst,inst_split); +} + + +void generator_fpga::_gen_inst_vector_(struct Program program,std::vector> task_split,struct memory_allocation_t memory_allocation, + std::vector>> inst,std::vector>> inst_split) +{ + __gen_inst_pll_lock__(); + __gen_inst_inst_reset__(); + __gen_inst_vector_l1_data_para(memory_allocation); + __gen_inst_l1_wait__(); + for(uint32_t task_split_first = 0; task_split_first < task_split[0].size(); task_split_first++) + { + __gen_inst_inst_reset__(); + for(uint32_t chip_sel = 0; chip_sel < task_split.size(); chip_sel++) + { + if(task_split[chip_sel].size() <= task_split_first){break;} + __gen_inst_vector_l1_data__(chip_sel,task_split_first,memory_allocation); + } + __gen_inst_l1_wait__(); + for(uint32_t chip_sel = 0; chip_sel < task_split.size(); chip_sel++) + { + if(task_split[chip_sel].size() <= task_split_first){break;} + __gen_inst_vector_l1_inst__(chip_sel,task_split_first,memory_allocation); + } + __gen_inst_l1_wait__(); + for(uint32_t chip_sel = 0; chip_sel < task_split.size(); chip_sel++) + { + if(task_split[chip_sel].size() <= task_split_first){break;} + __gen_inst_vector_r_inst_length_first__(chip_sel,task_split_first,inst_split); + } + __gen_inst_vector_r_inst_length_middle__(task_split_first,inst,inst_split,task_split); + __gen_inst_vector_r_inst_length_last__(task_split_first,inst,inst_split); + } + __gen_inst_inst_reset__(); + __gen_inst_i__(); +} + +void generator_fpga::__gen_inst_none__() +{ + uint128_t inst_temp=0; + inst_temp = gen_inst_none(); + inst.push_back(inst_temp); +} + +void generator_fpga::__gen_inst_pll_lock__() +{ + uint128_t inst_temp=0; + uint32_t inst_r_chip = 0; + uint32_t inst_r_data_type = w_reg; + uint32_t inst_r_data_address = ADDRESS_PIN_MUX; + uint32_t inst_r_data = DATA_PIN_MUX_PAD; + for(uint16_t chip_sel=0 ; chip_sel < NUMBER_OF_CHIP ; chip_sel++) + { + inst_r_chip += CHIP_TABLE[CHIP_ORDER[chip_sel]]; + } + inst_temp = gen_inst_r(inst_r_chip,inst_r_data_type,inst_r_data_address,inst_r_data); + inst.push_back(inst_temp); +} + + +void generator_fpga::__gen_inst_inst_reset__() +{ + uint128_t inst_temp=0; + uint32_t inst_r_chip = 0; + uint32_t inst_r_data_type = inst_reset; + uint32_t inst_r_data_address = 0x00000000; + uint32_t inst_r_data = 0x00000000; + + for(uint16_t chip_sel=0 ; chip_sel < NUMBER_OF_CHIP ; chip_sel++) + { + inst_r_chip += CHIP_TABLE[CHIP_ORDER[chip_sel]]; + } + inst_temp = gen_inst_r(inst_r_chip,inst_r_data_type,inst_r_data_address,inst_r_data); + inst.push_back(inst_temp); +} + +void generator_fpga::__gen_inst_inst_flag__(uint8_t chip_sel,uint32_t length) +{ + uint128_t inst_temp=0; + uint32_t inst_r_chip = 0; + uint32_t inst_r_data_type = w_reg; + uint32_t inst_r_data_address = ADDRESS_INST_FLAG; + uint32_t inst_r_data = length; + inst_r_chip = CHIP_TABLE[CHIP_ORDER[chip_sel]]; + inst_temp = gen_inst_r(inst_r_chip,inst_r_data_type,inst_r_data_address,inst_r_data); + inst.push_back(inst_temp); +} + +void generator_fpga::__gen_inst_w__(uint32_t address,uint32_t data) +{ + uint128_t inst_temp=0; + uint32_t inst_r_chip = 0; + uint32_t inst_r_data_type = w_reg; + uint32_t inst_r_data_address = address; + uint32_t inst_r_data = data; + for(uint16_t chip_sel=0 ; chip_sel < NUMBER_OF_CHIP ; chip_sel++) + { + inst_r_chip += CHIP_TABLE[CHIP_ORDER[chip_sel]]; + } + inst_temp = gen_inst_r(inst_r_chip,inst_r_data_type,inst_r_data_address,inst_r_data); + inst.push_back(inst_temp); +} + + +void generator_fpga::__gen_inst_l1_wait__() +{ + #if 0 + uint128_t inst_temp=0; + uint16_t chip = 0; + uint32_t ddr_address = 0; + uint32_t ddr_length = 0; + uint8_t data_type = 0; + uint32_t data_address = 0; + bool bool_check =1; + + inst_temp = gen_inst_l1(chip,ddr_address,ddr_length,data_type,data_address,bool_check); + inst.push_back(inst_temp); + #endif +} + + +void generator_fpga::__gen_inst_vector_l1_data_para(struct memory_allocation_t __memory_allocation) +{ + uint128_t inst_temp=0; + uint32_t inst_l1_chip = 0; + uint32_t inst_l1_ddr_address = __memory_allocation.in_para_ddr_address_total; + uint32_t inst_l1_ddr_length = __memory_allocation.in_para_ddr_length_total; + uint8_t inst_l1_data_type = w_data; + uint32_t inst_l1_data_address = 0x00000000; + for(uint16_t chip_sel=0 ; chip_sel < NUMBER_OF_CHIP ; chip_sel++) + { + inst_l1_chip += CHIP_TABLE[CHIP_ORDER[chip_sel]]; + } + inst_temp = gen_inst_l1(inst_l1_chip,inst_l1_ddr_address,inst_l1_ddr_length,inst_l1_data_type,inst_l1_data_address); + inst.push_back(inst_temp); +} + +void generator_fpga::__gen_inst_vector_l1_data__(uint16_t chip,uint32_t task_split_first,struct memory_allocation_t __memory_allocation) +{ + uint128_t inst_temp=0; + uint32_t last_l1_data_address = __memory_allocation.in_para_ddr_length_total * GLOBAL_BUFFER_WIDTH / LOCAL_BUFFER_WIDTH ; + uint32_t inst_l1_chip = CHIP_TABLE[CHIP_ORDER[chip]]; + uint32_t inst_l1_ddr_address = 0; + uint32_t inst_l1_ddr_length = 0; + uint8_t inst_l1_data_type = w_data; + uint32_t inst_l1_data_address = 0; + for(uint32_t i_of_op = 0; i_of_op< __memory_allocation.in_dat_ddr_mem_alloc.size(); i_of_op++) + { + inst_l1_ddr_address = __memory_allocation.in_dat_ddr_mem_alloc[i_of_op].in_data_detail[chip][task_split_first].out_ddr_addr; + inst_l1_ddr_length = __memory_allocation.in_dat_ddr_mem_alloc[i_of_op].in_data_detail[chip][task_split_first].out_ddr_length; + inst_l1_data_address = (int)last_l1_data_address; + inst_temp = gen_inst_l1(inst_l1_chip,inst_l1_ddr_address,inst_l1_ddr_length,inst_l1_data_type,inst_l1_data_address); + inst.push_back(inst_temp); + last_l1_data_address += inst_l1_ddr_length*GLOBAL_BUFFER_WIDTH/LOCAL_BUFFER_WIDTH; + } +} + +void generator_fpga::__gen_inst_vector_l1_inst__(uint16_t chip,uint32_t task_split_first,struct memory_allocation_t __memory_allocation) +{ + uint128_t inst_temp=0; + uint32_t inst_l1_chip = CHIP_TABLE[CHIP_ORDER[chip]]; + uint32_t inst_l1_ddr_address = __memory_allocation.inst_detail[chip][task_split_first].out_ddr_addr; + uint32_t inst_l1_ddr_length = __memory_allocation.inst_detail[chip][task_split_first].out_ddr_length; + uint8_t inst_l1_data_type = w_inst; + uint32_t inst_l1_data_address = 0x00000000; + inst_temp = gen_inst_l1(inst_l1_chip,inst_l1_ddr_address,inst_l1_ddr_length,inst_l1_data_type,inst_l1_data_address); + inst.push_back(inst_temp); +} + +void generator_fpga::__gen_inst_vector_r_inst_length_first__(uint16_t chip,uint32_t task_split_first,std::vector>> __inst_split_asic) +{ + uint128_t inst_temp=0; + uint32_t inst_r_chip = CHIP_TABLE[CHIP_ORDER[chip]]; + uint32_t inst_r_data_address = 0; + uint8_t inst_r_data_type = w_inst_length; + uint32_t inst_r_data = __inst_split_asic[chip][task_split_first][0]; + inst_temp = gen_inst_r(inst_r_chip,inst_r_data_type,inst_r_data_address,inst_r_data); + inst.push_back(inst_temp); +} + +void generator_fpga::__gen_inst_vector_r_inst_length_middle__(uint32_t task_split_first,std::vector>> __inst_asic,std::vector>> __inst_split_asic,std::vector> __task_split) +{ + uint128_t inst_temp=0; + uint32_t inst_r_chip = 0; + + uint32_t inst_r_data_address = 0; + uint8_t inst_r_data_type = w_inst; + uint32_t inst_r_data = 0x00000000; + + uint32_t inst_length_min_true = 0; + uint32_t inst_length_min_false = 0; + + inst_length_min_true = __inst_asic[0][task_split_first].size(); + inst_length_min_false = __inst_asic[0][task_split_first].size(); + min_chip_true = 0; + min_chip_false = 0; + + for(uint32_t chip_sel = 0; chip_sel < NUMBER_OF_CHIP; chip_sel++) + { + if(__task_split[chip_sel].size() <= task_split_first) + { + break; + } + if(__inst_asic[chip_sel][task_split_first].size() < inst_length_min_true) + { + inst_length_min_true = __inst_asic[chip_sel][task_split_first].size(); + min_chip_true = chip_sel; + } + if(__inst_asic[chip_sel][task_split_first].size() <= inst_length_min_false) + { + inst_length_min_false = __inst_asic[chip_sel][task_split_first].size(); + min_chip_false = chip_sel; + } + } + + for(uint16_t chip_sel=0 ; chip_sel < min_chip_false+1 ; chip_sel++) + { + inst_r_chip += CHIP_TABLE[CHIP_ORDER[chip_sel]]; + } + inst_r_data_type = w_inst_length; + inst_r_data_address = 0x00000000; + inst_r_data = __inst_asic[min_chip_false][task_split_first].size() - __inst_split_asic[min_chip_false][task_split_first][0]; + inst_temp = gen_inst_r(inst_r_chip,inst_r_data_type,inst_r_data_address,inst_r_data); + inst.push_back(inst_temp); + + uint32_t inst_l2_chip = inst_r_chip; + uint32_t inst_times = __inst_split_asic[min_chip_false][task_split_first].size()-1; + inst_temp = gen_inst_l2(inst_l2_chip,inst_times); + inst.push_back(inst_temp); +} + + + +void generator_fpga::__gen_inst_vector_r_inst_length_last__(uint32_t task_split_first,std::vector>> __inst_asic,std::vector>> __inst_split_asic) +{ + uint128_t inst_temp=0; + uint32_t inst_r_chip = 0; + uint32_t inst_l2_chip = 0; + uint32_t inst_times = 0; + uint32_t inst_r_data_address = 0; + uint8_t inst_r_data_type = w_inst_length; + uint32_t inst_r_data = 0x00000000; + + if(min_chip_true!=0) + { + for(uint16_t chip_sel = 0; chip_sel < min_chip_false; chip_sel++) + { + inst_r_chip += CHIP_TABLE[CHIP_ORDER[chip_sel]]; + } + inst_r_data = __inst_asic[0][task_split_first].size() - __inst_asic[min_chip_false][task_split_first].size(); + inst_temp = gen_inst_r(inst_r_chip,inst_r_data_type,inst_r_data_address,inst_r_data); + inst.push_back(inst_temp); + + inst_l2_chip =inst_r_chip; + inst_times = __inst_split_asic[0][task_split_first].size() - __inst_split_asic[min_chip_false][task_split_first].size(); + inst_temp = gen_inst_l2(inst_l2_chip,inst_times); + inst.push_back(inst_temp); + } +} + +void generator_fpga::__gen_inst_vector_r_inst_length_add__(uint16_t chip_sel) +{ + uint128_t inst_temp=0; + uint32_t inst_r_chip = CHIP_TABLE[CHIP_ORDER[chip_sel]]; + uint32_t inst_r_data_address = 0x00000000; + uint8_t inst_r_data_type = w_inst_length; + uint32_t inst_r_data = 1; + inst_temp = gen_inst_r(inst_r_chip,inst_r_data_type,inst_r_data_address,inst_r_data); + inst.push_back(inst_temp); +} + +Compiler::Compiler() +{ + num_of_chip = NUMBER_OF_CHIP; + num_of_pe = NUMBER_OF_PE; + ddr_address = 0; +} + +void Compiler::set_device(uint8_t number_of_chip,uint8_t number_of_pe,uint64_t ddr) +{ + this->num_of_chip= (number_of_chip!=0)?number_of_chip:this->num_of_chip; + this->num_of_pe= (number_of_chip!=0)?number_of_pe:this->num_of_pe; + this->ddr_size= (number_of_chip!=0)?ddr:this->ddr_size; +} + +void Compiler::clear() +{ + inst.clear(); + task_split.clear(); + inst_split.clear(); + inst_byte.clear(); + memory_allocation.out_detail.clear(); + memory_allocation.inst_detail.clear(); + memory_allocation.in_dat_ddr_mem_alloc.clear(); + executor.inst.clear(); + executor.inst_fpga.clear(); + memory_allocation.last_ddr_address = 0; + ddr_address = 0; +} + + +void Compiler::get_k_dat(struct Program program) +{ + if(program.operation_type == "MOD_MUL") + { + _k_dat = { + "MOD_MUL", + 2, + 0, + 2, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}} + } + }; + } + else if(program.operation_type == "PAILLIER_ENC") + { + _k_dat = { + "PAILLIER_ENC", + 5, + 0, + 2, + { + {{"P_BITCOUNT", 0}, {"E_BITCOUNT", 1}}, + {{"P_BITCOUNT", 0.5}, {"E_BITCOUNT", 0}} + } + }; + } + else if(program.operation_type == "MOD_EXP_CONST_E") + { + _k_dat = { + "MOD_EXP_CONST_E", + 2, + 1, + 2, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + } + }; + } + else if(program.operation_type == "MOD_EXP") + { + _k_dat = { + "MOD_EXP", + 2, + 0, + 2, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + {{"P_BITCOUNT", 0}, {"E_BITCOUNT", 1}} + } + }; + } + else if(program.operation_type == "MOD_MUL_CONST") + { + _k_dat = { + "MOD_MUL_CONST", + 3, + 0, + 2, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + } + }; + } + else if(program.operation_type == "MOD_EXP_CONST_A") + { + _k_dat = { + "MOD_EXP_CONST_A", + 3, + 0, + 2, + { + {{"P_BITCOUNT", 0}, {"E_BITCOUNT", 1}}, + } + }; + } + else if(program.operation_type == "MOD_ADD") + { + _k_dat = { + "MOD_ADD", + 1, + 0, + 1, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}} + } + }; + } + else if(program.operation_type == "MOD_ADD_CONST") + { + _k_dat = { + "MOD_ADD_CONST", + 2, + 0, + 1, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + } + }; + } + else if(program.operation_type == "MONT") + { + _k_dat = { + "MONT", + 1, + 0, + 2, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}} + } + }; + } + else if(program.operation_type == "MONT_CONST") + { + _k_dat = { + "MONT_CONST", + 2, + 0, + 2, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + } + }; + } + else if(program.operation_type == "MOD_INV_CONST_P") + { + _k_dat = { + "MOD_INV_CONST_P", + 1, + 0, + 1, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + } + }; + } + +} + +void Compiler::compile() +{ + clear(); + get_k_dat(_program); + _task_split_(_program); + _memory_allocation_(_program); + _gen_fpga_inst_(_program); + check_mem(_program); + _inst_reshape_(_program); + get_executor(); +} + +std::string Compiler::to_string(uint128_t x) +{ + if (x == 0) return "0"; + std::string s = ""; + while (x > 0) { + s += char(x % 10 + '0'); + x /= 10; + } + reverse(s.begin(), s.end()); + return s; +} + +void Compiler::data_inverse(uint8_t* data,uint32_t len) +{ + uint32_t i = 0; + uint8_t tmp; + + for(i=0;i inst_fpga_bytes; + for(__uint128_t &i_dat:_generator_fpga.inst) + { + memcpy(bytes, &i_dat, sizeof(i_dat)); + data_inverse(bytes,16); + for(uint32_t i = 0;i<16;i++) + { + inst_fpga_bytes.push_back(bytes[i]); + } + + } + executor.inst= inst_byte; + executor.inst_fpga = inst_fpga_bytes; + executor.in_para_address = memory_allocation.in_para_ddr_address_total; + executor.inst_address = memory_allocation.inst_ddr_address_total; + executor.out_address = memory_allocation.out_ddr_address_total; + executor.out_length = memory_allocation.out_ddr_length_total; +} + +void Compiler::__task_split_vector__(struct Program program) +{ + float x=(float)(program.vec_size)/(float)MAX_ON_CHIP_VECTOR; + uint32_t task_num = std::ceil(x); + uint32_t task_left = program.vec_size; + std::vector> _task_spilit(num_of_chip,std::vector()); + uint32_t chip_sel = 0; + uint32_t task = 0; + for(uint32_t i = 0;i>> m; + uint32_t _p_bitcount = program.p_bitcount; + uint32_t task_num = 0; + uint32_t ddr_length_total = 0; + uint32_t ddr_length = 0; + uint32_t ddr_address_total= ddr_address; + + memory_allocation.out_detail.clear(); + memory_allocation.inst_detail.clear(); + memory_allocation.in_dat_ddr_mem_alloc.clear(); + + for (uint32_t chip = 0; chip < task_split.size(); chip++) { + std::vector> first_vector; + for (uint32_t first = 0; first < task_split[chip].size(); first++) { + std::vector inner_vector(std::ceil(float(task_split[chip][first]) / float(NUMBER_OF_PE))); + first_vector.push_back(inner_vector); + } + m.push_back(first_vector); + } + + for (uint32_t chip = 0; chip < m.size(); chip++) { + for (uint32_t first = 0; first < m[chip].size(); first++) { + for (uint32_t i = 0; i < m[chip][first].size(); i++) { + + if(i ==std::ceil((float)(task_split[chip][first]) / (float)(NUMBER_OF_PE)) - 1){ + task_num = (task_split[chip][first] % NUMBER_OF_PE != 0) ? (task_split[chip][first] % NUMBER_OF_PE) : NUMBER_OF_PE; + } + else{ + task_num = NUMBER_OF_PE; + } + ddr_length = task_num * _p_bitcount / 8; + ddr_length_total += ddr_length; + m[chip][first][i].out_ddr_addr = ddr_address; + m[chip][first][i].out_ddr_length = ddr_length; + ddr_address += ddr_length; + } + } + } + memory_allocation.out_ddr_address_total = ddr_address_total; + memory_allocation.out_ddr_length_total = ddr_length_total; + memory_allocation.out_detail = m; +} + +void Compiler::__memory_allocation_vector_inst__(struct Program program) +{ + uint64_t ddr_length = 0; + uint64_t inst_ddr_address_total = ddr_address; + uint64_t inst_ddr_length_total = 0; + struct out_ddr_detail_t ddr_detail; + + std::vector> memory_allocation_detail(task_split.size(),std::vector()); + + for(uint8_t chip =0;chip < task_split.size();chip++) + { + for(uint32_t task_split_first=0;task_split_first> memory_allocation_detail(task_split.size(),std::vector()); + + try{ + p_b = _k_dat.data_in[i_of_op].at("P_BITCOUNT"); + e_b = _k_dat.data_in[i_of_op].at("E_BITCOUNT"); + }catch (...) {sta = false;} + + uint64_t in_data_ddr_address_total = ddr_address; + uint64_t in_data_ddr_length_total = 0; + for (uint32_t task_split_first = 0; task_split_first < task_split[0].size(); task_split_first++) { + for (uint32_t chip = 0; chip < task_split.size(); chip++) { + if(task_split[chip].empty()){ + continue; + } + ddr_length = (uint32_t)(p_b * _p_bitcount * task_split[chip][task_split_first] ); + ddr_length += (uint32_t)(e_b * _e_bitcount * task_split[chip][task_split_first]); + ddr_length = ddr_length / 8 ; + in_data_ddr_length_total +=ddr_length; + ddr_detail.out_ddr_addr = ddr_address; + ddr_detail.out_ddr_length = ddr_length; + memory_allocation_detail[chip].push_back(ddr_detail); + ddr_address += int(ddr_length); + } + } + in_data_mem_alloc_t in_data_mem_alloc; + in_data_mem_alloc.in_data_ddr_address_total = in_data_ddr_address_total; + in_data_mem_alloc.in_data_ddr_length_total = in_data_ddr_length_total; + in_data_mem_alloc.in_data_detail = memory_allocation_detail; + + if(i_of_op == 0) + { + memory_allocation.in_dat_ddr_mem_alloc.push_back(in_data_mem_alloc); + } + else{ + memory_allocation.in_dat_ddr_mem_alloc.push_back(in_data_mem_alloc); + } + + memory_allocation.last_ddr_address = ddr_address; + sta = true; + return sta; +} +void Compiler::__gen_asic_inst__(struct Program program) +{ + std::string type = program.type; + if(type == "vector") + { + ___gen_asic_inst_vector___(); + } +} + +void Compiler::___gen_asic_inst_vector___() +{ + std::vector>> _inst; + struct generator_inst_t _g_inst; + for (uint32_t chip = 0; chip < task_split.size(); ++chip) { + std::vector> chip_inst; + for (uint32_t i = 0; i < task_split[chip].size(); ++i) { + std::vector task_inst; + chip_inst.push_back(task_inst); + } + _inst.push_back(chip_inst); + } + + std::vector>> _inst_split; + for (uint32_t chip = 0; chip < task_split.size(); ++chip) { + std::vector> chip_inst; + for (uint32_t i = 0; i < task_split[chip].size(); ++i) { + std::vector task_inst; + chip_inst.push_back(task_inst); + } + _inst_split.push_back(chip_inst); + } + for (uint32_t chip = 0; chip < task_split.size(); chip++) + for (uint32_t task_split_first = 0; task_split_first < task_split[chip].size(); task_split_first++) { + { + _g_inst = _generator_asic.gen_inst(_program,task_split[chip][task_split_first],memory_allocation.out_detail[chip][task_split_first]); + _inst[chip][task_split_first] = _g_inst.inst; + _inst_split[chip][task_split_first] = _g_inst.inst_length; + } + } + inst = _inst; + inst_split = _inst_split; +} + +void Compiler::_gen_fpga_inst_(struct Program program) +{ + std::string type = program.type; + if(type == "vector") + { + _generator_fpga.gen_inst(program,task_split,memory_allocation,inst,inst_split); + } +} + +void Compiler::check_mem(struct Program program) +{ + std::string type = program.type; + std::stringstream ss; + uint64_t ddr_depth = (uint64_t)2*1024*1024*1024; + if(memory_allocation.last_ddr_address >= (uint64_t)ddr_depth) + { + ss << "memory allocation address >= 1024*1024*1024*2!"; + throw std::runtime_error(ss.str()); + } + + if(_generator_fpga.inst.size() > GLOBAL_INST_BUFFER_DEPTH) + { + ss << "inst fpga length > GLOBAL_INST_BUFFER_DEPTH!"; + throw std::runtime_error(ss.str()); + } +} + +void Compiler::_inst_reshape_(struct Program program) +{ + std::string type = program.type; + uint32_t length_of_inst = LOCAL_INST_BUFFER_WIDTH; + uint32_t bit_of_bit = 8; + if(type == "vector") + { + for(uint8_t chip=0;chip LOAD_PARA + inst_temp = gen_inst_c(IDLE,CAL,LOAD_PARA,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: LOAD_PARA -> ELE_MOD_MUL + inst_temp = gen_inst_c(LOAD_PARA,CAL,ELE_MOD_MUL,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: ELE_MOD_MUL and WAIT + inst_temp = gen_inst_c(ELE_MOD_MUL,WAIT,ELE_MOD_MUL,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // L : a + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(load_a_address,inst_l_length,pe_0,a,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + load_a_address +=inst_l_length*inst_l_times; + + // L : b + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(load_b_address,inst_l_length,pe_0,b,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + load_b_address +=inst_l_length*inst_l_times; + + // I + inst_temp = gen_inst_i(ddr_address,ddr_length); + inst.push_back(inst_temp); + + // inst split + inst_length.push_back(inst.size()-inst_length_last); + inst_length_last = inst.size(); + + // C : ELE_MOD_MUL -> VEC_OUTPUT_RESULTS + inst_temp = gen_inst_c(ELE_MOD_MUL,CAL,VEC_OUTPUT_RESULTS,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C : VEC_OUTPUT_RESULTS -> IDLE + inst_temp = gen_inst_c(VEC_OUTPUT_RESULTS,CAL,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); +} + +void generator_mod_mul::_inst_gen_end_(uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: IDLE and WAIT + actual_pe = 0; + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + inst_length.push_back(inst.size()-inst_length_last); +} + +struct generator_inst_t generator_mod_mul::gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) +{ + uint32_t _task_split = task_split; + uint64_t ddr_address = 0; + uint64_t ddr_length = 0; + uint8_t actual_pe = 0; + struct generator_inst_t _inst_tmp; + this->clear(); + task_split_n = task_split; + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_start_(p_bitcount,e_bitcount,actual_pe); + uint32_t number_of_outs = std::ceil((float)_task_split/(float)NUMBER_OF_PE); + for(uint8_t t = 0;t < number_of_outs;t++) + { + ddr_address = mem_alloc[t].out_ddr_addr; + ddr_length = mem_alloc[t].out_ddr_length; + if(t==number_of_outs-1) + { + actual_pe = ((_task_split%NUMBER_OF_PE)==0)?NUMBER_OF_PE:(_task_split%NUMBER_OF_PE); + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + else + { + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + } + _inst_gen_end_(actual_pe); + check_inst_sram_depth(inst); + _inst_tmp.inst = inst; + _inst_tmp.inst_length = inst_length; + return _inst_tmp; +} + +generator_mod_mul_const::generator_mod_mul_const() +{ + _k_dat = { + "MOD_MUL_CONST", + 3, + 0, + 2, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + } + }; +} + +void generator_mod_mul_const::clear() +{ + inst.clear(); + task_split.clear(); + inst_length.clear(); + inst_c_address = DATA_OUT_TO_OUT_CTRL; + inst_l_address = 0; + task_split_n = 0; + load_a_address = 0; + load_b_address = 0; + inst_length_last = 0; +} + +void generator_mod_mul_const::_inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + // L: param + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,param,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: p + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,p,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: n_prime + inst_l_length = 1; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,n_prime,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L : rsquare + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,r_square,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L : b + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = 1; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,b,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length*inst_l_times; +} + +void generator_mod_mul_const::_inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) +{ + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + uint64_t inst_temp = 0; + // C: IDLE -> LOAD_PARA + inst_temp = gen_inst_c(IDLE,CAL,LOAD_PARA,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: LOAD_PARA -> ELE_MOD_MUL + inst_temp = gen_inst_c(LOAD_PARA,CAL,ELE_MOD_MUL,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: ELE_MOD_MUL and WAIT + inst_temp = gen_inst_c(ELE_MOD_MUL,WAIT,ELE_MOD_MUL,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // L : a + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,pe_0,a,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length*inst_l_times; + + // I + inst_temp = gen_inst_i(ddr_address,ddr_length); + inst.push_back(inst_temp); + + // inst split + inst_length.push_back(inst.size()-inst_length_last); + inst_length_last = inst.size(); + // C : ELE_MOD_MUL -> VEC_OUTPUT_RESULTS + inst_temp = gen_inst_c(ELE_MOD_MUL,CAL,VEC_OUTPUT_RESULTS,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C : VEC_OUTPUT_RESULTS -> IDLE + inst_temp = gen_inst_c(VEC_OUTPUT_RESULTS,CAL,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); +} + +void generator_mod_mul_const::_inst_gen_end_(uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: IDLE and WAIT + actual_pe = 0; + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + inst_length.push_back(inst.size()-inst_length_last); +} + +struct generator_inst_t generator_mod_mul_const::gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) +{ + uint32_t _task_split = task_split; + uint64_t ddr_address = 0; + uint64_t ddr_length = 0; + uint8_t actual_pe = 0; + struct generator_inst_t _inst_tmp; + this->clear(); + task_split_n = task_split; + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_start_(p_bitcount,e_bitcount,actual_pe); + uint32_t number_of_outs = std::ceil((float)_task_split/(float)NUMBER_OF_PE); + for(uint8_t t = 0;t < number_of_outs;t++) + { + ddr_address = mem_alloc[t].out_ddr_addr; + ddr_length = mem_alloc[t].out_ddr_length; + if(t==number_of_outs-1) + { + actual_pe = ((_task_split%NUMBER_OF_PE)==0)?NUMBER_OF_PE:(_task_split%NUMBER_OF_PE); + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + else + { + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + } + _inst_gen_end_(actual_pe); + check_inst_sram_depth(inst); + _inst_tmp.inst = inst; + _inst_tmp.inst_length = inst_length; + return _inst_tmp; +} + +generator_mod_exp::generator_mod_exp() +{ + _k_dat = { + "MOD_EXP", + 2, + 0, + 2, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + {{"P_BITCOUNT", 0}, {"E_BITCOUNT", 1}} + } + }; +} + +void generator_mod_exp::clear() +{ + inst.clear(); + task_split.clear(); + inst_length.clear(); + inst_c_address = DATA_OUT_TO_OUT_CTRL; + inst_l_address = 0; + task_split_n = 0; + load_a_address = 0; + load_b_address = 0; + inst_length_last = 0; +} + +void generator_mod_exp::_inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + float p_bitcount_k = 0; + float e_bitcount_k = 0; + + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + // L: param + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,param,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: p + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,p,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: n_prime + inst_l_length = 1; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,n_prime,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L : rsquare + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,r_square,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + p_bitcount_k = _k_dat.data_in[0].at("P_BITCOUNT"); + e_bitcount_k = _k_dat.data_in[0].at("E_BITCOUNT"); + load_a_address = inst_l_address; + load_b_address = inst_l_address + task_split_n*(p_bitcount_k*p_bitcount/LOCAL_BUFFER_WIDTH + e_bitcount_k*e_bitcount/LOCAL_BUFFER_WIDTH ); +} + +void generator_mod_exp::_inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) +{ + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + uint64_t inst_temp = 0; + // C: IDLE -> LOAD_PARA + inst_temp = gen_inst_c(IDLE,CAL,LOAD_PARA,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: LOAD_PARA -> ELE_MOD_EXP + inst_temp = gen_inst_c(LOAD_PARA,CAL,ELE_MOD_EXP,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: ELE_MOD_EXP and WAIT + inst_temp = gen_inst_c(ELE_MOD_EXP,WAIT,ELE_MOD_EXP,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // L : a + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(load_a_address,inst_l_length,pe_0,a,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + load_a_address +=inst_l_length*inst_l_times; + + // L : b + inst_l_length = e_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(load_b_address,inst_l_length,pe_0,b,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + load_b_address +=inst_l_length*inst_l_times; + + // I + inst_temp = gen_inst_i(ddr_address,ddr_length); + inst.push_back(inst_temp); + + // inst split + inst_length.push_back(inst.size()-inst_length_last); + inst_length_last = inst.size(); + // C : ELE_MOD_EXP -> VEC_OUTPUT_RESULTS + inst_temp = gen_inst_c(ELE_MOD_EXP,CAL,VEC_OUTPUT_RESULTS,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C : VEC_OUTPUT_RESULTS -> IDLE + inst_temp = gen_inst_c(VEC_OUTPUT_RESULTS,CAL,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); +} + +void generator_mod_exp::_inst_gen_end_(uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: IDLE and WAIT + actual_pe = 0; + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + inst_length.push_back(inst.size()-inst_length_last); +} + +struct generator_inst_t generator_mod_exp::gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) +{ + uint32_t _task_split = task_split; + uint64_t ddr_address = 0; + uint64_t ddr_length = 0; + uint8_t actual_pe = 0; + struct generator_inst_t _inst_tmp; + this->clear(); + task_split_n = task_split; + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_start_(p_bitcount,e_bitcount,actual_pe); + uint32_t number_of_outs = std::ceil((float)_task_split/(float)NUMBER_OF_PE); + for(uint8_t t = 0;t < number_of_outs;t++) + { + ddr_address = mem_alloc[t].out_ddr_addr; + ddr_length = mem_alloc[t].out_ddr_length; + if(t==number_of_outs-1) + { + actual_pe = ((_task_split%NUMBER_OF_PE)==0)?NUMBER_OF_PE:(_task_split%NUMBER_OF_PE); + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + else + { + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + } + _inst_gen_end_(actual_pe); + check_inst_sram_depth(inst); + _inst_tmp.inst = inst; + _inst_tmp.inst_length = inst_length; + return _inst_tmp; +} + +generator_mod_inv_const_p::generator_mod_inv_const_p() +{ + _k_dat = { + "MOD_INV_CONST_P", + 1, + 0, + 1, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + } + }; +} + +void generator_mod_inv_const_p::clear() +{ + inst.clear(); + task_split.clear(); + inst_length.clear(); + inst_c_address = DATA_OUT_TO_OUT_CTRL; + inst_l_address = 0; + task_split_n = 0; + load_a_address = 0; + load_b_address = 0; + inst_length_last = 0; +} + +void generator_mod_inv_const_p::_inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + // L: param + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,param,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: p + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,p,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; +} + +void generator_mod_inv_const_p::_inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) +{ + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + uint64_t inst_temp = 0; + // C: IDLE -> LOAD_PARA + inst_temp = gen_inst_c(IDLE,CAL,LOAD_PARA,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: LOAD_PARA -> ELE_MOD_INV_P + inst_temp = gen_inst_c(LOAD_PARA,CAL,ELE_MOD_INV_P,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: ELE_MOD_INV_P and WAIT + inst_temp = gen_inst_c(ELE_MOD_INV_P,WAIT,ELE_MOD_INV_P,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // L : a + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,pe_0,a,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length*inst_l_times; + + // I + inst_temp = gen_inst_i(ddr_address,ddr_length); + inst.push_back(inst_temp); + + // inst split + inst_length.push_back(inst.size()-inst_length_last); + inst_length_last = inst.size(); + // C : ELE_MOD_INV_P -> VEC_OUTPUT_RESULTS + inst_temp = gen_inst_c(ELE_MOD_INV_P,CAL,VEC_OUTPUT_RESULTS,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C : VEC_OUTPUT_RESULTS -> IDLE + inst_temp = gen_inst_c(VEC_OUTPUT_RESULTS,CAL,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); +} + +void generator_mod_inv_const_p::_inst_gen_end_(uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: IDLE and WAIT + actual_pe = 0; + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + inst_length.push_back(inst.size()-inst_length_last); +} + +struct generator_inst_t generator_mod_inv_const_p::gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) +{ + uint32_t _task_split = task_split; + uint64_t ddr_address = 0; + uint64_t ddr_length = 0; + uint8_t actual_pe = 0; + struct generator_inst_t _inst_tmp; + this->clear(); + task_split_n = task_split; + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + + _inst_gen_start_(p_bitcount,e_bitcount,actual_pe); + uint32_t number_of_outs = std::ceil((float)_task_split/(float)NUMBER_OF_PE); + + for(uint8_t t = 0;t < number_of_outs;t++) + { + ddr_address = mem_alloc[t].out_ddr_addr; + ddr_length = mem_alloc[t].out_ddr_length; + + if(t==number_of_outs-1) + { + actual_pe = ((_task_split%NUMBER_OF_PE)==0)?NUMBER_OF_PE:(_task_split%NUMBER_OF_PE); + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + else + { + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + } + _inst_gen_end_(actual_pe); + check_inst_sram_depth(inst); + _inst_tmp.inst = inst; + _inst_tmp.inst_length = inst_length; + return _inst_tmp; +} + +generator_paillier_encrypt::generator_paillier_encrypt() +{ + _k_dat = { + "PAILLIER_ENC", + 5, + 0, + 2, + { + {{"P_BITCOUNT", 0}, {"E_BITCOUNT", 1}}, + {{"P_BITCOUNT", 0.5}, {"E_BITCOUNT", 0}} + } + }; +} + +void generator_paillier_encrypt::clear() +{ + inst.clear(); + task_split.clear(); + inst_length.clear(); + inst_c_address = DATA_OUT_TO_OUT_CTRL; + inst_l_address = 0; + task_split_n = 0; + load_a_address = 0; + load_b_address = 0; + inst_length_last = 0; +} + +void generator_paillier_encrypt::_inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + float p_bitcount_k = 0; + float e_bitcount_k = 0; + + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + // L: param + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,param,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: p + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,p,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: n_prime + inst_l_length = 1; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,n_prime,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L : rsquare + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,r_square,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + + // L : r_mont + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,r_mont,inst_l_times,NOCHANGE,actual_pe); + + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + address_n = inst_l_address; + address_g = address_n + inst_l_length; + inst_l_address = address_g + inst_l_length; + + + p_bitcount_k = _k_dat.data_in[0].at("P_BITCOUNT"); + e_bitcount_k = _k_dat.data_in[0].at("E_BITCOUNT"); + load_a_address = inst_l_address; + load_b_address = inst_l_address + (uint32_t)(task_split_n*(p_bitcount_k*p_bitcount/LOCAL_BUFFER_WIDTH + e_bitcount_k*e_bitcount/LOCAL_BUFFER_WIDTH )); +} + +void generator_paillier_encrypt::_inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) +{ + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + uint64_t inst_temp = 0; + // C: IDLE -> LOAD_PARA + inst_temp = gen_inst_c(IDLE,CAL,LOAD_PARA,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: LOAD_PARA -> ELE_MOD_EXP_INV_EXP + inst_temp = gen_inst_c(LOAD_PARA,CAL,ELE_MOD_EXP_INV_EXP,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: ELE_MOD_EXP_INV_EXP and WAIT + inst_temp = gen_inst_c(ELE_MOD_EXP_INV_EXP,WAIT,ELE_MOD_EXP_INV_EXP,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // L : g + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = 1; + inst_temp = gen_inst_l(address_g,inst_l_length,all,a,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length*inst_l_times; + + // L : plaintext + inst_l_length = e_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(load_a_address,inst_l_length,pe_0,b,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + load_a_address +=inst_l_length*inst_l_times; + + // C: ELE_MOD_EXP_INV_EXP -> ELE_MOD_EXP_INV_MUL + inst_temp = gen_inst_c(ELE_MOD_EXP_INV_EXP,CAL,ELE_MOD_EXP_INV_MUL,inst_c_address,actual_pe); + inst.push_back(inst_temp); + // C: ELE_MOD_EXP_INV_MUL -> ELE_MOD_EXP_INV_EXP + inst_temp = gen_inst_c(ELE_MOD_EXP_INV_MUL,CAL,ELE_MOD_EXP_INV_EXP,inst_c_address,actual_pe); + inst.push_back(inst_temp); + // C: ELE_MOD_EXP_INV_EXP and WAIT + inst_temp = gen_inst_c(ELE_MOD_EXP_INV_EXP,WAIT,ELE_MOD_EXP_INV_EXP,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // L : r + inst_l_length = (p_bitcount/ 2) / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(load_b_address,inst_l_length,pe_0,a,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + load_b_address +=inst_l_length*inst_l_times; + + // L : n + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = 1; + inst_temp = gen_inst_l(address_n,inst_l_length,all,b,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + + // C: ELE_MOD_EXP_INV_EXP -> ELE_MOD_EXP_INV_MUL + inst_temp = gen_inst_c(ELE_MOD_EXP_INV_EXP,CAL,ELE_MOD_EXP_INV_MUL,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: ELE_MOD_EXP_INV_MUL -> ELE_MOD_EXP_INV_COM + inst_temp = gen_inst_c(ELE_MOD_EXP_INV_MUL,CAL,ELE_MOD_EXP_INV_COM,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // I + inst_temp = gen_inst_i(ddr_address,ddr_length); + inst.push_back(inst_temp); + + // inst split + inst_length.push_back(inst.size()-inst_length_last); + inst_length_last = inst.size(); + // C : ELE_MOD_EXP_INV_COM -> VEC_OUTPUT_RESULTS + inst_temp = gen_inst_c(ELE_MOD_EXP_INV_COM,CAL,VEC_OUTPUT_RESULTS,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C : VEC_OUTPUT_RESULTS -> IDLE + inst_temp = gen_inst_c(VEC_OUTPUT_RESULTS,CAL,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); +} + +void generator_paillier_encrypt::_inst_gen_end_(uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: IDLE and WAIT + actual_pe = 0; + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + inst_length.push_back(inst.size()-inst_length_last); +} + +struct generator_inst_t generator_paillier_encrypt::gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) +{ + uint32_t _task_split = task_split; + uint64_t ddr_address = 0; + uint64_t ddr_length = 0; + uint8_t actual_pe = 0; + struct generator_inst_t _inst_tmp; + this->clear(); + task_split_n = task_split; + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + + _inst_gen_start_(p_bitcount,e_bitcount,actual_pe); + uint32_t number_of_outs = std::ceil((float)_task_split/(float)NUMBER_OF_PE); + for(uint8_t t = 0;t < number_of_outs;t++) + { + ddr_address = mem_alloc[t].out_ddr_addr; + ddr_length = mem_alloc[t].out_ddr_length; + + if(t==number_of_outs-1) + { + actual_pe = ((_task_split%NUMBER_OF_PE)==0)?NUMBER_OF_PE:(_task_split%NUMBER_OF_PE); + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + else + { + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + } + _inst_gen_end_(actual_pe); + check_inst_sram_depth(inst); + _inst_tmp.inst = inst; + _inst_tmp.inst_length = inst_length; + return _inst_tmp; +} + +generator_mod_exp_const_e::generator_mod_exp_const_e() +{ + _k_dat = { + "MOD_EXP_CONST_E", + 2, + 1, + 2, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + } + }; +} + +void generator_mod_exp_const_e::clear() +{ + inst.clear(); + task_split.clear(); + inst_length.clear(); + inst_c_address = DATA_OUT_TO_OUT_CTRL; + inst_l_address = 0; + task_split_n = 0; + load_a_address = 0; + load_b_address = 0; + inst_length_last = 0; +} + +void generator_mod_exp_const_e::_inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + // L: param + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,param,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: p + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,p,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: n_prime + inst_l_length = 1; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,n_prime,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L : rsquare + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,r_square,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L : b + inst_l_length = e_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = 1; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,b,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length*inst_l_times; +} + +void generator_mod_exp_const_e::_inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) +{ + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + uint64_t inst_temp = 0; + // C: IDLE -> LOAD_PARA + inst_temp = gen_inst_c(IDLE,CAL,LOAD_PARA,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: LOAD_PARA -> ELE_MOD_EXP + inst_temp = gen_inst_c(LOAD_PARA,CAL,ELE_MOD_EXP,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: ELE_MOD_EXP and WAIT + inst_temp = gen_inst_c(ELE_MOD_EXP,WAIT,ELE_MOD_EXP,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // L : a + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,pe_0,a,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length*inst_l_times; + + // I + inst_temp = gen_inst_i(ddr_address,ddr_length); + inst.push_back(inst_temp); + + // inst split + inst_length.push_back(inst.size()-inst_length_last); + inst_length_last = inst.size(); + // C : ELE_MOD_EXP -> VEC_OUTPUT_RESULTS + inst_temp = gen_inst_c(ELE_MOD_EXP,CAL,VEC_OUTPUT_RESULTS,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C : VEC_OUTPUT_RESULTS -> IDLE + inst_temp = gen_inst_c(VEC_OUTPUT_RESULTS,CAL,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); +} + +void generator_mod_exp_const_e::_inst_gen_end_(uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: IDLE and WAIT + actual_pe = 0; + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + inst_length.push_back(inst.size()-inst_length_last); +} + +struct generator_inst_t generator_mod_exp_const_e::gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) +{ + uint32_t _task_split = task_split; + uint64_t ddr_address = 0; + uint64_t ddr_length = 0; + uint8_t actual_pe = 0; + struct generator_inst_t _inst_tmp; + this->clear(); + task_split_n = task_split; + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_start_(p_bitcount,e_bitcount,actual_pe); + uint32_t number_of_outs = std::ceil((float)_task_split/(float)NUMBER_OF_PE); + for(uint8_t t = 0;t < number_of_outs;t++) + { + ddr_address = mem_alloc[t].out_ddr_addr; + ddr_length = mem_alloc[t].out_ddr_length; + + if(t==number_of_outs-1) + { + actual_pe = ((_task_split%NUMBER_OF_PE)==0)?NUMBER_OF_PE:(_task_split%NUMBER_OF_PE); + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + else + { + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + } + _inst_gen_end_(actual_pe); + check_inst_sram_depth(inst); + _inst_tmp.inst = inst; + _inst_tmp.inst_length = inst_length; + return _inst_tmp; +} + +generator_mod_exp_const_a::generator_mod_exp_const_a() +{ + _k_dat = { + "MOD_EXP_CONST_A", + 3, + 0, + 2, + { + {{"P_BITCOUNT", 0}, {"E_BITCOUNT", 1}}, + } + }; +} + +void generator_mod_exp_const_a::clear() +{ + inst.clear(); + task_split.clear(); + inst_length.clear(); + inst_c_address = DATA_OUT_TO_OUT_CTRL; + inst_l_address = 0; + task_split_n = 0; + load_a_address = 0; + load_b_address = 0; + inst_length_last = 0; +} + +void generator_mod_exp_const_a::_inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + // L: param + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,param,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + + // L: p + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,p,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: n_prime + inst_l_length = 1; + + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,n_prime,inst_l_times,NOCHANGE,actual_pe); + + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L : rsquare + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,r_square,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + + // L : a + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = 1; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,a,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length*inst_l_times; +} + +void generator_mod_exp_const_a::_inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) +{ + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + uint64_t inst_temp = 0; + // C: IDLE -> LOAD_PARA + inst_temp = gen_inst_c(IDLE,CAL,LOAD_PARA,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: LOAD_PARA -> ELE_MOD_EXP + inst_temp = gen_inst_c(LOAD_PARA,CAL,ELE_MOD_EXP,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: ELE_MOD_EXP and WAIT + inst_temp = gen_inst_c(ELE_MOD_EXP,WAIT,ELE_MOD_EXP,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // L : b + inst_l_length = e_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,pe_0,b,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length*inst_l_times; + + // I + inst_temp = gen_inst_i(ddr_address,ddr_length); + inst.push_back(inst_temp); + + // inst split + inst_length.push_back(inst.size()-inst_length_last); + inst_length_last = inst.size(); + // C : ELE_MOD_EXP -> VEC_OUTPUT_RESULTS + inst_temp = gen_inst_c(ELE_MOD_EXP,CAL,VEC_OUTPUT_RESULTS,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C : VEC_OUTPUT_RESULTS -> IDLE + inst_temp = gen_inst_c(VEC_OUTPUT_RESULTS,CAL,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); +} + +void generator_mod_exp_const_a::_inst_gen_end_(uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: IDLE and WAIT + actual_pe = 0; + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + inst_length.push_back(inst.size()-inst_length_last); +} + +struct generator_inst_t generator_mod_exp_const_a::gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) +{ + uint32_t _task_split = task_split; + uint64_t ddr_address = 0; + uint64_t ddr_length = 0; + uint8_t actual_pe = 0; + struct generator_inst_t _inst_tmp; + this->clear(); + task_split_n = task_split; + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_start_(p_bitcount,e_bitcount,actual_pe); + uint32_t number_of_outs = std::ceil((float)_task_split/(float)NUMBER_OF_PE); + + for(uint8_t t = 0;t < number_of_outs;t++) + { + ddr_address = mem_alloc[t].out_ddr_addr; + ddr_length = mem_alloc[t].out_ddr_length; + if(t==number_of_outs-1) + { + actual_pe = ((_task_split%NUMBER_OF_PE)==0)?NUMBER_OF_PE:(_task_split%NUMBER_OF_PE); + + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + else + { + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + } + _inst_gen_end_(actual_pe); + check_inst_sram_depth(inst); + _inst_tmp.inst = inst; + _inst_tmp.inst_length = inst_length; + return _inst_tmp; +} + +generator_mod_add::generator_mod_add() +{ + _k_dat = { + "MOD_ADD", + 1, + 0, + 1, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}} + } + }; +} + +void generator_mod_add::clear() +{ + inst.clear(); + task_split.clear(); + inst_length.clear(); + inst_c_address = DATA_OUT_TO_OUT_CTRL; + inst_l_address = 0; + task_split_n = 0; + load_a_address = 0; + load_b_address = 0; + inst_length_last = 0; +} + +void generator_mod_add::_inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + float p_bitcount_k = 0; + float e_bitcount_k = 0; + + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + // L: param + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,param,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: p + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,p,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + p_bitcount_k = _k_dat.data_in[0].at("P_BITCOUNT"); + e_bitcount_k = _k_dat.data_in[0].at("E_BITCOUNT"); + load_a_address = inst_l_address; + load_b_address = inst_l_address + task_split_n*(p_bitcount_k*p_bitcount/LOCAL_BUFFER_WIDTH + e_bitcount_k*e_bitcount/LOCAL_BUFFER_WIDTH ); +} + +void generator_mod_add::_inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) +{ + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + uint64_t inst_temp = 0; + // C: IDLE -> LOAD_PARA + inst_temp = gen_inst_c(IDLE,CAL,LOAD_PARA,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: LOAD_PARA -> ELE_MOD_INV + inst_temp = gen_inst_c(LOAD_PARA,CAL,ELE_MOD_INV,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: ELE_MOD_INV and WAIT + inst_temp = gen_inst_c(ELE_MOD_INV,WAIT,ELE_MOD_INV,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // L : a + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(load_a_address,inst_l_length,pe_0,a,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + load_a_address +=inst_l_length*inst_l_times; + + // L : b + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(load_b_address,inst_l_length,pe_0,b,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + load_b_address +=inst_l_length*inst_l_times; + + // I + inst_temp = gen_inst_i(ddr_address,ddr_length); + inst.push_back(inst_temp); + + // inst split + inst_length.push_back(inst.size()-inst_length_last); + inst_length_last = inst.size(); + // C : ELE_MOD_INV -> VEC_OUTPUT_RESULTS + inst_temp = gen_inst_c(ELE_MOD_INV,CAL,VEC_OUTPUT_RESULTS,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C : VEC_OUTPUT_RESULTS -> IDLE + inst_temp = gen_inst_c(VEC_OUTPUT_RESULTS,CAL,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); +} + +void generator_mod_add::_inst_gen_end_(uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: IDLE and WAIT + actual_pe = 0; + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + inst_length.push_back(inst.size()-inst_length_last); +} + +struct generator_inst_t generator_mod_add::gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) +{ + uint32_t _task_split = task_split; + uint64_t ddr_address = 0; + uint64_t ddr_length = 0; + uint8_t actual_pe = 0; + struct generator_inst_t _inst_tmp; + this->clear(); + task_split_n = task_split; + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + + _inst_gen_start_(p_bitcount,e_bitcount,actual_pe); + uint32_t number_of_outs = std::ceil((float)_task_split/(float)NUMBER_OF_PE); + + for(uint8_t t = 0;t < number_of_outs;t++) + { + ddr_address = mem_alloc[t].out_ddr_addr; + ddr_length = mem_alloc[t].out_ddr_length; + + if(t==number_of_outs-1) + { + actual_pe = ((_task_split%NUMBER_OF_PE)==0)?NUMBER_OF_PE:(_task_split%NUMBER_OF_PE); + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + else + { + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + } + _inst_gen_end_(actual_pe); + check_inst_sram_depth(inst); + _inst_tmp.inst = inst; + _inst_tmp.inst_length = inst_length; + return _inst_tmp; +} + +void generator_mod_add_const::clear() +{ + inst.clear(); + task_split.clear(); + inst_length.clear(); + inst_c_address = DATA_OUT_TO_OUT_CTRL; + inst_l_address = 0; + task_split_n = 0; + load_a_address = 0; + load_b_address = 0; + inst_length_last = 0; +} + +generator_mod_add_const::generator_mod_add_const() +{ + _k_dat = { + "MOD_ADD", + 2, + 0, + 1, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + } + }; +} + +void generator_mod_add_const::_inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + // L: param + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,param,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + + // L: p + + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,p,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: b + inst_l_times = 1; + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,b,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; +} + +void generator_mod_add_const::_inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) +{ + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + uint64_t inst_temp = 0; + // C: IDLE -> LOAD_PARA + inst_temp = gen_inst_c(IDLE,CAL,LOAD_PARA,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: LOAD_PARA -> ELE_MOD_INV + inst_temp = gen_inst_c(LOAD_PARA,CAL,ELE_MOD_INV,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: ELE_MOD_INV and WAIT + inst_temp = gen_inst_c(ELE_MOD_INV,WAIT,ELE_MOD_INV,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // L : a + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,pe_0,a,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length*inst_l_times; + + // I + inst_temp = gen_inst_i(ddr_address,ddr_length); + inst.push_back(inst_temp); + + // inst split + inst_length.push_back(inst.size()-inst_length_last); + inst_length_last = inst.size(); + // C : ELE_MOD_INV -> VEC_OUTPUT_RESULTS + inst_temp = gen_inst_c(ELE_MOD_INV,CAL,VEC_OUTPUT_RESULTS,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C : VEC_OUTPUT_RESULTS -> IDLE + inst_temp = gen_inst_c(VEC_OUTPUT_RESULTS,CAL,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); +} + +void generator_mod_add_const::_inst_gen_end_(uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: IDLE and WAIT + actual_pe = 0; + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + inst_length.push_back(inst.size()-inst_length_last); +} + +struct generator_inst_t generator_mod_add_const::gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) +{ + uint32_t _task_split = task_split; + uint64_t ddr_address = 0; + uint64_t ddr_length = 0; + uint8_t actual_pe = 0; + struct generator_inst_t _inst_tmp; + this->clear(); + task_split_n = task_split; + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + + _inst_gen_start_(p_bitcount,e_bitcount,actual_pe); + uint32_t number_of_outs = std::ceil((float)_task_split/(float)NUMBER_OF_PE); + + for(uint8_t t = 0;t < number_of_outs;t++) + { + ddr_address = mem_alloc[t].out_ddr_addr; + ddr_length = mem_alloc[t].out_ddr_length; + + if(t==number_of_outs-1) + { + actual_pe = ((_task_split%NUMBER_OF_PE)==0)?NUMBER_OF_PE:(_task_split%NUMBER_OF_PE); + + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + else + { + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + } + _inst_gen_end_(actual_pe); + check_inst_sram_depth(inst); + _inst_tmp.inst = inst; + _inst_tmp.inst_length = inst_length; + return _inst_tmp; +} + +generator_mont::generator_mont() +{ + _k_dat = { + "MONT", + 1, + 0, + 2, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}} + } + }; +} + +void generator_mont::clear() +{ + inst.clear(); + task_split.clear(); + inst_length.clear(); + inst_c_address = DATA_OUT_TO_OUT_CTRL; + inst_l_address = 0; + task_split_n = 0; + load_a_address = 0; + load_b_address = 0; + inst_length_last = 0; +} + +void generator_mont::_inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + float p_bitcount_k = 0; + float e_bitcount_k = 0; + + // C: IDLE and WAIT + + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + // L: param + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,param,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: p + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,p,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: n_prime + inst_l_length = 1; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,n_prime,inst_l_times,NOCHANGE,actual_pe); + + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + + p_bitcount_k = _k_dat.data_in[0].at("P_BITCOUNT"); + e_bitcount_k = _k_dat.data_in[0].at("E_BITCOUNT"); + load_a_address = inst_l_address; + load_b_address = inst_l_address + task_split_n*(p_bitcount_k*p_bitcount/LOCAL_BUFFER_WIDTH + e_bitcount_k*e_bitcount/LOCAL_BUFFER_WIDTH ); +} + +void generator_mont::_inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) +{ + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + uint64_t inst_temp = 0; + // C: IDLE -> LOAD_PARA + inst_temp = gen_inst_c(IDLE,CAL,LOAD_PARA,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: LOAD_PARA -> ELEMENT_RSQUARE + inst_temp = gen_inst_c(LOAD_PARA,CAL,ELEMENT_RSQUARE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: ELEMENT_RSQUARE and WAIT + inst_temp = gen_inst_c(ELEMENT_RSQUARE,WAIT,ELEMENT_RSQUARE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // L : a + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(load_a_address,inst_l_length,pe_0,a,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + load_a_address +=inst_l_length*inst_l_times; + + // L : b + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(load_b_address,inst_l_length,pe_0,r_square,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + load_b_address +=inst_l_length*inst_l_times; + + // I + inst_temp = gen_inst_i(ddr_address,ddr_length); + inst.push_back(inst_temp); + + // inst split + inst_length.push_back(inst.size()-inst_length_last); + inst_length_last = inst.size(); + // C : ELEMENT_RSQUARE -> VEC_OUTPUT_RESULTS + inst_temp = gen_inst_c(ELEMENT_RSQUARE,CAL,VEC_OUTPUT_RESULTS,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C : VEC_OUTPUT_RESULTS -> IDLE + inst_temp = gen_inst_c(VEC_OUTPUT_RESULTS,CAL,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); +} + +void generator_mont::_inst_gen_end_(uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: IDLE and WAIT + actual_pe = 0; + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + inst_length.push_back(inst.size()-inst_length_last); +} + +struct generator_inst_t generator_mont::gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) +{ + uint32_t _task_split = task_split; + uint64_t ddr_address = 0; + uint64_t ddr_length = 0; + uint8_t actual_pe = 0; + struct generator_inst_t _inst_tmp; + this->clear(); + task_split_n = task_split; + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + + _inst_gen_start_(p_bitcount,e_bitcount,actual_pe); + uint32_t number_of_outs = std::ceil((float)_task_split/(float)NUMBER_OF_PE); + + for(uint8_t t = 0;t < number_of_outs;t++) + { + ddr_address = mem_alloc[t].out_ddr_addr; + ddr_length = mem_alloc[t].out_ddr_length; + + if(t==number_of_outs-1) + { + actual_pe = ((_task_split%NUMBER_OF_PE)==0)?NUMBER_OF_PE:(_task_split%NUMBER_OF_PE); + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + else + { + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + } + _inst_gen_end_(actual_pe); + check_inst_sram_depth(inst); + _inst_tmp.inst = inst; + _inst_tmp.inst_length = inst_length; + return _inst_tmp; +} + +generator_mont_const::generator_mont_const() +{ + _k_dat = { + "MONT_CONST", + 2, + 0, + 2, + { + {{"P_BITCOUNT", 1}, {"E_BITCOUNT", 0}}, + } + }; +} + +void generator_mont_const::clear() +{ + inst.clear(); + task_split.clear(); + inst_length.clear(); + inst_c_address = DATA_OUT_TO_OUT_CTRL; + inst_l_address = 0; + task_split_n = 0; + load_a_address = 0; + load_b_address = 0; + inst_length_last = 0; +} + +void generator_mont_const::_inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + // L: param + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,param,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: p + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,p,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L: n_prime + inst_l_length = 1; + + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,n_prime,inst_l_times,NOCHANGE,actual_pe); + + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; + // L : rsquare + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,all,r_square,inst_l_times,NOCHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length; +} + +void generator_mont_const::_inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) +{ + uint32_t inst_l_length =1; + uint32_t inst_l_times = 1; + uint64_t inst_temp = 0; + // C: IDLE -> LOAD_PARA + inst_temp = gen_inst_c(IDLE,CAL,LOAD_PARA,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: LOAD_PARA -> ELEMENT_RSQUARE + inst_temp = gen_inst_c(LOAD_PARA,CAL,ELEMENT_RSQUARE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: ELEMENT_RSQUARE and WAIT + inst_temp = gen_inst_c(ELEMENT_RSQUARE,WAIT,ELEMENT_RSQUARE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // L : a + inst_l_length = p_bitcount / LOCAL_BUFFER_WIDTH; + inst_l_times = actual_pe; + inst_temp = gen_inst_l(inst_l_address,inst_l_length,pe_0,a,inst_l_times,CHANGE,actual_pe); + inst.push_back(inst_temp); + inst_l_address +=inst_l_length*inst_l_times; + + // I + inst_temp = gen_inst_i(ddr_address,ddr_length); + inst.push_back(inst_temp); + + // inst split + inst_length.push_back(inst.size()-inst_length_last); + inst_length_last = inst.size(); + // C : ELEMENT_RSQUARE -> VEC_OUTPUT_RESULTS + inst_temp = gen_inst_c(ELEMENT_RSQUARE,CAL,VEC_OUTPUT_RESULTS,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C : VEC_OUTPUT_RESULTS -> IDLE + inst_temp = gen_inst_c(VEC_OUTPUT_RESULTS,CAL,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); +} + +void generator_mont_const::_inst_gen_end_(uint8_t actual_pe) +{ + uint64_t inst_temp = 0; + // C: IDLE and WAIT + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + // C: IDLE and WAIT + actual_pe = 0; + inst_temp = gen_inst_c(IDLE,WAIT,IDLE,inst_c_address,actual_pe); + inst.push_back(inst_temp); + + inst_length.push_back(inst.size()-inst_length_last); +} + + +struct generator_inst_t generator_mont_const::gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) +{ + uint32_t _task_split = task_split; + uint64_t ddr_address = 0; + uint64_t ddr_length = 0; + uint8_t actual_pe = 0; + struct generator_inst_t _inst_tmp; + this->clear(); + task_split_n = task_split; + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + + _inst_gen_start_(p_bitcount,e_bitcount,actual_pe); + uint32_t number_of_outs = std::ceil((float)_task_split/(float)NUMBER_OF_PE); + + for(uint8_t t = 0;t < number_of_outs;t++) + { + ddr_address = mem_alloc[t].out_ddr_addr; + ddr_length = mem_alloc[t].out_ddr_length; + + if(t==number_of_outs-1) + { + actual_pe = ((_task_split%NUMBER_OF_PE)==0)?NUMBER_OF_PE:(_task_split%NUMBER_OF_PE); + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + else + { + actual_pe = (_task_split>=NUMBER_OF_PE)?NUMBER_OF_PE:_task_split; + _inst_gen_middle_(p_bitcount,e_bitcount,actual_pe,ddr_address,ddr_length); + } + } + _inst_gen_end_(actual_pe); + check_inst_sram_depth(inst); + _inst_tmp.inst = inst; + _inst_tmp.inst_length = inst_length; + return _inst_tmp; +} + +struct generator_inst_t generator_asic::gen_inst(struct Program program,uint32_t task_split,std::vector mem_alloc) +{ + struct generator_inst_t _g_inst; + if(program.operation_type == "MOD_MUL") + { + generator_mod_mul _generator_mod_mul; + _g_inst =_generator_mod_mul.gen_inst(program.p_bitcount,program.e_bitcount,task_split,mem_alloc); + } + else if(program.operation_type == "PAILLIER_ENC") + { + generator_paillier_encrypt _generator_paillier_encrypt; + _g_inst =_generator_paillier_encrypt.gen_inst(program.p_bitcount,program.e_bitcount,task_split,mem_alloc); + } + else if(program.operation_type == "MOD_EXP_CONST_E") + { + generator_mod_exp_const_e _generator_mod_exp_const_e; + _g_inst =_generator_mod_exp_const_e.gen_inst(program.p_bitcount,program.e_bitcount,task_split,mem_alloc); + } + else if(program.operation_type == "MOD_EXP") + { + generator_mod_exp _generator_mod_exp; + _g_inst =_generator_mod_exp.gen_inst(program.p_bitcount,program.e_bitcount,task_split,mem_alloc); + } + else if(program.operation_type == "MOD_MUL_CONST") + { + generator_mod_mul_const _generator_mod_mul_const; + _g_inst =_generator_mod_mul_const.gen_inst(program.p_bitcount,program.e_bitcount,task_split,mem_alloc); + } + else if(program.operation_type == "MOD_EXP_CONST_A") + { + generator_mod_exp_const_a _generator_mod_exp_const_a; + _g_inst =_generator_mod_exp_const_a.gen_inst(program.p_bitcount,program.e_bitcount,task_split,mem_alloc); + } + else if(program.operation_type == "MOD_ADD") + { + generator_mod_add _generator_mod_add; + _g_inst =_generator_mod_add.gen_inst(program.p_bitcount,program.e_bitcount,task_split,mem_alloc); + } + else if(program.operation_type == "MOD_ADD_CONST") + { + generator_mod_add_const _generator_mod_add_const; + _g_inst =_generator_mod_add_const.gen_inst(program.p_bitcount,program.e_bitcount,task_split,mem_alloc); + } + else if(program.operation_type == "MONT") + { + generator_mont _generator_mont; + _g_inst =_generator_mont.gen_inst(program.p_bitcount,program.e_bitcount,task_split,mem_alloc); + } + else if(program.operation_type == "MONT_CONST") + { + generator_mont_const _generator_mont_const; + _g_inst =_generator_mont_const.gen_inst(program.p_bitcount,program.e_bitcount,task_split,mem_alloc); + } + else if(program.operation_type == "MOD_INV_CONST_P") + { + generator_mod_inv_const_p _generator_mod_inv_const_p; + _g_inst =_generator_mod_inv_const_p.gen_inst(program.p_bitcount,program.e_bitcount,task_split,mem_alloc); + } + return _g_inst; +} + +void generator_asic::gen_inst_para() +{ + inst.clear(); + inst_split.clear(); +} \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/compiler/compiler.h b/heu/library/algorithms/leichi_paillier/compiler/compiler.h new file mode 100755 index 00000000..1ade9690 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/compiler/compiler.h @@ -0,0 +1,600 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using int128_t = __int128_t; +using uint128_t = __uint128_t; + +#define NUMBER_OF_CHIP 8 +#define NUMBER_OF_PE 16 +#define MAX_ON_CHIP_VECTOR 2048 +#define GLOBAL_BUFFER_DEPTH 1024*1024*1024*2 +#define LOCAL_INST_BUFFER_WIDTH 64 +#define LOCAL_BUFFER_WIDTH 256 +#define LOCAL_INST_BUFFER_DEPTH 16*1024 +#define GLOBAL_BUFFER_WIDTH 8 +#define GLOBAL_INST_BUFFER_DEPTH 4096 + +#define SOF 0x00ffffff +#define ADDRESS_PIN_MUX 0x00000068 +#define ADDRESS_INST_FLAG 0x000000D0 +#define DATA_PIN_MUX_PAD 0x00000000 +#define DATA_OUT_TO_OUT_CTRL 0b111111111111111111 + +enum PE_STATE { + IDLE = 0b00000000, + LOAD_PARA = 0b00000001, + ELEMENT_RSQUARE = 0b00000010, + ELE_MOD_MUL = 0b00000011, + ELE_MOD_EXP = 0b00000100, + ELE_MOD_INV = 0b00000101, + ELE_MOD_INV_P = 0b00000110, + ELE_MOD_EXP_INV_EXP = 0b00110001, + ELE_MOD_EXP_INV_MUL = 0b00110010, + ELE_MOD_EXP_INV_COM = 0b00110011, + ELE_MOD_EXP_EXP_EXP = 0b00110100, + ELE_MOD_EXP_EXP_MUL = 0b00110101, + ELE_MOD_EXP_EXP_COM = 0b00110110, + VEC_OUTPUT_RESULTS = 0b00000111, + SET_PSUM_CONV_ONLY = 0b00001000, + CONV_ONLY = 0b00001001, + CONV_REDUCE_SINGLE_BIT = 0b00001010, + CONV_REDUCE_MULTI_PO_BIT = 0b00001011, + CONV_REDUCE_MULTI_NA_BIT = 0b00001100, + CONV_REDUCE_MULTI_NA_1_BIT = 0b00001101, + SET_PSUM_INV_PLUS_CONV = 0b00001110, + MOD_INV_PLUS_CONV = 0b00001111, + MOD_INV_CONV_REDUCE_SIGNLE_BIT = 0b00010000, + MOD_INV_CONV_REDUCE_MULTI_PO_BIT = 0b00010001, + MOD_INV_CONV_REDUCE_MULTI_NA_BIT = 0b00010010, + MOD_INV_CONV_REDUCE_MULTI_NA_1_BIT = 0b00010011, + CHECK_MOD_INV_FINISH = 0b00010100, + MOD_INV_REDUCE_PLUS_CONV = 0b00010101, + OUTPUT_RESULTS = 0b00010110, + MOD_INV_ONLY = 0b00010111, + MOD_INV_REDUCE_ONLY = 0b00011000, + OUTPUT_RESULT = 0b00011001, + COMPLETED = 0b11111111, +} ; + +enum DES_PE_TABLE { + all = 0b101111, + pe_group0 = 0b100001, + pe_group1 = 0b100010, + pe_group2 = 0b100100, + pe_group3 = 0b101000, + pe_0 = 0b000000, + pe_1 = 0b000001, + pe_2 = 0b000010, + pe_3 = 0b000011, + pe_4 = 0b000100, + pe_5 = 0b000101, + pe_6 = 0b000110, + pe_7 = 0b000111, + pe_8 = 0b001000, + pe_9 = 0b001001, + pe_10 = 0b001010, + pe_11 = 0b001011, + pe_12 = 0b001100, + pe_13 = 0b001101, + pe_14 = 0b001110, + pe_15 = 0b001111 +}; + +enum DES_REG_TABLE { + activation = 0b0000, + a = 0b0001, + b = 0b0010, + weight = 0b0011, + r_square = 0b0100, + r_mont = 0b0101, + n_prime = 0b0110, + p = 0b0111, + param = 0b1000 +}; + +enum DATA_TYPE { + w_data = 0b000, + w_inst = 0b001, + w_inst_length = 0b010, + inst_reset = 0b011, + w_reg = 0b100, + r_reg = 0b101 +}; + +enum CAL_FLAG { + CAL = 1, + WAIT = 0, +}; + +enum REPEAT_TABLE { + CHANGE =0b1, + NOCHANGE =0b0 +}; + +struct DATA_IN{ + uint8_t p_bitcount; + uint8_t e_bitcount; +}; + +struct MEM_CFG{ + uint8_t operation_type; + uint8_t const_p_bitcount; + uint8_t const_e_bitcount; + uint8_t const_sram_data_width; + uint8_t data_in_len; + struct DATA_IN data_in[2]; +}; + +struct INST_C_TABLE { + uint8_t operand; + PE_STATE pe_state; + CAL_FLAG cal_flag; + PE_STATE pe_n_state; + uint32_t address; + uint8_t pe_gate; + uint32_t reserver; +}; + +struct Data_k_Info_t { + std::string name; + int const_p_bitcount; + int const_e_bitcount; + int const_sram_data_width; + std::vector> data_in; +}; + +uint64_t gen_inst_l(uint32_t address, uint32_t length, uint8_t des_pe, uint8_t des_reg, uint16_t times, bool change_flag, uint8_t pe_gate=0b1111); +uint64_t gen_inst_c(uint8_t pe_state, bool cal_flag, uint8_t pe_n_state, uint32_t address, uint8_t pe_gate=0b1111); +uint64_t gen_inst_i(uint64_t ddr_address,uint64_t ddr_length); +void check_inst_sram_depth(std::vector _inst); +uint64_t gen_inst_none(); +uint128_t gen_inst_r(uint16_t chip, uint8_t data_type, uint32_t data_address, uint32_t data); +uint128_t gen_inst_l1(uint16_t chip, uint32_t ddr_address, uint32_t ddr_length, uint8_t data_type, uint32_t data_address, bool bool_check=0); +uint128_t gen_inst_l2(uint16_t chip, uint16_t times); + +struct out_ddr_detail_t{ + uint64_t out_ddr_addr; + uint64_t out_ddr_length; +}; + +struct in_data_mem_alloc_t{ + uint64_t in_data_ddr_address_total; + uint64_t in_data_ddr_length_total; + std::vector> in_data_detail; +}; + +struct memory_allocation_t{ + uint64_t out_ddr_address_total; + uint64_t out_ddr_length_total; + std::vector>> out_detail; + uint64_t inst_ddr_address_total; + uint64_t inst_ddr_length_total; + std::vector> inst_detail; + uint64_t in_para_ddr_address_total; + uint64_t in_para_ddr_length_total; + std::vector in_dat_ddr_mem_alloc; + uint64_t last_ddr_address; +}; + +struct Program { + std::string type; + std::string operation_type; + int vec_size; + int p_bitcount; + int e_bitcount; + int start_frequency; +}; + +struct generator_inst_t{ + std::vector inst; + std::vector inst_length; +}; + +struct _executor +{ + std::vector inst; + std::vector inst_fpga; + uint32_t in_para_address; + uint32_t inst_address; + uint32_t out_address; + uint32_t out_length; +}; + +class generator_fpga +{ + public: + generator_fpga(); + ~generator_fpga(){}; + void clear(); + void gen_inst(struct Program program,std::vector> task_split,struct memory_allocation_t memory_allocation, + std::vector>> inst,std::vector>> inst_split); + void _gen_inst_vector_(struct Program program,std::vector> task_split,struct memory_allocation_t memory_allocation, + std::vector>> inst,std::vector>> inst_split); + void __gen_inst_none__(); + void __gen_inst_pll_lock__(); + void __gen_inst_inst_reset__(); + void __gen_inst_inst_flag__(uint8_t chip_sel,uint32_t length); + void __gen_inst_i__(){}; + void __gen_inst_w__(uint32_t address,uint32_t data); + void __gen_inst_l1_wait__(); + void __gen_inst_vector_l1_data_para(struct memory_allocation_t __memory_allocation); + void __gen_inst_vector_l1_data__(uint16_t chip,uint32_t task_split_first,struct memory_allocation_t __memory_allocation); + void __gen_inst_vector_l1_inst__(uint16_t chip,uint32_t task_split_first,struct memory_allocation_t __memory_allocation); + void __gen_inst_vector_r_inst_length_first__(uint16_t chip,uint32_t task_split_first,std::vector>> __inst_split_asic); + void __gen_inst_vector_r_inst_length_middle__(uint32_t task_split_first,std::vector>> __inst_asic,std::vector>> __inst_split_asic,std::vector> __task_split); + void __gen_inst_vector_r_inst_length_last__(uint32_t task_split_first,std::vector>> __inst_asic,std::vector>> __inst_split_asic); + void __gen_inst_vector_r_inst_length_add__(uint16_t chip_sel); + void __gen_inst_c__(uint32_t chip_act_num,uint32_t pll_lock,uint32_t inst_flag); + private: + uint16_t chip_num; + uint32_t ddr; + uint32_t min_chip_true; + uint32_t min_chip_false; + public: + std::vector inst; +}; + +class generator_mont +{ + public: + generator_mont(); + ~generator_mont(){} + void Init(uint32_t& p_bitcount); + void clear() ; + struct generator_inst_t gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) ; + void _inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) ; + void _inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) ; + void _inst_gen_end_(uint8_t actual_pe) ; + private: + std::vector inst; + std::vector> task_split; + uint32_t inst_c_address; + uint32_t inst_l_address; + uint32_t task_split_n; + uint32_t load_a_address; + uint32_t load_b_address; + std::vector inst_length; + uint32_t inst_length_last; + public: + Data_k_Info_t _k_dat; +}; + +class generator_mont_const +{ + public: + generator_mont_const(); + ~generator_mont_const(){} + void Init(uint32_t& p_bitcount); + void clear() ; + struct generator_inst_t gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) ; + void _inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) ; + void _inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) ; + void _inst_gen_end_(uint8_t actual_pe) ; + private: + std::vector inst; + std::vector> task_split; + uint32_t inst_c_address; + uint32_t inst_l_address; + uint32_t task_split_n; + uint32_t load_a_address; + uint32_t load_b_address; + std::vector inst_length; + uint32_t inst_length_last; + public: + Data_k_Info_t _k_dat; +}; + +class generator_mod_mul +{ + public: + generator_mod_mul(); + ~generator_mod_mul(){} + void clear() ; + struct generator_inst_t gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) ; + void _inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) ; + void _inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) ; + void _inst_gen_end_(uint8_t actual_pe) ; + private: + std::vector inst; + std::vector> task_split; + uint32_t inst_c_address; + uint32_t inst_l_address; + uint32_t task_split_n; + uint32_t load_a_address; + uint32_t load_b_address; + std::vector inst_length; + uint32_t inst_length_last; + public: + Data_k_Info_t _k_dat; +}; + +class generator_mod_mul_const +{ + public: + generator_mod_mul_const(); + ~generator_mod_mul_const(){} + void Init(uint32_t& p_bitcount); + void clear() ; + struct generator_inst_t gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) ; + void _inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) ; + void _inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) ; + void _inst_gen_end_(uint8_t actual_pe) ; + private: + std::vector inst; + std::vector> task_split; + uint32_t inst_c_address; + uint32_t inst_l_address; + uint32_t task_split_n; + uint32_t load_a_address; + uint32_t load_b_address; + std::vector inst_length; + uint32_t inst_length_last; + public: + Data_k_Info_t _k_dat; +}; + +class generator_mod_exp +{ + public: + generator_mod_exp(); + ~generator_mod_exp(){} + void Init(uint32_t& p_bitcount); + void clear() ; + struct generator_inst_t gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) ; + void _inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) ; + void _inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) ; + void _inst_gen_end_(uint8_t actual_pe) ; + private: + std::vector inst; + std::vector> task_split; + uint32_t inst_c_address; + uint32_t inst_l_address; + uint32_t task_split_n; + uint32_t load_a_address; + uint32_t load_b_address; + std::vector inst_length; + uint32_t inst_length_last; + public: + Data_k_Info_t _k_dat; +}; + +class generator_mod_exp_const_e +{ + public: + generator_mod_exp_const_e(); + ~generator_mod_exp_const_e(){} + void Init(uint32_t& p_bitcount); + void clear() ; + struct generator_inst_t gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) ; + void _inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) ; + void _inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) ; + void _inst_gen_end_(uint8_t actual_pe) ; + private: + std::vector inst; + std::vector> task_split; + uint32_t inst_c_address; + uint32_t inst_l_address; + uint32_t task_split_n; + uint32_t load_a_address; + uint32_t load_b_address; + std::vector inst_length; + uint32_t inst_length_last; + public: + Data_k_Info_t _k_dat; +}; + +class generator_mod_exp_const_a +{ + public: + generator_mod_exp_const_a(); + ~generator_mod_exp_const_a(){} + void Init(uint32_t& p_bitcount); + void clear() ; + struct generator_inst_t gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) ; + void _inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) ; + void _inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) ; + void _inst_gen_end_(uint8_t actual_pe) ; + private: + std::vector inst; + std::vector> task_split; + uint32_t inst_c_address; + uint32_t inst_l_address; + uint32_t task_split_n; + uint32_t load_a_address; + uint32_t load_b_address; + std::vector inst_length; + uint32_t inst_length_last; + public: + Data_k_Info_t _k_dat; +}; + +class generator_mod_inv_const_p +{ + public: + generator_mod_inv_const_p(); + ~generator_mod_inv_const_p(){} + void Init(uint32_t& p_bitcount); + void clear() ; + struct generator_inst_t gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) ; + void _inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) ; + void _inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) ; + void _inst_gen_end_(uint8_t actual_pe) ; + private: + std::vector inst; + std::vector> task_split; + uint32_t inst_c_address; + uint32_t inst_l_address; + uint32_t task_split_n; + uint32_t load_a_address; + uint32_t load_b_address; + std::vector inst_length; + uint32_t inst_length_last; + public: + Data_k_Info_t _k_dat; +}; + +class generator_paillier_encrypt +{ + public: + generator_paillier_encrypt(); + ~generator_paillier_encrypt(){} + void clear() ; + struct generator_inst_t gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) ; + void _inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) ; + void _inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) ; + void _inst_gen_end_(uint8_t actual_pe) ; + private: + std::vector inst; + std::vector> task_split; + uint32_t inst_c_address; + uint32_t inst_l_address; + uint32_t task_split_n; + uint32_t load_a_address; + uint32_t load_b_address; + std::vector inst_length; + uint32_t inst_length_last; + uint32_t address_n; + uint32_t address_g; + public: + Data_k_Info_t _k_dat; +}; + +class generator_mod_add +{ + public: + generator_mod_add(); + ~generator_mod_add(){} + void Init(uint32_t& p_bitcount); + void clear() ; + struct generator_inst_t gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) ; + void _inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) ; + void _inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) ; + void _inst_gen_end_(uint8_t actual_pe) ; + private: + std::vector inst; + std::vector> task_split; + uint32_t inst_c_address; + uint32_t inst_l_address; + uint32_t task_split_n; + uint32_t load_a_address; + uint32_t load_b_address; + std::vector inst_length; + uint32_t inst_length_last; + public: + Data_k_Info_t _k_dat; +}; + +class generator_mod_add_const +{ + public: + generator_mod_add_const(); + ~generator_mod_add_const(){} + void Init(uint32_t& p_bitcount); + void clear() ; + struct generator_inst_t gen_inst(uint32_t p_bitcount,uint32_t e_bitcount,uint32_t task_split,std::vector mem_alloc) ; + void _inst_gen_start_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe) ; + void _inst_gen_middle_(uint32_t p_bitcount,uint32_t e_bitcount,uint8_t actual_pe,uint32_t ddr_address,uint32_t ddr_length) ; + void _inst_gen_end_(uint8_t actual_pe) ; + private: + std::vector inst; + std::vector> task_split; + uint32_t inst_c_address; + uint32_t inst_l_address; + uint32_t task_split_n; + uint32_t load_a_address; + uint32_t load_b_address; + std::vector inst_length; + uint32_t inst_length_last; + public: + Data_k_Info_t _k_dat; +}; + +class generator_asic +{ + public: + generator_asic() = default; + struct generator_inst_t gen_inst(struct Program program,uint32_t task_split,std::vector mem_alloc); + void gen_inst_para(); + private: + std::vector inst; + std::vector inst_split; + public: + std::string operation_type; + auto _generator_sel_(std::string choice_generator_num); +}; + +class Compiler +{ + public: + Compiler(); + ~Compiler(){ + memory_allocation.last_ddr_address = 0; + ddr_address = 0; + } + void clear(); + void get_k_dat(struct Program program); + void set_device(uint8_t number_of_chip,uint8_t number_of_pe,uint64_t ddr); + void compile(); + void _task_split_(struct Program program); + void __task_split_vector__(struct Program program); + void _memory_allocation_(struct Program program); + void __memory_allocation_vector_out__(struct Program program); + void __memory_allocation_vector_inst__(struct Program program); + void __memory_allocation_vector_in_para__(struct Program program); + bool __memory_allocation_vector_in_data_i__(struct Program program,uint32_t i_of_op); + void __memory_allocation_vector_in_para___bak(MEM_CFG mem_cfg); + void __gen_asic_inst__(struct Program program); + void ___gen_asic_inst_vector___(); + void _gen_fpga_inst_(struct Program program); + void check_mem(struct Program program); + void _inst_reshape_(struct Program program); + void get_executor(); + void data_inverse(uint8_t* data,uint32_t len); + std::string to_string(uint128_t x); + public: + generator_asic _generator_asic; + generator_fpga _generator_fpga; + public: + uint8_t num_of_chip; + uint8_t num_of_pe; + uint64_t ddr_size; + uint64_t ddr_address; + Data_k_Info_t _k_dat; + public: + struct Program _program; + std::vector> task_split; + std::vector>> inst; + std::vector>> inst_split; + struct memory_allocation_t memory_allocation; + std::vector inst_byte; + struct _executor executor; +}; diff --git a/heu/library/algorithms/leichi_paillier/key_generator.cc b/heu/library/algorithms/leichi_paillier/key_generator.cc new file mode 100644 index 00000000..384fce4b --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/key_generator.cc @@ -0,0 +1,43 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "heu/library/algorithms/leichi_paillier/key_generator.h" +#include +#include +#include +#include +#include +namespace heu::lib::algorithms::leichi_paillier { + void get_prime(size_t bit_len,const Plaintext &op) + { + BN_generate_prime_ex(op.bn_,bit_len,1,NULL,NULL,NULL); + } + void KeyGenerator::Generate(size_t key_size, SecretKey* sk, PublicKey* pk){ + do{ + get_prime(key_size/2,sk->p_); + get_prime(key_size/2,sk->q_); + }while(sk->p_ == sk->q_); + + pk->n_ = sk->p_*sk->q_; + Plaintext one; + uint32_t a =1; + one.Set(a); + pk->g_ = pk->n_ + one; + pk->max_plaintext_ = pk->n_; + } + + void KeyGenerator::Generate(SecretKey* sk, PublicKey* pk) { + Generate(2048, sk, pk); + } +} \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/key_generator.h b/heu/library/algorithms/leichi_paillier/key_generator.h new file mode 100644 index 00000000..0a759aad --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/key_generator.h @@ -0,0 +1,24 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "heu/library/algorithms/leichi_paillier/public_key.h" +#include "heu/library/algorithms/leichi_paillier/secret_key.h" + +namespace heu::lib::algorithms::leichi_paillier { + class KeyGenerator { + public: + static void Generate(size_t key_size, SecretKey* sk, PublicKey* pk); + static void Generate(SecretKey* sk, PublicKey* pk); + }; +} \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/key_generator_test.cc b/heu/library/algorithms/leichi_paillier/key_generator_test.cc new file mode 100644 index 00000000..58dd597d --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/key_generator_test.cc @@ -0,0 +1,28 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "heu/library/algorithms/leichi_paillier/key_generator.h" +#include "gtest/gtest.h" + +namespace heu::lib::algorithms::leichi_paillier::test { + class KeyGenTest : public ::testing::TestWithParam {}; + INSTANTIATE_TEST_SUITE_P(SubTest, KeyGenTest, + ::testing::Values(1024)); + + TEST_P(KeyGenTest, SubTest) { + SecretKey sk; + PublicKey pk; + KeyGenerator::Generate(GetParam(), &sk, &pk); + } +} // namespace heu::lib::algorithms::leichi_paillier::test diff --git a/heu/library/algorithms/leichi_paillier/leichi.h b/heu/library/algorithms/leichi_paillier/leichi.h new file mode 100644 index 00000000..96d9ee9c --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/leichi.h @@ -0,0 +1,32 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef APPLY_LEICHI +#define ENABLE_LEICHI true +#else +#define ENABLE_LEICHI false +#endif + +#if ENABLE_LEICHI == true + +#include "heu/library/algorithms/leichi_paillier/vector_decryptor.h" +#include "heu/library/algorithms/leichi_paillier/vector_encryptor.h" +#include "heu/library/algorithms/leichi_paillier/vector_evaluator.h" +#include "heu/library/algorithms/leichi_paillier/key_generator.h" +#include "heu/library/algorithms/leichi_paillier/public_key.h" +#include "heu/library/algorithms/leichi_paillier/secret_key.h" + +#endif \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/leichi_test.cc b/heu/library/algorithms/leichi_paillier/leichi_test.cc new file mode 100644 index 00000000..a5581db8 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/leichi_test.cc @@ -0,0 +1,330 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gtest/gtest.h" +#include "heu/library/algorithms/leichi_paillier/vector_encryptor.h" +#include "heu/library/algorithms/leichi_paillier/vector_decryptor.h" +#include "heu/library/algorithms/leichi_paillier/key_generator.h" +#include "heu/library/algorithms/leichi_paillier/vector_evaluator.h" + +namespace heu::lib::algorithms::leichi_paillier::test { + + class LEICHITest : public testing::Test { + protected: + void SetUp() override { + KeyGenerator::Generate(2048, &sk_, &pk_); + encryptor_ = std::make_shared(pk_); + evaluator_ = std::make_shared(pk_); + decryptor_ = std::make_shared(pk_, sk_); + } + + protected: + SecretKey sk_; + PublicKey pk_; + std::shared_ptr encryptor_; + std::shared_ptr evaluator_; + std::shared_ptr decryptor_; + }; + + int CompareBignum(const BIGNUM* bn1, const BIGNUM* bn2) { + return BN_cmp(bn1, bn2); + } + + TEST_F(LEICHITest, DISABLE_EncDec) { + Encryptor encryptor(pk_); + Decryptor decryptor(pk_, sk_); + std::vector a_vec{25, 13}; + std::vector a_pt_vec; + auto vec_size = a_vec.size(); + for (size_t i = 0; i < vec_size; i++) { + Plaintext a; + a.Set(a_vec[i]); + a_pt_vec.push_back(a); + } + + std::vector<Plaintext *> a_pt_pts; + ValueVecToPtsVec(a_pt_vec, a_pt_pts); + auto a_pt_span = absl::MakeConstSpan(a_pt_pts.data(), vec_size); + auto a_ct_vec = encryptor_->Encrypt(a_pt_span); + std::vector<Ciphertext *> a_ct_pts; + ValueVecToPtsVec(a_ct_vec, a_ct_pts); + auto a_ct_span = absl::MakeConstSpan(a_ct_pts.data(), vec_size); + std::vector<Plaintext> res_pt_vec = decryptor_->Decrypt(a_ct_span); + } + + TEST_F(LEICHITest, DISABLE_CTPlusCT) { + std::vector<int32_t> a_vec{25, 13}; + std::vector<int32_t> b_vec{-25, 13}; + + std::vector<Plaintext> a_pt_vec; + std::vector<Plaintext> b_pt_vec; + auto vec_size = a_vec.size(); + + for (size_t i = 0; i < vec_size; i++) { + Plaintext a; + a.Set(a_vec[i]); + a_pt_vec.push_back(a); + + Plaintext b; + b.Set(b_vec[i]); + b_pt_vec.push_back(b); + } + + std::vector<Plaintext *> a_pt_pts; + std::vector<Plaintext *> b_pt_pts; + ValueVecToPtsVec(a_pt_vec, a_pt_pts); + ValueVecToPtsVec(b_pt_vec, b_pt_pts); + auto a_pt_span = absl::MakeConstSpan(a_pt_pts.data(), vec_size); + auto b_pt_span = absl::MakeConstSpan(b_pt_pts.data(), vec_size); + auto a_ct_vec = encryptor_->Encrypt(a_pt_span); + auto b_ct_vec = encryptor_->Encrypt(b_pt_span); + + std::vector<Ciphertext *> a_ct_pts; + std::vector<Ciphertext *> b_ct_pts; + ValueVecToPtsVec(a_ct_vec, a_ct_pts); + ValueVecToPtsVec(b_ct_vec, b_ct_pts); + auto a_ct_span = absl::MakeConstSpan(a_ct_pts.data(), vec_size); + auto b_ct_span = absl::MakeConstSpan(b_ct_pts.data(), vec_size); + std::vector<Ciphertext> res_ct_vec = evaluator_->Add(a_ct_span, b_ct_span); + std::vector<Ciphertext *> res_ct_pts; + ValueVecToPtsVec(res_ct_vec, res_ct_pts); + auto res_ct_span = absl::MakeConstSpan(res_ct_pts.data(), vec_size); + std::vector<Plaintext> res_pt_vec = decryptor_->Decrypt(res_ct_span); + } + + TEST_F(LEICHITest, DISABLE_CTPlusPT) { + std::vector<int32_t> a_vec{25, 13, 15}; + std::vector<int32_t> b_vec{-25, 13, 15}; + + std::vector<Plaintext> a_pt_vec; + std::vector<Plaintext> b_pt_vec; + auto vec_size = a_vec.size(); + + for (size_t i = 0; i < vec_size; i++) { + Plaintext a; + a.Set(a_vec[i]); + a_pt_vec.push_back(a); + + Plaintext b; + b.Set(b_vec[i]); + b_pt_vec.push_back(b); + } + + std::vector<Plaintext *> a_pt_pts; + std::vector<Plaintext *> b_pt_pts; + ValueVecToPtsVec(a_pt_vec, a_pt_pts); + ValueVecToPtsVec(b_pt_vec, b_pt_pts); + auto a_pt_span = absl::MakeConstSpan(a_pt_pts.data(), vec_size); + auto b_pt_span = absl::MakeConstSpan(b_pt_pts.data(), vec_size); + auto a_ct_vec = encryptor_->Encrypt(a_pt_span); + + std::vector<Ciphertext *> a_ct_pts; + ValueVecToPtsVec(a_ct_vec, a_ct_pts); + auto a_ct_span = absl::MakeConstSpan(a_ct_pts.data(), vec_size); + + std::vector<Ciphertext> res_ct_vec = evaluator_->Add(a_ct_span, b_pt_span); + + std::vector<Ciphertext *> res_ct_pts; + ValueVecToPtsVec(res_ct_vec, res_ct_pts); + auto res_ct_span = absl::MakeConstSpan(res_ct_pts.data(), vec_size); + std::vector<Plaintext> res_pt_vec = decryptor_->Decrypt(res_ct_span); + } + + TEST_F(LEICHITest, DISABLE_PTPlusCT) { + std::vector<int32_t> a_vec{25, 13, 15}; + std::vector<int32_t> b_vec{25, 13, 15}; + + std::vector<Plaintext> a_pt_vec; + std::vector<Plaintext> b_pt_vec; + auto vec_size = a_vec.size(); + + for (size_t i = 0; i < vec_size; i++) { + Plaintext a; + a.Set(a_vec[i]); + a_pt_vec.push_back(a); + + Plaintext b; + b.Set(b_vec[i]); + b_pt_vec.push_back(b); + } + std::vector<Plaintext *> a_pt_pts; + std::vector<Plaintext *> b_pt_pts; + ValueVecToPtsVec(a_pt_vec, a_pt_pts); + ValueVecToPtsVec(b_pt_vec, b_pt_pts); + auto a_pt_span = absl::MakeConstSpan(a_pt_pts.data(), vec_size); + auto b_pt_span = absl::MakeConstSpan(b_pt_pts.data(), vec_size); + auto b_ct_vec = encryptor_->Encrypt(b_pt_span); + + std::vector<Ciphertext *> b_ct_pts; + ValueVecToPtsVec(b_ct_vec, b_ct_pts); + auto b_ct_span = absl::MakeConstSpan(b_ct_pts.data(), vec_size); + + std::vector<Ciphertext> res_ct_vec = evaluator_->Add(a_pt_span, b_ct_span); + + std::vector<Ciphertext *> res_ct_pts; + ValueVecToPtsVec(res_ct_vec, res_ct_pts); + auto res_ct_span = absl::MakeConstSpan(res_ct_pts.data(), vec_size); + std::vector<Plaintext> res_pt_vec = decryptor_->Decrypt(res_ct_span); + } + + + TEST_F(LEICHITest, DISABLE_CTSubCT) { + std::vector<int32_t> a_vec{50, 30}; + std::vector<int32_t> b_vec{-20, 10}; + std::vector<int32_t> expect_res{30, 20}; + + std::vector<Plaintext> a_pt_vec; + std::vector<Plaintext> b_pt_vec; + auto vec_size = a_vec.size(); + for (size_t i = 0; i < vec_size; i++) { + Plaintext a; + a.Set(a_vec[i]); + a_pt_vec.push_back(a); + + Plaintext b; + b.Set(b_vec[i]); + b_pt_vec.push_back(b); + } + std::vector<Plaintext *> a_pt_pts; + std::vector<Plaintext *> b_pt_pts; + ValueVecToPtsVec(a_pt_vec, a_pt_pts); + ValueVecToPtsVec(b_pt_vec, b_pt_pts); + auto a_pt_span = absl::MakeConstSpan(a_pt_pts.data(), vec_size); + auto b_pt_span = absl::MakeConstSpan(b_pt_pts.data(), vec_size); + auto a_ct_vec = encryptor_->Encrypt(a_pt_span); + auto b_ct_vec = encryptor_->Encrypt(b_pt_span); + + std::vector<Ciphertext *> a_ct_pts; + std::vector<Ciphertext *> b_ct_pts; + ValueVecToPtsVec(a_ct_vec, a_ct_pts); + ValueVecToPtsVec(b_ct_vec, b_ct_pts); + auto a_ct_span = absl::MakeConstSpan(a_ct_pts.data(), vec_size); + auto b_ct_span = absl::MakeConstSpan(b_ct_pts.data(), vec_size); + std::vector<Ciphertext> res_ct_vec = evaluator_->Sub(a_ct_span, b_ct_span); + + std::vector<Ciphertext *> res_ct_pts; + ValueVecToPtsVec(res_ct_vec, res_ct_pts); + auto res_ct_span = absl::MakeConstSpan(res_ct_pts.data(), vec_size); + std::vector<Plaintext> res_pt_vec = decryptor_->Decrypt(res_ct_span); + } + + TEST_F(LEICHITest, DISABLE_CTSubPT) { + std::vector<int32_t> a_vec{25, 13, 15}; + std::vector<int32_t> b_vec{20, 10, 10}; + + std::vector<Plaintext> a_pt_vec; + std::vector<Plaintext> b_pt_vec; + auto vec_size = a_vec.size(); + + for (size_t i = 0; i < vec_size; i++) { + Plaintext a; + a.Set(a_vec[i]); + a_pt_vec.push_back(a); + + Plaintext b; + b.Set(b_vec[i]); + b_pt_vec.push_back(b); + } + + std::vector<Plaintext *> a_pt_pts; + std::vector<Plaintext *> b_pt_pts; + ValueVecToPtsVec(a_pt_vec, a_pt_pts); + ValueVecToPtsVec(b_pt_vec, b_pt_pts); + auto a_pt_span = absl::MakeConstSpan(a_pt_pts.data(), vec_size); + auto b_pt_span = absl::MakeConstSpan(b_pt_pts.data(), vec_size); + auto a_ct_vec = encryptor_->Encrypt(a_pt_span); + + std::vector<Ciphertext *> a_ct_pts; + ValueVecToPtsVec(a_ct_vec, a_ct_pts); + auto a_ct_span = absl::MakeConstSpan(a_ct_pts.data(), vec_size); + std::vector<Ciphertext> res_ct_vec = evaluator_->Sub(a_ct_span, b_pt_span); + std::vector<Ciphertext *> res_ct_pts; + ValueVecToPtsVec(res_ct_vec, res_ct_pts); + auto res_ct_span = absl::MakeConstSpan(res_ct_pts.data(), vec_size); + std::vector<Plaintext> res_pt_vec = decryptor_->Decrypt(res_ct_span); + } + + TEST_F(LEICHITest, DISABLE_PTSubCT) { + std::vector<int32_t> a_vec{25, 13, 15}; + std::vector<int32_t> b_vec{20, 10, 10}; + + std::vector<Plaintext> a_pt_vec; + std::vector<Plaintext> b_pt_vec; + auto vec_size = a_vec.size(); + for (size_t i = 0; i < vec_size; i++) { + Plaintext a; + a.Set(a_vec[i]); + a_pt_vec.push_back(a); + + Plaintext b; + b.Set(b_vec[i]); + b_pt_vec.push_back(b); + } + std::vector<Plaintext *> a_pt_pts; + std::vector<Plaintext *> b_pt_pts; + ValueVecToPtsVec(a_pt_vec, a_pt_pts); + ValueVecToPtsVec(b_pt_vec, b_pt_pts); + auto a_pt_span = absl::MakeConstSpan(a_pt_pts.data(), vec_size); + auto b_pt_span = absl::MakeConstSpan(b_pt_pts.data(), vec_size); + auto b_ct_vec = encryptor_->Encrypt(b_pt_span); + + std::vector<Ciphertext *> b_ct_pts; + ValueVecToPtsVec(b_ct_vec, b_ct_pts); + auto b_ct_span = absl::MakeConstSpan(b_ct_pts.data(), vec_size); + std::vector<Ciphertext> res_ct_vec = evaluator_->Sub(a_pt_span, b_ct_span); + std::vector<Ciphertext *> res_ct_pts; + ValueVecToPtsVec(res_ct_vec, res_ct_pts); + auto res_ct_span = absl::MakeConstSpan(res_ct_pts.data(), vec_size); + std::vector<Plaintext> res_pt_vec = decryptor_->Decrypt(res_ct_span); + } + + TEST_F(LEICHITest, DISABLE_CTMultiplyPT) { + std::vector<int32_t> a_vec{-5, 3}; + std::vector<int32_t> b_vec{2, 1}; + std::vector<int32_t> expect_res{10, 3}; + + std::vector<Plaintext> a_pt_vec; + std::vector<Plaintext> b_pt_vec; + + for (size_t i = 0; i < a_vec.size(); i++) { + Plaintext a; + a.Set(a_vec[i]); + a_pt_vec.push_back(a); + } + for (size_t i = 0; i < b_vec.size(); i++) { + Plaintext b; + b.Set(b_vec[i]); + b_pt_vec.push_back(b); + } + std::vector<Plaintext *> a_pt_pts; + std::vector<Plaintext *> b_pt_pts; + ValueVecToPtsVec(a_pt_vec, a_pt_pts); + ValueVecToPtsVec(b_pt_vec, b_pt_pts); + auto a_pt_span = absl::MakeConstSpan(a_pt_pts.data(), a_vec.size()); + auto b_pt_span = absl::MakeConstSpan(b_pt_pts.data(), b_vec.size()); + auto a_ct_vec = encryptor_->Encrypt(a_pt_span); + + std::vector<Ciphertext *> a_ct_pts; + ValueVecToPtsVec(a_ct_vec, a_ct_pts); + auto a_ct_span = absl::MakeConstSpan(a_ct_pts.data(), a_vec.size()); + std::vector<Ciphertext> res_ct_vec = evaluator_->Mul(a_ct_span, b_pt_span); + std::vector<Ciphertext *> res_ct_pts; + ValueVecToPtsVec(res_ct_vec, res_ct_pts); + auto res_ct_span = absl::MakeConstSpan(res_ct_pts.data(), res_ct_pts.size()); + std::vector<Plaintext> res_pt_vec = decryptor_->Decrypt(res_ct_span); + } +} // namespace heu::lib::algorithms::leichi_paillier::test diff --git a/heu/library/algorithms/leichi_paillier/pcie/pcie.cc b/heu/library/algorithms/leichi_paillier/pcie/pcie.cc new file mode 100644 index 00000000..d99a177c --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/pcie/pcie.cc @@ -0,0 +1,124 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "pcie.h" + +CPcieComm::CPcieComm() +{ + +} + +CPcieComm::~CPcieComm() +{ +} + +int CPcieComm::open_device() +{ + wr_fd = open("/dev/xdma0_h2c_0", O_RDWR | O_NONBLOCK); // xdma write channel. + rd_fd = open("/dev/xdma0_c2h_0", O_RDWR | O_NONBLOCK); // xdma read channel. + bp_fd = open("/dev/xdma0_bypass", O_RDWR | O_NONBLOCK); // xdma bypass channel. + + if (wr_fd < 0 || rd_fd < 0 || bp_fd < 0) + { + // printf("CPcieComm::pcie:open device error!wr_fd:%d rd_fd:%d bp_fd:%d\n", wr_fd, rd_fd, bp_fd); + return -1; + } + + map_base = (int *)mmap(0, MAP_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, bp_fd, 0); + if (!map_base) + { + return -1; + } + + b_pcie_open = true; + return 1; +} + +int CPcieComm::close_device() +{ + close(wr_fd); // write channel. + close(rd_fd); // read channel. + close(bp_fd); // bypass channel. + b_pcie_open = false; + munmap(map_base, MAP_SIZE); + return 1; +} + +bool CPcieComm::pcie_is_open() +{ + return b_pcie_open; +} + +int CPcieComm::write_data(unsigned int addr, unsigned char *data, unsigned int len) +{ + int res = 0; + // adjust to specified address. + lseek(wr_fd, addr, SEEK_SET); + + res = write(wr_fd, (void *)data, len); + + if (res <= 0) + { + printf("CPcieComm:send data error.res=%d\n", res); + return -1; + } + + return res; +} + +int CPcieComm::read_data(unsigned int addr, unsigned char *data, unsigned int len) +{ + int res = 0; + // Read calculation result data. + lseek(rd_fd, addr, SEEK_SET); + + res = read(rd_fd, (void *)data, len); + if (res <= 0) + { + printf("read data error,res=%d\n", res); + return -1; + } + return len; +} + +int CPcieComm::write_data_bypass(unsigned int addr, unsigned char *data, unsigned int len) +{ + for (unsigned int i = 0; i < len; i++) + { + *((volatile uint8_t *)((unsigned char *)map_base + addr + i)) = *(data + i); + } + return len; +} + +int CPcieComm::read_data_bypass(unsigned int addr, unsigned char *data, unsigned int len) +{ + for (unsigned int i = 0; i < len; i++) + { + *(data + i) = *((volatile uint8_t *)((unsigned char *)map_base + addr + i)); + } + return len; +} + +int CPcieComm::write_reg(unsigned int addr, unsigned int data) +{ + *((volatile unsigned int *)((unsigned char *)map_base+addr)) = data; + return 1; +} + +int CPcieComm::read_reg(unsigned int addr, unsigned int *data) +{ + *data = *((volatile unsigned int *)((unsigned char *)map_base+addr)); + return 1; +} diff --git a/heu/library/algorithms/leichi_paillier/pcie/pcie.h b/heu/library/algorithms/leichi_paillier/pcie/pcie.h new file mode 100644 index 00000000..11d946d2 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/pcie/pcie.h @@ -0,0 +1,53 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include <stdlib.h> +#include <string.h> +#include <unistd.h> +#include <fcntl.h> +#include <stdio.h> +#include <ctype.h> +#include <string.h> +#include<sys/types.h> +#include<sys/stat.h> +#include<fcntl.h> +#include<fcntl.h> +#include<sys/mman.h> +#include<arpa/inet.h> + +#define MAP_SIZE (2*1024*1024UL) + +class CPcieComm +{ +public: + CPcieComm(); + ~CPcieComm(); +public: + int open_device(); + int close_device(); + bool pcie_is_open(); + int write_data(unsigned int addr,unsigned char *data,unsigned int len); + int read_data(unsigned int addr,unsigned char *data,unsigned int len); + int write_data_bypass(unsigned int addr,unsigned char *data,unsigned int len); + int read_data_bypass(unsigned int addr,unsigned char *data,unsigned int len); + int write_reg(unsigned int addr,unsigned int data); + int read_reg(unsigned int addr, unsigned int *data); +private: + int wr_fd; //write channel fd. + int rd_fd; //read channel fd. + int bp_fd; //bypass channel fd. + int *map_base; + bool b_pcie_open =false; +}; \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/plaintext.cc b/heu/library/algorithms/leichi_paillier/plaintext.cc new file mode 100644 index 00000000..b1a5d1f9 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/plaintext.cc @@ -0,0 +1,537 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "heu/library/algorithms/leichi_paillier/plaintext.h" +#include "cereal/archives/portable_binary.hpp" +#include "heu/library/algorithms/leichi_paillier/utils.h" +namespace heu::lib::algorithms::leichi_paillier { + void vector_to_char_array(std::vector<uint8_t>& vec, unsigned char* &arr) + { + arr = reinterpret_cast<unsigned char*>(vec.data()); + } + + std::ostream& operator<<(std::ostream& os, const Plaintext& pt) { + char* str = BN_bn2dec(pt.bn_); + os << str; + return os; + } + bool Plaintext::operator!=(const Plaintext& other) const { + return (BN_cmp(bn_, other.bn_))?true:false; + } + + std::string Plaintext::ToString() const { + char* str = BN_bn2dec(bn_); + std::string result(str); + return result; + } + + std::string Plaintext::ToHexString() const { + char* str = BN_bn2hex(bn_); + std::string result(str); + return result; + } + + Plaintext Plaintext::operator-() const{ + Plaintext result; + BN_copy(result.bn_,bn_); + BN_set_negative(result.bn_,!BN_is_negative(bn_)); + return result; + } + + template <> + void Plaintext::Set(Plaintext value) { + bn_ = BN_dup(value.bn_); + } + + template <> + void Plaintext::Set(uint8_t value) { + BN_set_word(bn_,(BN_ULONG)value); + } + + template <> + void Plaintext::Set(int8_t value) { + BN_set_word(bn_,(BN_ULONG)abs(value)); + if(value < 0) + { + BN_set_negative(bn_,1); + } + } + + template <> + void Plaintext::Set(uint16_t value) { + BN_set_word(bn_,(BN_ULONG)value); + } + + template <> + void Plaintext::Set(int16_t value) { + BN_set_word(bn_,(BN_ULONG)abs(value)); + if(value < 0) + { + BN_set_negative(bn_,1); + } + } + + template <> + void Plaintext::Set(uint32_t value) { + BN_set_word(bn_,(BN_ULONG)value); + } + + template <> + void Plaintext::Set(int32_t value) { + BN_set_word(bn_,(BN_ULONG)abs(value)); + if(value < 0) + { + BN_set_negative(bn_,1); + } + } + + template <> + void Plaintext::Set(int64_t value) { + BN_set_word(bn_,(BN_ULONG)abs(value)); + if(value < 0){ + BN_set_negative(bn_,1); + } + } + + template <> + void Plaintext::Set(uint64_t value) { + BN_set_word(bn_,(BN_ULONG)value); + } + + template <> + void Plaintext::Set(int128_t value) { + uint128_t avalue = value > 0 ? value : -value; + auto sign = value > 0 ? 0 : 1; + unsigned char data[16] = {0}; + for (size_t i = 0; i < sizeof(value); ++i) { + data[sizeof(value) - i - 1] = (unsigned char)(avalue >> (i * 8)); + } + bn_ = BN_bin2bn(data, sizeof(data), NULL); + BN_set_negative(bn_,sign); + } + + template <> + void Plaintext::Set(uint128_t value) { + unsigned char data[16] = {0}; + for (size_t i = 0; i < sizeof(value); ++i) { + data[sizeof(value) - i - 1] = (unsigned char)(value >> (i * 8)); + } + + bn_ = BN_bin2bn(data, sizeof(data), NULL); + } + + template <> + BIGNUM * Plaintext::Get() const{ + return bn_; + } + + template <> + int8_t Plaintext::Get() const{ + int8_t value = 0; + if(BN_is_negative(bn_)) + { + BN_set_negative(bn_,1); + } + value = BN_get_word(bn_); + return value; + } + + template <> + uint8_t Plaintext::Get() const{ + uint8_t value = 0; + value = BN_get_word(bn_); + return value; + } + + template <> + int16_t Plaintext::Get() const{ + char *str = BN_bn2dec(bn_); + int16_t value = (int16_t)strtol(str, NULL, 10); + return value; + } + + template <> + uint16_t Plaintext::Get() const{ + uint16_t value = 0; + value = BN_get_word(bn_); + return value; + } + + template <> + int32_t Plaintext::Get() const{ + char *str = BN_bn2dec(bn_); + int32_t value = (int32_t)strtol(str, NULL, 10); + return value; + } + + template <> + uint32_t Plaintext::Get() const{ + uint32_t value = 0; + value = BN_get_word(bn_); + return value; + } + + template <> + int64_t Plaintext::Get() const{ + char *str = BN_bn2dec(bn_); + int64_t value = (int64_t)strtol(str, NULL, 10); + return value; + } + + template <> + uint64_t Plaintext::Get() const{ + uint64_t value = 0; + value = BN_get_word(bn_); + return value; + } + + template <> + int128_t Plaintext::Get() const{ + std::vector<uint8_t> vec_tmp; + uint8_t * _buff; + int128_t value = 0; + vec_tmp = bnTobin(this->bn_); + std::reverse(vec_tmp.begin(), vec_tmp.end()); + vector_to_char_array(vec_tmp,_buff); + std::memcpy(&value,_buff,vec_tmp.size()); + + if(BN_is_negative(bn_)) + { + value = -value; + } + return value; + } + + template <> + uint128_t Plaintext::Get() const{ + std::vector<uint8_t> vec_tmp; + uint8_t * _buff; + uint128_t value = 0; + vec_tmp = bnTobin(this->bn_); + std::reverse(vec_tmp.begin(), vec_tmp.end()); + vector_to_char_array(vec_tmp,_buff); + std::memcpy(&value,_buff,vec_tmp.size()); + return value; + } + + template <> + void Plaintext::Set(double value) { + int64_t int_val = static_cast<int64_t>(value); + Set(int_val); + } + + template <> + double Plaintext::Get() const { + int64_t ret = this->Get<int64_t>(); + return (double)ret; + } + + template <> + void Plaintext::Set(float value) { + int64_t int_val = static_cast<int64_t>(value); + Set(int_val); + } + + template <> + float Plaintext::Get() const { + int64_t ret = this->Get<int64_t>(); + return (float)ret; + } + + void Plaintext::RandomExactBits(size_t bit_size, Plaintext *r){ + BN_rand(r->bn_,bit_size,0,0); + } + + void Plaintext::RandomLtN(const Plaintext &n, Plaintext *r){ + BN_rand_range(r->bn_, n.bn_); + } + + Plaintext Plaintext::operator+=(const Plaintext &op2) { + BN_add(bn_,bn_,op2.bn_); + return *this; + } + + Plaintext Plaintext::operator-=(const Plaintext &op2) { + BN_sub(bn_,bn_,op2.bn_); + return *this; + } + + Plaintext Plaintext::operator*=(const Plaintext &op2) { + BN_CTX *bn_ctx = BN_CTX_new(); + BN_mul(bn_,bn_,op2.bn_,bn_ctx); + BN_CTX_free(bn_ctx); + return *this; + } + + Plaintext Plaintext::operator/=(const Plaintext &op2) { + Plaintext rem; + BN_CTX *bn_ctx = BN_CTX_new(); + BN_div(bn_,rem.bn_,bn_,op2.bn_,bn_ctx); + BN_CTX_free(bn_ctx); + return *this; + } + + Plaintext Plaintext::operator%=(const Plaintext &op2) { + BN_CTX *bn_ctx = BN_CTX_new(); + BN_mod(bn_,bn_,op2.bn_,bn_ctx); + BN_CTX_free(bn_ctx); + return *this; + } + + Plaintext Plaintext::operator&(const Plaintext &op2) const + { + Plaintext result; + std::vector<uint8_t> a_vec; + std::vector<uint8_t> b_vec; + std::vector<uint8_t> c_vec; + + bool is_res_negtive; + if ((this->IsNegative() && !op2.IsNegative()) || + (!this->IsNegative() && op2.IsNegative())) { + is_res_negtive = true; + } else { + is_res_negtive = false; + } + + a_vec = bnTobin(bn_); + b_vec = bnTobin(op2.bn_); + + int size = std::max(a_vec.size(), b_vec.size()); + + for (int i = 0; i < size; i++) + { + c_vec.push_back(a_vec[i] & b_vec[i]); + } + uint8_t * _buff; + vector_to_char_array(c_vec,_buff); + BN_bin2bn(_buff, c_vec.size(),result.bn_); + if(is_res_negtive) + { + BN_set_negative(result.bn_,1); + } + return result; + } + Plaintext Plaintext::operator|(const Plaintext &op2) const + { + Plaintext result; + + std::vector<uint8_t> a_vec; + std::vector<uint8_t> b_vec; + std::vector<uint8_t> c_vec; + + bool is_res_negtive; + if ((this->IsNegative() && !op2.IsNegative()) || + (!this->IsNegative() && op2.IsNegative())) { + is_res_negtive = true; + } else { + is_res_negtive = false; + } + a_vec = bnTobin(bn_); + b_vec = bnTobin(op2.bn_); + + int size = std::max(a_vec.size(), b_vec.size()); + + for (int i = 0; i < size; i++) + { + c_vec.push_back(a_vec[i] | b_vec[i]); + } + uint8_t * _buff; + vector_to_char_array(c_vec,_buff); + BN_bin2bn(_buff, c_vec.size(),result.bn_); + if(is_res_negtive) + { + BN_set_negative(result.bn_,1); + } + return result; + } + Plaintext Plaintext::operator^(const Plaintext &op2) const + { + Plaintext result; + std::vector<uint8_t> a_vec; + std::vector<uint8_t> b_vec; + std::vector<uint8_t> c_vec; + + bool is_res_negtive; + if ((this->IsNegative() && !op2.IsNegative()) || + (!this->IsNegative() && op2.IsNegative())) { + is_res_negtive = true; + } else { + is_res_negtive = false; + } + a_vec = bnTobin(bn_); + b_vec = bnTobin(op2.bn_); + + int size = std::max(a_vec.size(), b_vec.size()); + + for (int i = 0; i < size; i++) + { + c_vec.push_back(a_vec[i] ^ b_vec[i]); + } + + uint8_t * _buff; + vector_to_char_array(c_vec,_buff); + BN_bin2bn(_buff, c_vec.size(),result.bn_); + if(is_res_negtive) + { + BN_set_negative(result.bn_,1); + } + return result; + } + Plaintext Plaintext::operator<<(size_t op2) const + { + Plaintext result; + BN_lshift(result.bn_,bn_,op2); + return result; + } + Plaintext Plaintext::operator>>(size_t op2) const + { + Plaintext result; + BN_rshift(result.bn_,bn_,op2); + return result; + } + + Plaintext Plaintext::operator&=(const Plaintext &op2) + { + std::vector<uint8_t> a_vec; + std::vector<uint8_t> b_vec; + std::vector<uint8_t> c_vec; + + bool is_res_negtive; + if ((this->IsNegative() && !op2.IsNegative()) || + (!this->IsNegative() && op2.IsNegative())) { + is_res_negtive = true; + } else { + is_res_negtive = false; + } + + a_vec = bnTobin(bn_); + std::reverse(a_vec.begin(), a_vec.end()); + b_vec = bnTobin(op2.bn_); + std::reverse(b_vec.begin(), b_vec.end()); + + int size = std::max(a_vec.size(), b_vec.size()); + + for (int i = 0; i < size; i++) + { + a_vec[i] = ((size_t)i < a_vec.size()) ? a_vec[i] : (uint32_t)0; + a_vec[i] &= b_vec[i]; + c_vec.push_back(a_vec[i]); + } + + std::reverse(c_vec.begin(), c_vec.end()); + uint8_t * _buff; + vector_to_char_array(c_vec,_buff); + BN_bin2bn(_buff, c_vec.size(),this->bn_); + if(is_res_negtive) + { + BN_set_negative(this->bn_,1); + } + return *this; + } + + Plaintext Plaintext::operator|=(const Plaintext &op2) + { + std::vector<uint8_t> a_vec; + std::vector<uint8_t> b_vec; + std::vector<uint8_t> c_vec; + + bool is_res_negtive; + if ((this->IsNegative() && !op2.IsNegative()) || + (!this->IsNegative() && op2.IsNegative())) { + is_res_negtive = true; + } else { + is_res_negtive = false; + } + + a_vec = bnTobin(bn_); + std::reverse(a_vec.begin(), a_vec.end()); + b_vec = bnTobin(op2.bn_); + std::reverse(b_vec.begin(), b_vec.end()); + + int size = std::max(a_vec.size(), b_vec.size()); + + for (int i = 0; i < size; i++) + { + a_vec[i] = ((size_t)i < a_vec.size()) ? a_vec[i] : (uint32_t)0; + a_vec[i] |= b_vec[i]; + c_vec.push_back(a_vec[i]); + } + std::reverse(c_vec.begin(), c_vec.end()); + + uint8_t * _buff; + vector_to_char_array(c_vec,_buff); + BN_bin2bn(_buff, c_vec.size(),this->bn_); + if(is_res_negtive) + { + BN_set_negative(this->bn_,1); + } + return *this; + } + Plaintext Plaintext::operator^=(const Plaintext &op2) + { + std::vector<uint8_t> a_vec; + std::vector<uint8_t> b_vec; + std::vector<uint8_t> c_vec; + + bool is_res_negtive; + if ((this->IsNegative() && !op2.IsNegative()) || + (!this->IsNegative() && op2.IsNegative())) { + is_res_negtive = true; + } else { + is_res_negtive = false; + } + a_vec = bnTobin(bn_); + std::reverse(a_vec.begin(), a_vec.end()); + b_vec = bnTobin(op2.bn_); + std::reverse(b_vec.begin(), b_vec.end()); + + int size = std::max(a_vec.size(), b_vec.size()); + + for (int i = 0; i < size; i++) + { + a_vec[i] = ((size_t)i < a_vec.size()) ? a_vec[i] : (uint32_t)0; + a_vec[i] ^= b_vec[i]; + c_vec.push_back(a_vec[i]); + } + std::reverse(c_vec.begin(), c_vec.end()); + uint8_t * _buff; + vector_to_char_array(c_vec,_buff); + BN_bin2bn(_buff, c_vec.size(),this->bn_); + if(is_res_negtive) + { + BN_set_negative(this->bn_,1); + } + return *this; + } + Plaintext Plaintext::operator<<=(size_t op2) + { + BN_lshift(bn_,bn_,op2); + return *this; + } + Plaintext Plaintext::operator>>=(size_t op2) + { + BN_rshift(bn_,bn_,op2); + return *this; + } + + void Plaintext::NegateInplace(){ + if(BN_is_negative(bn_)) + { + BN_set_negative(bn_,0); + } + else{ + BN_set_negative(bn_,1); + } + } +} \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/plaintext.h b/heu/library/algorithms/leichi_paillier/plaintext.h new file mode 100644 index 00000000..2e9e5e44 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/plaintext.h @@ -0,0 +1,208 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "openssl/bn.h" +#include "heu/library/algorithms/util/spi_traits.h" +// #include "yacl/base/exception.h" +#include <ostream> +#include <string> +#include <utility> +#include "yacl/base/byte_container_view.h" +#include "cereal/archives/portable_binary.hpp" +#include <cstdint> +#include <iostream> +#include <vector> +#include <cmath> +#include <stdio.h> +#include "heu/library/algorithms/leichi_paillier/utils.h" + +using int128_t = __int128_t; +using uint128_t = __uint128_t; +namespace heu::lib::algorithms::leichi_paillier { + template <typename T> + void ValueVecToPtsVec(std::vector<T>& value_vec, std::vector<T*>& pts_vec) { + int size = value_vec.size(); + for (int i = 0; i < size; i++) { + pts_vec.push_back(&value_vec[i]); + } + } + + class Plaintext { + public: + BIGNUM* bn_; + public: + Plaintext() { + bn_ = BN_new(); + } + ~Plaintext(){ + BN_free(bn_); + } + Plaintext(const Plaintext& other) { + bn_ = BN_dup(other.bn_); + } + + Plaintext operator-() const; + + Plaintext& operator=(const Plaintext& other) { + if (this != &other) { + BN_copy(bn_, other.bn_); + } + return *this; + } + + Plaintext operator*(const Plaintext& other) const { + Plaintext result; + BN_CTX *bn_ctx = BN_CTX_new(); + BN_mul(result.bn_, bn_, other.bn_, bn_ctx); + BN_CTX_free(bn_ctx); + return result; + } + + Plaintext operator+(const Plaintext& other) const { + Plaintext result; + BN_add(result.bn_, bn_, other.bn_); + return result; + } + + Plaintext operator-(const Plaintext &op2) const { + Plaintext result; + BN_sub(result.bn_,bn_,op2.bn_); + return result; + } + + Plaintext operator/(const Plaintext &op2) const { + Plaintext result; + Plaintext rem; + BN_CTX *bn_ctx = BN_CTX_new(); + BN_div(result.bn_,rem.bn_,bn_,op2.bn_,bn_ctx); + BN_CTX_free(bn_ctx); + return result; + } + + Plaintext operator%(const Plaintext &op2) const { + Plaintext result; + BN_CTX *bn_ctx = BN_CTX_new(); + BN_mod(result.bn_,bn_,op2.bn_,bn_ctx); + BN_CTX_free(bn_ctx); + return result; + } + + static Plaintext generateRandom(const Plaintext &op2) { + Plaintext randomNum; + BN_rand_range(randomNum.bn_, op2.bn_); + return randomNum; + } + + BIGNUM* getValue() const { + return bn_; + } + + size_t numBits() const { + return BN_num_bits(bn_); + } + + friend std::ostream &operator<<(std::ostream &os, const Plaintext &pt); + + Plaintext operator&(const Plaintext &op2) const; + Plaintext operator|(const Plaintext &op2) const; + Plaintext operator^(const Plaintext &op2) const; + Plaintext operator<<(size_t op2) const; + Plaintext operator>>(size_t op2) const; + + Plaintext operator+=(const Plaintext &op2); + Plaintext operator-=(const Plaintext &op2); + Plaintext operator*=(const Plaintext &op2); + Plaintext operator/=(const Plaintext &op2); + Plaintext operator%=(const Plaintext &op2); + Plaintext operator&=(const Plaintext &op2); + Plaintext operator|=(const Plaintext &op2); + Plaintext operator^=(const Plaintext &op2); + Plaintext operator<<=(size_t op2); + Plaintext operator>>=(size_t op2); + + bool IsZero() const{return BN_is_zero(bn_);} + bool IsPositive() const{return (BN_is_negative(bn_) == 0 && !BN_is_zero(bn_))?true:false;} + bool IsNegative() const{return (BN_is_negative(bn_) == 1)?true:false;} + size_t BitCount() const { return BN_num_bits(bn_); } + bool operator!=(const Plaintext &other) const; + bool operator>(const Plaintext &other) const{return (BN_cmp(bn_, other.bn_) > 0)?true:false;} + bool operator<(const Plaintext &other) const{return (BN_cmp(bn_, other.bn_) < 0)?true:false;} + bool operator>=(const Plaintext &other) const{return (BN_cmp(bn_, other.bn_) >= 0)?true:false;} + bool operator<=(const Plaintext &other) const{return (BN_cmp(bn_, other.bn_) <= 0)?true:false;} + bool operator==(const Plaintext &other) const{return (BN_cmp(bn_, other.bn_) == 0)?true:false;} + + std::string ToHexString() const; + std::string ToString() const; + + Plaintext get_prime(size_t bit_len){ + Plaintext result; + BN_generate_prime_ex(result.bn_,bit_len,1,NULL,NULL,NULL); + return result; + } + + void get_prime(size_t bit_len,const Plaintext &op) + { + BN_generate_prime_ex(op.bn_,bit_len,1,NULL,NULL,NULL); + } + + template <typename T> + explicit Plaintext(T &value) { + Set(value); + } + + template <typename T> + void Set(T value); + + template <typename T> + T Get() const; + + void Set(const std::string &num, int radix){BN_dec2bn(&bn_, num.c_str());} + + Plaintext Absolute(const Plaintext &pt){ + Plaintext result; + BN_copy(result.bn_,bn_); + BN_set_negative(bn_, 0); + return result; + } + + static void RandomExactBits(size_t bit_size, Plaintext *r); + static void RandomLtN(const Plaintext &n, Plaintext *r); + void NegateInplace(); + + yacl::Buffer Serialize() const{ + uint32_t n_bits_len = BN_num_bits(bn_); + uint8_t* n_arr = new uint8_t[n_bits_len]; + std::vector<uint8_t> vec_tmp; + BN_bn2bin(bn_, n_arr); + uint32_t bytes_len = std::ceil(n_bits_len/8.0); + for(uint32_t i=0;i<bytes_len;i++) + { + vec_tmp.push_back(n_arr[i]); + } + yacl::Buffer buf(vec_tmp.data(), std::ceil(BN_num_bits(bn_)/8.0)); + return buf; + } + void Deserialize(yacl::ByteContainerView buffer){ + std::istringstream is((std::string)buffer); + BN_bin2bn((uint8_t *)(is.str().data()), is.str().size(),bn_); + BN_set_negative(bn_,1); + } + + yacl::Buffer ToBytes(size_t byte_len, Endian endian = Endian::native) const{yacl::Buffer buf(byte_len);return buf;} + void ToBytes(unsigned char *buf, size_t buf_len, + Endian endian = Endian::native) const{} + }; + +} \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/public_key.cc b/heu/library/algorithms/leichi_paillier/public_key.cc new file mode 100644 index 00000000..21d4dc19 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/public_key.cc @@ -0,0 +1,31 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "heu/library/algorithms/leichi_paillier/public_key.h" +namespace heu::lib::algorithms::leichi_paillier { + + void SetCacheTableDensity(size_t density) { + YACL_ENFORCE(density > 0, "density must > 0"); + } + + void PublicKey::Init() { + + } + + std::string PublicKey::ToString() const { + return fmt::format( + n_.ToHexString(), n_.BitCount(), + PlaintextBound().ToHexString(), PlaintextBound().BitCount() - 1); + } +} \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/public_key.h b/heu/library/algorithms/leichi_paillier/public_key.h new file mode 100644 index 00000000..3fad9366 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/public_key.h @@ -0,0 +1,61 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "openssl/bn.h" +#include "heu/library/algorithms/leichi_paillier/plaintext.h" +#pragma once +namespace heu::lib::algorithms::leichi_paillier { + + void SetCacheTableDensity(size_t density); + + class PublicKey { + public: + Plaintext n_; + Plaintext g_; + Plaintext max_plaintext_; + + void Init(); + [[nodiscard]] std::string ToString() const; + + bool operator==(const PublicKey &other) const { + return (n_ == other.n_)?true:false ; + } + + bool operator!=(const PublicKey &other) const { + return (!this->operator==(other))?true:false; + } + + [[nodiscard]] const Plaintext &PlaintextBound() const & { return max_plaintext_; } + + yacl::Buffer Serialize() const{ + uint32_t n_bits_len = BN_num_bits(n_.bn_); + uint8_t* n_arr = new uint8_t[n_bits_len]; + std::vector<uint8_t> vec_tmp; + BN_bn2bin(n_.bn_, n_arr); + uint32_t bytes_len = std::ceil(n_bits_len/8.0); + for(uint32_t i=0;i<bytes_len;i++) + { + vec_tmp.push_back(n_arr[i]); + } + yacl::Buffer buf(vec_tmp.data(), std::ceil(BN_num_bits(n_.bn_)/8.0)); + return buf; + }; + void Deserialize(yacl::ByteContainerView in){ + std::istringstream is((std::string)in); + BN_bin2bn((uint8_t *)(is.str().data()), is.str().size(),n_.bn_); + }; + }; + +} // namespace heu::lib::algorithms::leichi_paillier + diff --git a/heu/library/algorithms/leichi_paillier/runtime.cc b/heu/library/algorithms/leichi_paillier/runtime.cc new file mode 100644 index 00000000..98562e79 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/runtime.cc @@ -0,0 +1,1149 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "heu/library/algorithms/leichi_paillier/runtime.h" +#include <fstream> +namespace heu::lib::algorithms::leichi_paillier { + bool Runtime::dev_connect() + { + if (pPcie.pcie_is_open()) + return true; + return pPcie.open_device() > 0 ? true : false; + } + + bool Runtime::dev_close() + { + pPcie.close_device(); + return OK; + } + + size_t Runtime::dev_write_reg(size_t in_data, size_t addr) + { + return pPcie.write_reg(addr, in_data); + } + + bool Runtime::dev_reset() + { + dev_write_reg(0,0x0010+0x24000); + sleep(2); + dev_write_reg(1, 0x0010+0x24000); + sleep(2); + return OK; + } + + size_t Runtime::dev_read_reg(size_t *out_data, size_t addr) + { + return pPcie.read_reg(addr, (unsigned int *)out_data); + } + + size_t Runtime::dev_write_ddr(uint8_t *in_data, size_t write_len, size_t addr) + { + if (pPcie.write_data(addr, in_data, write_len) < 1){ + return -1; + } + return write_len; + } + + size_t Runtime::dev_read_ddr(uint8_t *out_data, size_t read_len, size_t addr) + { + if (pPcie.read_data(addr, out_data, read_len) < 1){ + return -1; + } + return read_len; + } + + size_t Runtime::dev_write_init(uint8_t *in_data, size_t write_len, size_t addr) + { + if (pPcie.write_data_bypass(addr, in_data, write_len) < 1){ + return -1; + } + return write_len; + } + + size_t Runtime::dev_read_init(uint8_t *out_data, size_t read_len, size_t addr) + { + return pPcie.read_data_bypass(addr, out_data, read_len); + } + + size_t Runtime::dev_reset_device() + { + size_t ret1 = 0; + size_t ret2 = 0; + ret1 = dev_write_reg(0,0x0000+0x24000); + usleep(1000); + ret2 = dev_write_reg(1,0x0000+0x24000); + usleep(1000); + if(ret1 < 1 || ret1 > 100 || ret2 < 1 || ret2 > 100){ + return 0; + } + else{ + return 1; + } + } + + size_t Runtime::api_write_inst(uint8_t *inst, size_t write_len,size_t addr) + { + return dev_write_init(inst, write_len,addr); + } + + void Runtime::api_set_inst_length(size_t len) + { + dev_write_reg(len, FPGA_SEND_BASE_ADDR+0x04); + usleep(1000); + dev_write_reg(1,FPGA_SEND_BASE_ADDR+0x08); + } + + size_t Runtime::api_set_inst_length_clear() + { + size_t ret1 = 0; + size_t ret2 = 0; + ret1 = dev_write_reg(0, FPGA_SEND_BASE_ADDR+0x04); + usleep(1000); + ret2 = dev_write_reg(0,FPGA_SEND_BASE_ADDR+0x08); + if(ret1 < 1 or ret1 > 100 or ret2 < 1 or ret2 > 100){ + return 1; + } + else{ + return 0; + } + } + + size_t Runtime::check_inst_pointer(size_t *out_data) + { + size_t addr = 0x0c+FPGA_SEND_BASE_ADDR; + return dev_read_reg(out_data,addr); + } + + void Runtime::vector_to_char_array(std::vector<uint8_t>& vec, unsigned char* &arr) + { + arr = reinterpret_cast<unsigned char*>(vec.data()); + } + + void Runtime::char_array_to_vector(std::vector<uint8_t> &vec,unsigned char* buff,uint32_t buff_size) + { + for(uint32_t i=0;i<buff_size;i++){ + vec.push_back(buff[i]); + } + } + + OPERATION_TYPE Runtime::operation_trans(std::string operation) + { + OPERATION_TYPE operation_num=NONE; + if("MONT" == operation) + operation_num = MONT; + else if("MONT_CONST" == operation) + operation_num = MONT_CONST; + else if("MOD_ADD" == operation) + operation_num = MOD_ADD; + else if(operation == "MOD_ADD_CONST") + operation_num = MOD_ADD_CONST; + else if(operation == "MOD_MUL_CONST") + operation_num = MOD_MUL_CONST; + else if(operation == "MOD_EXP") + operation_num = MOD_EXP; + else if(operation == "MOD_MUL") + operation_num = MOD_MUL; + else if(operation == "MOD_EXP_CONST_A") + operation_num = MOD_EXP_CONST_A; + else if(operation == "MOD_EXP_CONST_E") + operation_num = MOD_EXP_CONST_E; + else if(operation == "MOD_EXP_CONST_A") + operation_num = MOD_EXP_CONST_A; + else if(operation == "PAILLIER_ENC") + operation_num = PAILLIER_ENC; + else if(operation == "MOD_INV_CONST_P") + operation_num = MOD_INV_CONST_P; + else if(operation == "MOD_INV") + operation_num = MOD_INV; + return operation_num; + } + + int Runtime::get_p_bitcount_chip(uint32_t p_bitcount) + { + uint32_t count = 0; + switch(p_bitcount){ + case 4096: + count = 5; + break; + case 3072: + count = 4; + break; + case 2048: + count = 3; + break; + case 1024: + count = 2; + break; + case 512: + count = 1; + break; + } + return count; + } + + void Runtime::data_inverse(uint8_t* data,uint32_t len) + { + uint32_t i = 0; + uint8_t tmp; + for(i=0;i<len/2;i++) + { + tmp = data[len-1-i]; + data[len-1-i] = data[i]; + data[i] = tmp; + } + } + + int Runtime::gen_param(uint32_t kernels, uint32_t p_bitcount, uint32_t case_v, uint32_t e_bitcount, uint32_t bool_ele_r_square, uint32_t bits) + { + int para = 0; + para += ((kernels % (int)(pow(2, 39))) << 25); + para += ((p_bitcount % (int)(pow(2, 3))) << 22); + para += ((case_v % (int)(pow(2, 3))) << 19); + para += ((e_bitcount % (int)(pow(2, 13))) << 6); + para += ((bool_ele_r_square % (int)(pow(2, 1))) << 5); + para += ((bits % (int)(pow(2, 5))) << 0); + return para; + } + + void Runtime::gen_mont_para(BIGNUM *p, uint32_t p_bitcount,uint8_t*output,uint32_t &len) + { + int block_len = 256; + BIGNUM *temp = BN_new(); + BIGNUM *bn_n_prime = BN_new(); + BIGNUM *bn_r_square = BN_new(); + BN_CTX *bn_ctx = BN_CTX_new(); + len = 0; + + BN_lshift(bn_r_square, BN_value_one(), block_len); + BN_mod_inverse(bn_n_prime, p, bn_r_square, bn_ctx); + BN_zero(temp); + BN_sub(bn_n_prime, temp, bn_n_prime); + BN_nnmod(bn_n_prime, bn_n_prime, bn_r_square, bn_ctx); + BN_lshift(bn_r_square, BN_value_one(), 2 * (p_bitcount)); + BN_mod(bn_r_square, bn_r_square, p, bn_ctx); + + BN_bn2binpad(bn_n_prime,output+len,BYTECOUNT(block_len)); + data_inverse(output+len,BYTECOUNT(block_len)); + + len += BYTECOUNT(block_len); + BN_bn2binpad(bn_r_square,output+len,BYTECOUNT(p_bitcount)); + data_inverse(output+len,BYTECOUNT(p_bitcount)); + len += BYTECOUNT(p_bitcount); + + BN_free(temp); + BN_free(bn_n_prime); + BN_free(bn_r_square); + BN_CTX_free(bn_ctx); + } + + void Runtime::gen_p_mont_para(BIGNUM *p, uint32_t p_bitcount,uint8_t*output,uint32_t &len,uint8_t flg) + { + uint32_t block_len = 256; + BIGNUM *temp = BN_new(); + BIGNUM *temp1 = BN_new(); + BIGNUM *temp2 = BN_new(); + BIGNUM *bn_n_prime = BN_new(); + BIGNUM *bn_r_square = BN_new(); + BN_CTX *bn_ctx = BN_CTX_new(); + BIGNUM *bn_r = BN_new(); + BIGNUM *bn_s = BN_new(); + BIGNUM *bn_x = BN_new(); + BIGNUM *bn_y = BN_new(); + BIGNUM *bn_p = BN_new(); + BIGNUM *bn_s_0 = BN_new(); + BIGNUM *bn_x_0 = BN_new(); + BIGNUM *bn_y_i = BN_new(); + BIGNUM *bn_q = BN_new(); + len = 0; + + BN_lshift(bn_r_square, BN_value_one(), block_len); + BN_mod_inverse(bn_n_prime, p, bn_r_square, bn_ctx); + BN_zero(temp); + BN_sub(bn_n_prime, temp, bn_n_prime); + BN_nnmod(bn_n_prime, bn_n_prime, bn_r_square, bn_ctx); + + BN_lshift(bn_r_square, BN_value_one(), 2 * (p_bitcount)); + BN_mod(bn_r_square, bn_r_square, p, bn_ctx); + BN_lshift(bn_r, BN_value_one(), block_len); + BN_one(bn_x); + BN_set_word(bn_s, 0); + BN_set_word(bn_x_0,0); + BN_set_word(bn_s_0,0); + BN_set_word(bn_y_i,0); + BN_set_word(bn_q,0); + BN_zero(temp1); + BN_zero(temp2); + uint32_t len_tmp = 0; + for (uint32_t i = 0; i < p_bitcount/block_len; i++) { + BN_mod(bn_s_0, bn_s, bn_r, bn_ctx); + BN_mod(bn_x_0, bn_x, bn_r, bn_ctx); + len_tmp = block_len*i; + BN_rshift(bn_y_i, bn_r_square, len_tmp); + BN_mod(bn_y_i, bn_y_i, bn_r, bn_ctx); + BN_mul(bn_q, bn_x_0, bn_y_i, bn_ctx); + BN_add(bn_q, bn_q, bn_s_0); + BN_mul(bn_q, bn_q, bn_n_prime, bn_ctx); + BN_mod(bn_q, bn_q, bn_r, bn_ctx); + BN_mul(temp1, bn_x, bn_y_i, bn_ctx); + BN_mul(temp2, bn_q, p, bn_ctx); + BN_add(bn_s, temp1, bn_s); + BN_add(bn_s, bn_s, temp2); + BN_rshift(bn_s, bn_s, block_len); + } + BN_bn2binpad(bn_n_prime,output+len,BYTECOUNT(block_len)); + data_inverse(output+len,BYTECOUNT(block_len)); + len += BYTECOUNT(block_len); + + if(flg==1){ + BN_bn2binpad(bn_r_square,output+len,BYTECOUNT(p_bitcount)); + data_inverse(output+len,BYTECOUNT(p_bitcount)); + len += BYTECOUNT(p_bitcount); + } + else if(flg==2){ + BN_bn2binpad(bn_r_square,output+len,BYTECOUNT(p_bitcount)); + data_inverse(output+len,BYTECOUNT(p_bitcount)); + len += BYTECOUNT(p_bitcount); + + BN_bn2binpad(bn_s,output+len,BYTECOUNT(p_bitcount)); + data_inverse(output+len,BYTECOUNT(p_bitcount)); + len += BYTECOUNT(p_bitcount); + } + else if(flg == 3){;} + + BN_free(bn_r); + BN_free(bn_s); + BN_free(bn_x); + BN_free(bn_y); + BN_free(bn_p); + BN_free(bn_n_prime); + BN_free(bn_s_0); + BN_free(bn_x_0); + BN_free(bn_y_i); + BN_free(bn_q); + BN_free(temp); + BN_free(temp1); + BN_free(temp2); + BN_free(bn_r_square); + BN_CTX_free(bn_ctx); + } + + int Runtime::dev_compiler(std::string operation_type,uint32_t str_len,uint32_t size,uint32_t p_bitcount,uint32_t e_bitcount,struct executor &executor_dat) + { + compiler._program.type = "vector"; + compiler._program.operation_type = operation_type; + compiler._program.p_bitcount = p_bitcount; + compiler._program.e_bitcount = e_bitcount; + compiler._program.vec_size = size; + compiler.compile(); + executor_dat.inst = compiler.executor.inst; + executor_dat.inst_fpga = compiler.executor.inst_fpga; + executor_dat.in_para_address = compiler.executor.in_para_address; + executor_dat.inst_address = compiler.executor.inst_address; + executor_dat.out_address = compiler.executor.out_address; + executor_dat.out_length = compiler.executor.out_length; + return 0; + } + + void Runtime::dev_gen_data(OPERATION_TYPE operation_type, uint8_t *a, uint8_t *b, uint8_t *n, uint32_t vec_size, std::vector<uint8_t> m_flg,uint32_t p_bitcount, uint32_t e_bitcount,uint8_t *output, bool split_flg,uint32_t a_len,uint32_t b_len,uint32_t n_len,uint32_t &offset) + { + offset = 0; + uint32_t mont_len = 0; + int kernels = 0; + int case_v = 0b000; + int bool_ele_r_square = 0; + int bits = 0; + int para = 0; + + BIGNUM *_n = BN_new(); + int para_p_bitcount = get_p_bitcount_chip(p_bitcount); + if(operation_type == MOD_ADD || operation_type == MOD_ADD_CONST ) {case_v = 0b001;} + uint8_t para_bytes[32]; + memset(para_bytes,0,32); + if(operation_type == PAILLIER_ENC){ + para = gen_param(kernels, para_p_bitcount, case_v, p_bitcount, bool_ele_r_square, bits); + memcpy((uint8_t *)(output+offset),(uint8_t *)&para,sizeof(para)); + + offset += 32; + BN_bin2bn(n, BYTECOUNT(p_bitcount/2), _n); + } + else if(operation_type == MOD_INV_CONST_P){ + para = gen_param(kernels, para_p_bitcount, case_v, e_bitcount, bool_ele_r_square, bits); + memcpy((uint8_t *)(output+offset),(uint8_t *)&para,sizeof(para)); + memcpy((uint8_t *)para_bytes,(uint8_t *)&para,sizeof(para)); + offset += 32; + BN_bin2bn(n, BYTECOUNT(p_bitcount/2), _n); + } + else{ + para = gen_param(kernels, para_p_bitcount, case_v, e_bitcount, bool_ele_r_square, bits); + memcpy((uint8_t *)(output+offset),(uint8_t *)&para,sizeof(para)); + offset += 32; + if(split_flg){ + BN_bin2bn(n, BYTECOUNT(p_bitcount/2), _n); + } + else{ + BN_bin2bn(n, BYTECOUNT(p_bitcount), _n); + } + } + + BIGNUM *pubkey_nsquare = BN_new(); + BIGNUM *pubkey_g = BN_new(); + BIGNUM *m = BN_new(); + BIGNUM *temp = BN_new(); + BIGNUM *temp1 = BN_new(); + BN_CTX *bn_ctx = BN_CTX_new(); + + if(operation_type == PAILLIER_ENC){ + BN_mul(pubkey_nsquare,_n,_n,bn_ctx); + BN_bn2binpad(pubkey_nsquare,output+offset,BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + } + else if(operation_type == MOD_INV_CONST_P){;} + else{ + if(split_flg){ + BN_mul(pubkey_nsquare,_n,_n,bn_ctx); + BN_bn2binpad(pubkey_nsquare,output+offset,BYTECOUNT(p_bitcount)); + } + else{ + BN_bn2binpad(_n,output+offset,BYTECOUNT(p_bitcount)); + } + + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + } + + if(operation_type == PAILLIER_ENC){ + gen_p_mont_para(pubkey_nsquare,p_bitcount, (uint8_t *)(output+offset),mont_len,2); + offset += mont_len; + + BN_lshift(temp1, _n, 1); + BN_bn2binpad(temp1,output+offset,BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + + BN_add(pubkey_g,_n,BN_value_one()); + BN_bn2binpad(pubkey_g,output+offset,BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + } + else if(operation_type == MOD_INV_CONST_P){;} + else{ + if(operation_type == MONT || operation_type==MONT_CONST){ + gen_p_mont_para(pubkey_nsquare,p_bitcount, (uint8_t *)(output+offset),mont_len,0); + offset += mont_len; + } + else if(operation_type == MOD_ADD || operation_type==MOD_ADD_CONST){;} + else{ + if(split_flg){ + BN_mul(pubkey_nsquare,_n,_n,bn_ctx); + gen_mont_para(pubkey_nsquare,p_bitcount, (uint8_t *)(output+offset),mont_len); + offset += mont_len; + } + else{ + gen_mont_para(_n,p_bitcount, (uint8_t *)(output+offset),mont_len); + offset += mont_len; + } + } + } + + uint32_t a_offset=0 ; + uint32_t b_offset=0 ; + uint32_t number_of_outs = std::ceil((float)vec_size/16.0); + uint32_t last_number = 0; + uint32_t numbers = 0; + uint8_t x= 0; + std::vector<uint8_t> xx; + switch(operation_type){ + case MOD_INV: + break; + case MONT: + for(uint32_t i=0;i<vec_size;i++){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(a+a_offset),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + a_offset += BYTECOUNT(p_bitcount); + } + for(uint32_t i=0;i<vec_size;i++){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(b+b_offset),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + b_offset += BYTECOUNT(p_bitcount); + } + break; + case MONT_CONST: + memcpy((uint8_t *)(output+offset),(uint8_t *)(b),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + for(uint32_t i=0;i<vec_size;i++){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(a+a_offset),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + a_offset += BYTECOUNT(p_bitcount); + } + break; + case MOD_MUL: + if(split_flg){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(a+a_offset),a_len); + offset += a_len; + memcpy((uint8_t *)(output+offset),(uint8_t *)(b+b_offset),b_len); + offset += b_len; + } + else{ + for(uint32_t i=0;i<vec_size;i++){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(a+a_offset),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + a_offset += BYTECOUNT(p_bitcount); + } + for(uint32_t i=0;i<vec_size;i++){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(b+b_offset),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + b_offset += BYTECOUNT(p_bitcount); + } + } + break; + case MOD_MUL_CONST: + memcpy((uint8_t *)(output+offset),(uint8_t *)(b),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + for(uint32_t i=0;i<vec_size;i++){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(a+a_offset),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + a_offset += BYTECOUNT(p_bitcount); + } + break; + case MOD_EXP: + if(split_flg){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(a+a_offset),a_len); + offset += a_len; + memcpy((uint8_t *)(output+offset),(uint8_t *)(b+b_offset),b_len); + offset += b_len; + } + else{ + for(uint32_t i=0;i<vec_size;i++){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(a+a_offset),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + a_offset += BYTECOUNT(p_bitcount); + } + for(uint32_t i=0;i<vec_size;i++){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(b+b_offset),BYTECOUNT(e_bitcount)); + data_inverse(output+offset,BYTECOUNT(e_bitcount)); + offset += BYTECOUNT(e_bitcount); + b_offset += BYTECOUNT(e_bitcount); + } + } + break; + case MOD_EXP_CONST_A: + memcpy((uint8_t *)(output+offset),(uint8_t *)(a+a_offset),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + + for(uint32_t i=0;i<vec_size;i++){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(b+b_offset),BYTECOUNT(e_bitcount)); + data_inverse(output+offset,BYTECOUNT(e_bitcount)); + offset += BYTECOUNT(e_bitcount); + b_offset += BYTECOUNT(e_bitcount); + } + break; + case MOD_EXP_CONST_E: + if(split_flg){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(b+b_offset),b_len); + data_inverse(output+offset,b_len); + offset += b_len; + memcpy((uint8_t *)(output+offset),(uint8_t *)(a+a_offset),a_len); + offset += a_len; + } + else{ + memcpy((uint8_t *)(output+offset),(uint8_t *)(b+b_offset),BYTECOUNT(e_bitcount)); + data_inverse(output+offset,BYTECOUNT(e_bitcount)); + offset += BYTECOUNT(e_bitcount); + for(uint32_t i=0;i<vec_size;i++){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(a+a_offset),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + a_offset += BYTECOUNT(p_bitcount); + } + } + break; + case MOD_INV_CONST_P: + if(split_flg){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(b+b_offset),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + for(uint32_t i=0;i<number_of_outs;i++){ + last_number = ((vec_size%16)==0)?16:(vec_size%16); + numbers = (i == (number_of_outs-1))? last_number:16; + for(uint32_t j = 0; j<numbers;j++) + for(uint32_t k = 0;k<(p_bitcount/8);k++) + { + x = a[(i*16+j)*(p_bitcount/8)+k]; + xx.push_back(x); + } + } + memcpy((uint8_t *)(output+offset),xx.data(),xx.size()); + offset += xx.size(); + } + break; + case MOD_ADD: + for(uint32_t i=0;i<vec_size;i++){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(a+a_offset),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + a_offset += BYTECOUNT(p_bitcount); + } + for(uint32_t i=0;i<vec_size;i++){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(b+b_offset),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + b_offset += BYTECOUNT(p_bitcount); + } + break; + case MOD_ADD_CONST: + memcpy((uint8_t *)(output+offset),(uint8_t *)(b+b_offset),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + + for(uint32_t i=0;i<vec_size;i++){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(a+a_offset),BYTECOUNT(p_bitcount)); + data_inverse(output+offset,BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + a_offset += BYTECOUNT(p_bitcount); + } + break; + case PAILLIER_ENC: + for(uint32_t i=0;i<vec_size;i++){ + BN_bin2bn(b+b_offset, BYTECOUNT(p_bitcount/2), m); + BN_lshift(temp, m, 1); + if(m_flg[i] ==1){ + BN_add_word(temp,1); + } + + BN_bn2binpad(temp,output+offset,BYTECOUNT(p_bitcount/2)); + uint8_t tmp_before[2048]; + memset(tmp_before,0,2048); + memcpy(tmp_before,output+offset,BYTECOUNT(p_bitcount/2)); + data_inverse(output+offset,BYTECOUNT(p_bitcount/2)); + uint8_t tmp[2048]; + memset(tmp,0,2048); + memcpy(tmp,output+offset,BYTECOUNT(p_bitcount/2)); + offset += BYTECOUNT(p_bitcount)/2; + b_offset += BYTECOUNT(p_bitcount/2); + } + for(uint32_t i=0;i<vec_size;i++){ + memcpy((uint8_t *)(output+offset),(uint8_t *)(a+a_offset),BYTECOUNT(p_bitcount/2)); + data_inverse(output+offset,BYTECOUNT(p_bitcount/2)); + offset += BYTECOUNT(p_bitcount/2); + a_offset += BYTECOUNT(p_bitcount/2); + } + break; + default: + break; + } + BN_free(_n); + BN_free(pubkey_nsquare); + BN_free(pubkey_g); + BN_free(m); + BN_free(temp); + BN_CTX_free(bn_ctx); + } + + void writeBin(char *path, uint8_t *buf, uint32_t size) + { + FILE *outfile; + + if ((outfile = fopen(path, "wb")) == NULL) + { + printf("\nCan not open the path: %s \n", path); + exit(-1); + } + fwrite(buf, sizeof(uint8_t), size, outfile); + fclose(outfile); + } + + DEV_STATE Runtime::dev_run(struct executor executor_dat,uint8_t *dat_in ,uint32_t dat_len,uint8_t *out) + { + DEV_STATE status = OK; + size_t cur_val = 0; + uint32_t time_cnt = 0; + if(!dev_reset_device()){return Failed;} + + if (!dev_write_ddr(dat_in, dat_len, executor_dat.in_para_address)){ + status = Write_Faild; + return status; + } + + if (!dev_write_ddr(executor_dat.inst.data(), executor_dat.inst.size(),executor_dat.inst_address)){ + std::cout << "dev_write_ddr inst wrong! \n" <<std::endl; + status = Write_Faild; + return status; + } + + if (!api_write_inst(executor_dat.inst_fpga.data(), executor_dat.inst_fpga.size(),0)){ + status = Write_Faild; + return status; + } + + uint32_t inst_fpga_len = executor_dat.inst_fpga.size()/16-1; + api_set_inst_length(inst_fpga_len); + + while (1){ + check_inst_pointer(&cur_val); + if((time_cnt %100) == 0){} + if (cur_val >= 1){ + status = OK; + break; + } + if(time_cnt>100000){ + status = Time_out; + break; + } + usleep(1000); + time_cnt ++; + } + api_set_inst_length_clear(); + dev_read_ddr(out, executor_dat.out_length, executor_dat.out_address); + dev_reset_device(); + return status; + } + + DEV_STATE Runtime::dev_alg_operation(std::string operation_name , uint8_t*a, uint32_t a_len,uint8_t *b,uint32_t b_len ,uint8_t *n,uint32_t n_len,uint32_t vec_size, uint32_t p_bitcount, uint32_t e_bitcount,uint8_t *output,uint32_t &output_len,std::vector<uint8_t> m_flg,bool split_flg) + { + uint32_t input_len=0; + DEV_STATE status = OK; + OPERATION_TYPE operation_type; + struct executor executor_dat; + output_len = 0; + uint32_t input_size = 0; + input_size = BYTECOUNT(p_bitcount)*vec_size+5*BYTECOUNT(p_bitcount)+2*BYTECOUNT(256)+BYTECOUNT(p_bitcount)*vec_size; + uint8_t *input = new uint8_t[input_size]; + memset(input,0x00,input_size); + operation_type = operation_trans(operation_name); + dev_gen_data(operation_type, a, b, n, vec_size, m_flg,p_bitcount, e_bitcount,input, split_flg,a_len,b_len,n_len,input_len); + dev_compiler(operation_name.c_str(),operation_name.length(),vec_size,p_bitcount,e_bitcount,executor_dat); + dev_run(executor_dat,input,input_len,output); + output_len = executor_dat.out_length; + delete [] input; + return status; + } + + void Runtime::big_sub_const_b(uint8_t* input_a, uint8_t* input_b, uint8_t * output, const uint32_t length, const int p_bitcount, const int p_bitcount_const) + { + BIGNUM *a = BN_new(); + BIGNUM *b = BN_new(); + BIGNUM *result = BN_new(); + + BN_bin2bn((unsigned char*)input_b, BYTECOUNT(p_bitcount_const), b); + uint32_t offset = 0; + for (uint32_t i = 0; i < length; i++) { + BN_bin2bn((unsigned char*)(input_a + offset), BYTECOUNT(p_bitcount), a); + BN_sub(result, a, b); + BN_bn2binpad(result, (unsigned char*)(output + offset), BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + } + + BN_free(a); + BN_free(b); + BN_free(result); + } + + void Runtime::big_div_const_b(uint8_t* input_a, uint8_t* input_b, uint8_t * output, const uint32_t length, const int p_bitcount, const int p_bitcount_const, const int p_bitcount_result) + { + BIGNUM *a = BN_new(); + BIGNUM *b = BN_new(); + BIGNUM *result = BN_new(); + BIGNUM *rem = BN_new(); + BN_CTX *bn_ctx = BN_CTX_new(); + + BN_bin2bn((unsigned char*)input_b, BYTECOUNT(p_bitcount_const), b); + uint32_t offset = 0; + uint32_t offset_result = 0; + for (uint32_t i = 0; i < length; i++) { + BN_bin2bn((unsigned char*)(input_a + offset), BYTECOUNT(p_bitcount), a); + BN_div(result, rem, a, b, bn_ctx); + BN_bn2binpad(result, (unsigned char*)(output + offset_result), BYTECOUNT(p_bitcount_result)); + offset += BYTECOUNT(p_bitcount); + offset_result += BYTECOUNT(p_bitcount_result); + } + + BN_free(a); + BN_free(b); + BN_free(result); + BN_free(rem); + BN_CTX_free(bn_ctx); + } + + void Runtime::big_com_and_sub(uint8_t* input_a, uint8_t* input_b, uint8_t* input_p, uint8_t* output, uint8_t* output_flag, const uint32_t length, const int p_bitcount, const int p_bitcount_const, const int p_bitcount_result) + { + BIGNUM *a = BN_new(); + BIGNUM *b = BN_new(); + BIGNUM *p = BN_new(); + BIGNUM *result = BN_new(); + BN_CTX *bn_ctx = BN_CTX_new(); + + BN_bin2bn((unsigned char*)input_b, BYTECOUNT(p_bitcount_const), b); + BN_bin2bn((unsigned char*)input_p, BYTECOUNT(p_bitcount_const), p); + uint32_t offset = 0; + uint32_t offset_result = 0; + int32_t com_bool = 0; + for (uint32_t i = 0; i < length; i++) { + BN_bin2bn((unsigned char*)(input_a + offset), BYTECOUNT(p_bitcount), a); + com_bool = BN_cmp(a, b); + if(com_bool>=0) { + output_flag[i] = 1; + BN_sub(result, p, a); + BN_bn2binpad(result, (unsigned char*)(output + offset_result), BYTECOUNT(p_bitcount_result)); + } + else{ + output_flag[i] = 0; + BN_bn2binpad(a, (unsigned char*)(output + offset_result), BYTECOUNT(p_bitcount_result)); + } + + offset += BYTECOUNT(p_bitcount); + offset_result += BYTECOUNT(p_bitcount_result); + } + BN_free(a); + BN_free(b); + BN_free(p); + BN_free(result); + BN_CTX_free(bn_ctx); + } + + void Runtime::mod_mul_const(uint8_t* input_a, uint8_t* input_b, uint8_t* input_p, uint8_t * output, const uint32_t length, const int p_bitcount) { + BIGNUM *a = BN_new(); + BIGNUM *b = BN_new(); + BIGNUM *p = BN_new(); + BIGNUM *result = BN_new(); + BN_CTX *bn_ctx = BN_CTX_new(); + + BN_bin2bn((unsigned char*)input_p, BYTECOUNT(p_bitcount), p); + BN_bin2bn((unsigned char*)input_b, BYTECOUNT(p_bitcount), b); + uint32_t offset = 0; + for (uint32_t i = 0; i < length; i++) { + BN_bin2bn((unsigned char*)(input_a + offset), BYTECOUNT(p_bitcount), a); + BN_mod_mul(result, a, b, p, bn_ctx); + BN_bn2binpad(result, (unsigned char*)(output + offset), BYTECOUNT(p_bitcount)); + offset += BYTECOUNT(p_bitcount); + } + BN_free(a); + BN_free(b); + BN_free(p); + BN_free(result); + BN_CTX_free(bn_ctx); + } + + void Runtime::paillier_decrypt_step2(struct _private_key private_key, uint32_t p_bitcount, uint32_t vec_size, uint8_t *plaintext_byte,uint32_t plaintext_byte_len,uint8_t *out,uint8_t *out_flg){ + uint32_t p_bitcount_const = 8; + uint32_t p_bitcount_result = 0; + BIGNUM *p = BN_new(); + BIGNUM *q = BN_new(); + BIGNUM *n = BN_new(); + BIGNUM *g = BN_new(); + BIGNUM *mu = BN_new(); + BIGNUM *n_square = BN_new(); + BIGNUM *lamda = BN_new(); + BIGNUM *rem = BN_new(); + BIGNUM *temp = BN_new(); + BIGNUM *temp_1 = BN_new(); + BN_CTX *bn_ctx = BN_CTX_new(); + BN_CTX *bn_n_ctx = BN_CTX_new(); + + uint8_t * p_buff; + uint8_t * q_buff; + vector_to_char_array(private_key.p,p_buff); + BN_bin2bn(p_buff, private_key.p.size(), p); + + vector_to_char_array(private_key.q,q_buff); + BN_bin2bn(q_buff, private_key.q.size(), q); + + BN_mul(n,p,q,bn_n_ctx); + BN_sub(p, p,BN_value_one()); + BN_sub(q, q,BN_value_one()); + BN_mul(lamda,p,q,bn_ctx); + BN_mul(n_square,n,n,bn_ctx); + BN_add(g,n,BN_value_one()); + + BN_mod_mul(mu, g, lamda, n_square, bn_ctx); + BN_sub(mu,mu,BN_value_one()); + BN_div(mu,rem,mu,n,bn_ctx); + BN_mod_inverse(mu,mu,n,bn_ctx); + + uint8_t *result = new uint8_t[vec_size*BYTECOUNT(p_bitcount)]; + memset(result,0,vec_size*BYTECOUNT(p_bitcount)); + uint8_t b_bytes[1]={1}; + data_inverse(plaintext_byte,plaintext_byte_len); + big_sub_const_b(plaintext_byte, b_bytes, result, vec_size, p_bitcount, p_bitcount_const); + + p_bitcount_const = p_bitcount/2; + p_bitcount_result = p_bitcount/2; + uint8_t n_bytes[BYTECOUNT(p_bitcount/2)]; + memset(n_bytes,0,BYTECOUNT(p_bitcount/2)); + BN_bn2binpad(n, (unsigned char*)(n_bytes), BYTECOUNT(p_bitcount/2)); + + uint8_t *div_result = new uint8_t[vec_size*BYTECOUNT(p_bitcount)]; + memset(div_result,0,vec_size*BYTECOUNT(p_bitcount)); + + big_div_const_b(result, n_bytes, div_result, vec_size, p_bitcount, p_bitcount_const, p_bitcount_result); + + uint8_t mu_bytes[BYTECOUNT(p_bitcount/2)]; + memset(mu_bytes,0,BYTECOUNT(p_bitcount/2)); + + BN_bn2binpad(mu, (unsigned char*)(mu_bytes), BYTECOUNT(p_bitcount/2)); + + uint8_t *mul_result = new uint8_t[vec_size*BYTECOUNT(p_bitcount)]; + memset(mul_result,0,vec_size*BYTECOUNT(p_bitcount)); + uint32_t p_bitcount_c = p_bitcount/2; + + mod_mul_const(div_result, mu_bytes, n_bytes, mul_result, vec_size, p_bitcount_c); + BN_set_word(temp,2); + BN_mul(temp,temp,n,bn_ctx); + BN_set_word(temp_1,3); + BN_div(temp,rem,temp,temp_1,bn_ctx); + uint8_t temp_bytes[BYTECOUNT(p_bitcount/2)]; + memset(temp_bytes,0,BYTECOUNT(p_bitcount/2)); + BN_bn2binpad(temp, (unsigned char*)(temp_bytes), BYTECOUNT(p_bitcount/2)); + big_com_and_sub(mul_result, temp_bytes, n_bytes, out, out_flg, vec_size, p_bitcount_c, p_bitcount_const, p_bitcount_result); + BN_free(p); + BN_free(q); + BN_free(n); + BN_free(g); + BN_free(n_square); + BN_free(lamda); + BN_free(rem); + BN_free(temp); + BN_free(temp_1); + BN_free(mu); + BN_CTX_free(bn_ctx); + BN_CTX_free(bn_n_ctx); + + delete [] div_result; + delete [] mul_result; + delete [] result; + } +// } + +DEV_STATE Runtime::paillier_encrypt(uint8_t *m,uint8_t *r,std::vector<uint8_t> m_flg,uint32_t vec_size,struct _public_key public_key,uint8_t *output,uint32_t &output_len) +{ + bool split_flg = true; + DEV_STATE status = OK; + std::string operation_name = "PAILLIER_ENC"; + BIGNUM *n = BN_new(); + uint32_t p_bitcount = 2*public_key.n_bitcount; + uint32_t e_bitcount = p_bitcount/2; + uint8_t * n_buff; + uint32_t bits_len = 0; + + uint8_t nn[BYTECOUNT(p_bitcount/2)]; + memset(nn,0x00,sizeof(nn)); + + vector_to_char_array(public_key.n,n_buff); + BN_bin2bn(n_buff, public_key.n.size(), n); + bits_len = BN_num_bits(n); + BN_bn2binpad(n,nn,BYTECOUNT(p_bitcount/2)); + + uint32_t r_len = 0; + uint32_t m_len = 0; + uint32_t nn_len = 0; + + nn_len = BYTECOUNT(bits_len); + dev_alg_operation(operation_name,r,r_len,m,m_len,nn,nn_len,vec_size,p_bitcount,e_bitcount,output,output_len,m_flg,split_flg); + BN_free(n); + return status; +} + +DEV_STATE Runtime::paillier_decrypt(uint8_t *ct,uint32_t vec_size,struct _private_key private_key,uint8_t *m_output,uint8_t *m_output_flg) +{ + bool split_flg = true; + std::string operation_name = "MOD_EXP_CONST_E"; + struct executor executor_dat; + BIGNUM *p = BN_new(); + BIGNUM *q = BN_new(); + BIGNUM *n = BN_new(); + BIGNUM *lamda = BN_new(); + BN_CTX *bn_ctx = BN_CTX_new(); + uint32_t p_bitcount = 2*private_key.n_bitcount; + uint32_t e_bitcount = p_bitcount/2; + uint8_t * p_buff; + uint8_t * q_buff; + vector_to_char_array(private_key.p,p_buff); + BN_bin2bn(p_buff, private_key.p.size(), p); + vector_to_char_array(private_key.q,q_buff); + BN_bin2bn(q_buff, private_key.q.size(), q); + BN_mul(n,p,q,bn_ctx); + BN_sub(p, p,BN_value_one()); + BN_sub(q, q,BN_value_one()); + BN_mul(lamda,p,q,bn_ctx); + + uint32_t ct_len = BYTECOUNT(p_bitcount)*vec_size; + uint32_t bits_len=0; + uint32_t lamba_byte_len =0; + uint32_t n_len =0; + + bits_len = BN_num_bits(lamda); + lamba_byte_len = BYTECOUNT(bits_len); + uint8_t lamba_byte[lamba_byte_len]; + memset(lamba_byte,0x00,sizeof(lamba_byte)); + BN_bn2binpad(lamda,lamba_byte,lamba_byte_len); + bits_len = BN_num_bits(n); + n_len = BYTECOUNT(bits_len); + uint8_t nn[n_len]; + memset(nn,0x00,sizeof(nn)); + BN_bn2binpad(n,nn,n_len); + + uint32_t output_len = 0; + uint8_t *m1_output = new uint8_t[vec_size*BYTECOUNT(p_bitcount)]; + memset(m1_output,0x00,vec_size*BYTECOUNT(p_bitcount)); + + std::vector<uint8_t > m_flg; + dev_alg_operation(operation_name,ct,ct_len,lamba_byte,lamba_byte_len,nn,n_len,vec_size,p_bitcount,e_bitcount,m1_output,output_len,m_flg,split_flg); + paillier_decrypt_step2(private_key,p_bitcount,vec_size,m1_output,output_len,m_output,m_output_flg); + + BN_free(p); + BN_free(q); + BN_free(n); + BN_free(lamda); + BN_CTX_free(bn_ctx); + delete [] m1_output; + return OK; +} + +DEV_STATE Runtime::paillier_add(uint8_t *ct1,uint32_t ct1_len,uint8_t *ct2 ,uint32_t ct2_len,uint32_t vec_size,uint8_t *ct_output,uint32_t &output_len,struct _public_key public_key) +{ + DEV_STATE status = OK; + std::string operation_name = "MOD_MUL"; + bool split_flg = true; + uint32_t p_bitcount = 2*public_key.n_bitcount; + uint32_t e_bitcount = p_bitcount/2; + uint8_t * n_buff; + uint32_t bits_len = 0; + uint32_t nn_len = 0; + + BIGNUM *n = BN_new(); + vector_to_char_array(public_key.n,n_buff); + BN_bin2bn(n_buff, public_key.n.size(), n); + bits_len = BN_num_bits(n); + nn_len = BYTECOUNT(bits_len); + + uint8_t nn[BYTECOUNT(p_bitcount/2)]; + memset(nn,0x00,sizeof(nn)); + BN_bn2binpad(n,nn,BYTECOUNT(p_bitcount/2)); + + std::vector<uint8_t > m_flg; + dev_alg_operation(operation_name,ct1,ct1_len,ct2,ct2_len,nn,nn_len,vec_size,p_bitcount,e_bitcount,ct_output,output_len,m_flg,split_flg); + BN_free(n); + return status; +} + +DEV_STATE Runtime::paillier_mul(uint8_t *ct1,uint32_t ct1_len,uint8_t *m,uint32_t m_size,uint32_t vec_size,uint8_t *ct_output,uint32_t &output_len,struct _public_key public_key) +{ + DEV_STATE status = OK; + std::string operation_name; + bool split_flg = true; + uint32_t m_len = 0; + uint32_t nn_len = 0; + if(m_size == 1) + { + operation_name = "MOD_EXP_CONST_E"; + } + else + { + operation_name = "MOD_EXP"; + } + + uint32_t p_bitcount = 2*public_key.n_bitcount; + uint32_t e_bitcount = p_bitcount/2; + uint8_t * n_buff; + uint32_t bits_len = 0; + + BIGNUM *n = BN_new(); + + vector_to_char_array(public_key.n,n_buff); + BN_bin2bn(n_buff, public_key.n.size(), n); + bits_len = BN_num_bits(n); + + nn_len = BYTECOUNT(bits_len); + + uint8_t nn[nn_len]; + memset(nn,0x00,sizeof(nn)); + BN_bn2binpad(n,nn,nn_len); + + std::vector<uint8_t > m_flg; + + if(m_size == 1) + { + m_len = BYTECOUNT(e_bitcount); + dev_alg_operation(operation_name,ct1,ct1_len,m,m_len,nn,nn_len,vec_size,p_bitcount,e_bitcount,ct_output,output_len,m_flg,split_flg); + } + else{ + uint32_t offset =0; + m_len = m_size*BYTECOUNT(e_bitcount); + for(uint32_t i=0;i<m_size;i++) + { + data_inverse(m+offset,BYTECOUNT(e_bitcount)); + offset +=BYTECOUNT(e_bitcount); + } + dev_alg_operation(operation_name,ct1,ct1_len,m,m_len,nn,nn_len,vec_size,p_bitcount,e_bitcount,ct_output,output_len,m_flg,split_flg); + } + + BN_free(n); + return status; +} + +DEV_STATE Runtime::paillier_sub(uint8_t *ct1,uint32_t ct1_len,uint8_t *ct2,uint32_t ct2_len,uint32_t vec_size,uint8_t *ct_output,uint32_t &output_len,struct _public_key public_key) +{ + DEV_STATE status = OK; + std::string operateion_name = "MOD_INV_CONST_P"; + uint32_t p_bitcount = 2*public_key.n_bitcount; + uint32_t e_bitcount = p_bitcount/2; + uint8_t * n_buff; + uint32_t bits_len = 0; + + BIGNUM *n = BN_new(); + BIGNUM *n_squre = BN_new(); + BN_CTX *bn_ctx = BN_CTX_new(); + + vector_to_char_array(public_key.n,n_buff); + BN_bin2bn(n_buff, public_key.n.size(), n); + bits_len = BN_num_bits(n); + BN_mul(n_squre,n,n,bn_ctx); + + uint8_t n_squre_byte[BYTECOUNT(p_bitcount)]; + uint32_t n_squre_offset = 0; + + uint8_t nn[BYTECOUNT(p_bitcount)]; + memset(nn,0x00,sizeof(nn)); + + bits_len = BN_num_bits(n_squre); + BN_bn2binpad(n_squre,n_squre_byte+n_squre_offset,BYTECOUNT(p_bitcount)); + bits_len = BN_num_bits(n); + uint32_t nn_len = 0; + nn_len = BYTECOUNT(bits_len); + + memset(nn,0x00,sizeof(nn)); + BN_bn2binpad(n,nn,nn_len); + uint32_t split_flg = true; + uint32_t n_squre_byte_len = BYTECOUNT(bits_len); + std::vector<uint8_t > m_flg; + dev_alg_operation(operateion_name,ct2,ct2_len,n_squre_byte,n_squre_byte_len,nn,nn_len,vec_size,p_bitcount,e_bitcount,ct_output,output_len,m_flg,split_flg); + operateion_name = "MOD_MUL"; + dev_alg_operation(operateion_name,ct1,ct1_len,ct_output,output_len,nn,nn_len,vec_size,p_bitcount,e_bitcount,ct_output,output_len,m_flg,split_flg); + BN_free(n); + BN_free(n_squre); + BN_CTX_free(bn_ctx); + return status; +} + +} \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/runtime.h b/heu/library/algorithms/leichi_paillier/runtime.h new file mode 100644 index 00000000..b3b035f5 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/runtime.h @@ -0,0 +1,138 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "openssl/bn.h" +#include <stdlib.h> +#include <string.h> +#include <openssl/bn.h> +#include <openssl/rand.h> +#include <openssl/err.h> +#include "heu/library/algorithms/leichi_paillier/pcie/pcie.h" +#include <cstdint> +#include <iostream> +#include <vector> +#include <cmath> +#include "heu/library/algorithms/leichi_paillier/compiler/compiler.h" +#include <stdio.h> +#include <unordered_map> +#include <sstream> + +namespace heu::lib::algorithms::leichi_paillier { + #define BYTECOUNT(x) ((x)>>3) + #define FPGA_SEND_BASE_ADDR 0x24000 + #define NUMBER_OF_PE 16 + #pragma pack(1) + enum DEV_STATE{ + OK = 0x00, + Failed = 0x01, + Invaled_Dat = 0x02, + Invaled_Param = 0x03, + Write_Faild = 0x04, + Read_Faild = 0x05, + Time_out = 0x06 + }; + enum OPERATION_TYPE{ + NONE = 0, + MONT = 1, + MONT_CONST = 2, + MOD_MUL = 3, + MOD_MUL_CONST = 4, + MOD_EXP = 5, + MOD_EXP_CONST_A = 6, + MOD_EXP_CONST_E = 7, + MOD_INV_CONST_P = 8, + MOD_ADD = 9, + MOD_ADD_CONST = 10, + // SRAM_DATA_SHIFT = 11, + PAILLIER_ENC = 12, + MOD_INV = 13 + }; + + struct _public_key{ + int n_bitcount; + std::vector<uint8_t> g; + std::vector<uint8_t> n; + }; + + struct _private_key{ + int n_bitcount; + std::vector<uint8_t> p; + std::vector<uint8_t> q; + }; + + struct _paillier_key{ + _public_key public_key; + _private_key private_key; + }; + + struct executor + { + std::vector<uint8_t> data_in; + std::vector<uint8_t> data_out; + std::vector<uint8_t> inst; + std::vector<uint8_t> inst_fpga; + uint32_t in_para_address; + uint32_t inst_address; + uint32_t out_address; + uint32_t out_length; + }; + #pragma pack() + class Runtime { + public: + Runtime() = default; + bool dev_connect(); + bool dev_close(); + bool dev_reset(); + DEV_STATE paillier_encrypt(uint8_t *m,uint8_t *r,std::vector<uint8_t> m_flg,uint32_t vec_size,struct _public_key public_key,uint8_t *output,uint32_t &output_len); + DEV_STATE paillier_decrypt(uint8_t *ct,uint32_t vec_size,struct _private_key private_key,uint8_t *m_output,uint8_t *m_output_flg); + DEV_STATE paillier_add(uint8_t *ct1,uint32_t ct1_len,uint8_t *ct2 ,uint32_t ct2_len,uint32_t vec_size,uint8_t *ct_output,uint32_t &output_len,struct _public_key public_key); + DEV_STATE paillier_mul(uint8_t *ct1,uint32_t ct1_len,uint8_t *m,uint32_t m_size,uint32_t vec_size,uint8_t *ct_output,uint32_t &output_len,struct _public_key public_key); + DEV_STATE paillier_sub(uint8_t *ct1,uint32_t ct1_len,uint8_t *ct2,uint32_t ct2_len,uint32_t vec_size,uint8_t *ct_output,uint32_t &output_len,struct _public_key public_key); + private: + CPcieComm pPcie; + Compiler compiler; + void data_inverse(uint8_t* data,uint32_t len); + void char_array_to_vector(std::vector<uint8_t> &vec,unsigned char* buff,uint32_t buff_size); + void vector_to_char_array(std::vector<uint8_t>& vec, unsigned char* &arr) ; + void gen_mont_para(BIGNUM *p, uint32_t p_bitcount,uint8_t*output,uint32_t &len); + void gen_p_mont_para(BIGNUM *p, uint32_t p_bitcount,uint8_t*output,uint32_t &len,uint8_t flg); + + int gen_param(uint32_t kernels, uint32_t p_bitcount, uint32_t case_v, uint32_t e_bitcount, uint32_t bool_ele_r_square, uint32_t bits); + OPERATION_TYPE operation_trans(std::string operation); + + size_t dev_write_reg(size_t in_data, size_t addr); + size_t dev_read_reg(size_t *out_data, size_t addr); + size_t dev_write_ddr(uint8_t *in_data, size_t write_len, size_t addr); + size_t dev_read_ddr(uint8_t *out_data, size_t read_len, size_t addr); + size_t dev_write_init(uint8_t *in_data, size_t write_len, size_t addr); + size_t dev_read_init(uint8_t *out_data, size_t read_len, size_t addr); + void api_set_inst_length(size_t len) ; + size_t dev_reset_device(); + size_t api_write_inst(uint8_t *inst, size_t write_len,size_t addr); + size_t api_set_inst_length_clear(); + size_t check_inst_pointer(size_t *out_data); + + int get_p_bitcount_chip(uint32_t p_bitcount); + void big_sub_const_b(uint8_t* input_a, uint8_t* input_b, uint8_t * output, const uint32_t length, const int p_bitcount, const int p_bitcount_const); + void big_div_const_b(uint8_t* input_a, uint8_t* input_b, uint8_t * output, const uint32_t length, const int p_bitcount, const int p_bitcount_const, const int p_bitcount_result) ; + void big_com_and_sub(uint8_t* input_a, uint8_t* input_b, uint8_t* input_p, uint8_t* output, uint8_t* output_flag, const uint32_t length, const int p_bitcount, const int p_bitcount_const, const int p_bitcount_result); + void mod_mul_const(uint8_t* input_a, uint8_t* input_b, uint8_t* input_p, uint8_t * output, const uint32_t length, const int p_bitcount); + int dev_compiler(std::string operation_type,uint32_t str_len,uint32_t size,uint32_t p_bitcount,uint32_t e_bitcount,struct executor &executor_dat); + void dev_gen_data(OPERATION_TYPE operation_type, uint8_t *a, uint8_t *b, uint8_t *n, uint32_t vec_size, std::vector<uint8_t> m_flg,uint32_t p_bitcount, uint32_t e_bitcount,uint8_t *output, bool split_flg,uint32_t a_len,uint32_t b_len,uint32_t n_len,uint32_t &offset); + DEV_STATE dev_run(struct executor executor_dat,uint8_t *dat_in ,uint32_t dat_len,uint8_t *out); + DEV_STATE dev_alg_operation(std::string operation_name , uint8_t*a, uint32_t a_len,uint8_t *b,uint32_t b_len ,uint8_t *n,uint32_t n_len,uint32_t vec_size, uint32_t p_bitcount, uint32_t e_bitcount,uint8_t *output,uint32_t &output_len,std::vector<uint8_t> m_flg,bool split_flg=false); + void paillier_decrypt_step2(struct _private_key private_key, uint32_t p_bitcount, uint32_t vec_size, uint8_t *plaintext_byte,uint32_t plaintext_byte_len,uint8_t *out,uint8_t *out_flg); + }; +} \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/secret_key.cc b/heu/library/algorithms/leichi_paillier/secret_key.cc new file mode 100644 index 00000000..c8844b47 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/secret_key.cc @@ -0,0 +1,17 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "heu/library/algorithms/leichi_paillier/secret_key.h" + namespace heu::lib::algorithms::leichi_paillier { +} \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/secret_key.h b/heu/library/algorithms/leichi_paillier/secret_key.h new file mode 100644 index 00000000..442c1ab0 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/secret_key.h @@ -0,0 +1,43 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "heu/library/algorithms/util/he_object.h" +#include "openssl/bn.h" +#include "heu/library/algorithms/leichi_paillier/plaintext.h" +namespace heu::lib::algorithms::leichi_paillier { + class SecretKey { + public: + Plaintext p_, q_; + bool operator==(const SecretKey &other) const { + return p_ == other.p_ && q_ == other.q_ ; + } + + bool operator!=(const SecretKey &other) const { + return !this->operator==(other); + } + + [[nodiscard]] std::string ToString() const { + return fmt::format("leichi_paillier SK, p={}[{}bits], q={}[{}bits]", p_.ToHexString(), + p_.BitCount(), q_.ToHexString(), q_.BitCount()); + } + + yacl::Buffer Serialize() const { YACL_THROW("Not implemented."); } + void Deserialize(yacl::ByteContainerView in) { + YACL_THROW("Not implemented."); + } + }; + +} + diff --git a/heu/library/algorithms/leichi_paillier/utils.cc b/heu/library/algorithms/leichi_paillier/utils.cc new file mode 100644 index 00000000..0ca7cdc9 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/utils.cc @@ -0,0 +1,17 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "utils.h" + namespace heu::lib::algorithms::leichi_paillier { +} \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/utils.h b/heu/library/algorithms/leichi_paillier/utils.h new file mode 100644 index 00000000..a61367d6 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/utils.h @@ -0,0 +1,52 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "openssl/bn.h" +#include "heu/library/algorithms/leichi_paillier/plaintext.h" +#include "heu/library/algorithms/leichi_paillier/ciphertext.h" +namespace heu::lib::algorithms::leichi_paillier { + template <typename T> + std::vector<uint8_t> Tobin(T &op) + { + uint32_t n_bits_len = op.numBits(); + uint8_t* n_arr = new uint8_t[n_bits_len]; + std::vector<uint8_t> vec_tmp; + BN_bn2bin(op.bn_, n_arr); + uint32_t bytes_len = ((n_bits_len)>>3); + for(uint32_t i=0;i<bytes_len;i++) + { + vec_tmp.push_back(n_arr[i]); + } + delete[] n_arr; + return vec_tmp; + } + + template <typename T> + std::vector<uint8_t> bnTobin(T &bn) + { + uint32_t n_bits_len = BN_num_bits(bn); + uint8_t* n_arr = new uint8_t[n_bits_len]; + std::vector<uint8_t> vec_tmp; + BN_bn2bin(bn, n_arr); + uint32_t bytes_len = std::ceil(n_bits_len/8.0); + for(uint32_t i=0;i<bytes_len;i++) + { + vec_tmp.push_back(n_arr[i]); + } + delete[] n_arr; + return vec_tmp; + } + +} \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/vector_decryptor.cc b/heu/library/algorithms/leichi_paillier/vector_decryptor.cc new file mode 100644 index 00000000..47c6e660 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/vector_decryptor.cc @@ -0,0 +1,103 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "heu/library/algorithms/leichi_paillier/vector_decryptor.h" +#include "heu/library/algorithms/leichi_paillier/runtime.h" +namespace heu::lib::algorithms::leichi_paillier { +void Decryptor::Decrypt(ConstSpan<Ciphertext> in_cts, Span<Plaintext> out_pts) const { + + Runtime _runtime; + std::vector<Plaintext> _pts(in_cts.size()); + uint8_t *ct_bytes = new uint8_t[in_cts.size()*BYTECOUNT(pk_.n_.numBits()*2)]; + uint8_t *pt_bytes = new uint8_t[in_cts.size()*BYTECOUNT(pk_.n_.numBits())]; + memset(pt_bytes,0,in_cts.size()*BYTECOUNT(pk_.n_.numBits())); + uint8_t *pt_flg = new uint8_t[in_cts.size()*BYTECOUNT(pk_.n_.numBits())]; + memset(pt_flg,0,in_cts.size()); + uint32_t ct_offset = 0; + uint32_t pt_offset = 0; + struct _paillier_key paillier_key; + paillier_key.private_key.p = Tobin(sk_.p_); + paillier_key.private_key.q = Tobin(sk_.q_); + paillier_key.private_key.n_bitcount = pk_.n_.numBits(); + + for (auto item : in_cts) { + BN_bn2binpad(item->bn_,ct_bytes+ct_offset,BYTECOUNT(pk_.n_.numBits()*2)); + ct_offset += BYTECOUNT(pk_.n_.numBits()*2); + } + + if(_runtime.dev_connect()) + { + _runtime.paillier_decrypt(ct_bytes,in_cts.size(),paillier_key.private_key,pt_bytes,pt_flg); + for (size_t i = 0; i < _pts.size(); i++) { + BN_bin2bn(pt_bytes+pt_offset,BYTECOUNT(pk_.n_.numBits()),_pts[i].bn_); + pt_offset +=BYTECOUNT(pk_.n_.numBits()); + if(pt_flg[i] == 1) + { + BN_set_negative(_pts[i].bn_,1); + } + } + + std::reverse(_pts.begin(),_pts.end()); + for (size_t i = 0; i < _pts.size(); i++) { + *out_pts[i] = Plaintext(_pts[i]); + } + } + _runtime.dev_close(); + delete []pt_bytes; + delete []ct_bytes; + delete []pt_flg; +} + +std::vector<Plaintext> Decryptor::Decrypt(ConstSpan<Ciphertext> cts) const +{ + Runtime _runtime; + std::vector<Plaintext> _pts(cts.size()); + uint8_t *ct_bytes = new uint8_t[cts.size()*BYTECOUNT(pk_.n_.numBits()*2)]; + uint8_t *pt_bytes = new uint8_t[cts.size()*BYTECOUNT(pk_.n_.numBits())]; + memset(pt_bytes,0,cts.size()*BYTECOUNT(pk_.n_.numBits())); + uint8_t *pt_flg = new uint8_t[cts.size()*BYTECOUNT(pk_.n_.numBits())]; + memset(pt_flg,0,cts.size()); + uint32_t ct_offset = 0; + uint32_t pt_offset = 0; + struct _paillier_key paillier_key; + paillier_key.private_key.p = Tobin(sk_.p_); + paillier_key.private_key.q = Tobin(sk_.q_); + paillier_key.private_key.n_bitcount = pk_.n_.numBits(); + + for (auto item : cts) { + BN_bn2binpad(item->bn_,ct_bytes+ct_offset,BYTECOUNT(pk_.n_.numBits()*2)); + ct_offset += BYTECOUNT(pk_.n_.numBits()*2); + } + + if(_runtime.dev_connect()) + { + _runtime.paillier_decrypt(ct_bytes,cts.size(),paillier_key.private_key,pt_bytes,pt_flg); + for (std::size_t i = 0; i < _pts.size(); i++) { + BN_bin2bn(pt_bytes+pt_offset,BYTECOUNT(pk_.n_.numBits()),_pts[i].bn_); + pt_offset +=BYTECOUNT(pk_.n_.numBits()); + if(pt_flg[i] == 1) + { + BN_set_negative(_pts[i].bn_,1); + } + } + std::reverse(_pts.begin(),_pts.end()); + } + _runtime.dev_close(); + delete []pt_bytes; + delete []ct_bytes; + delete []pt_flg; + return _pts; +} + +} // namespace heu::lib::algorithms::leichi_paillier \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/vector_decryptor.h b/heu/library/algorithms/leichi_paillier/vector_decryptor.h new file mode 100644 index 00000000..204a3b27 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/vector_decryptor.h @@ -0,0 +1,37 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "heu/library/algorithms/leichi_paillier/ciphertext.h" +#include "heu/library/algorithms/leichi_paillier/public_key.h" +#include "heu/library/algorithms/leichi_paillier/secret_key.h" +#include "openssl/bn.h" +#include <utility> +#include "heu/library/algorithms/leichi_paillier/utils.h" +namespace heu::lib::algorithms::leichi_paillier { + class Decryptor { + public: + explicit Decryptor(PublicKey pk, SecretKey sk) + : pk_(std::move(pk)), sk_(std::move(sk)) {} + + std::vector<Plaintext> Decrypt(ConstSpan<Ciphertext> cts) const; + void Decrypt(ConstSpan<Ciphertext> in_cts, Span<Plaintext> out_pts) const; + + private: + PublicKey pk_; + SecretKey sk_; + }; +} + + diff --git a/heu/library/algorithms/leichi_paillier/vector_encryptor.cc b/heu/library/algorithms/leichi_paillier/vector_encryptor.cc new file mode 100644 index 00000000..75656a11 --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/vector_encryptor.cc @@ -0,0 +1,109 @@ +#include "heu/library/algorithms/leichi_paillier/vector_encryptor.h" +#include "heu/library/algorithms/leichi_paillier/runtime.h" +namespace heu::lib::algorithms::leichi_paillier { + + std::vector<Ciphertext> Encryptor::EncryptZero(int64_t size) const { + Runtime _runtime; + uint8_t *r_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits())*size]; + uint8_t *m_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits())*size]; + uint8_t *ct_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*size]; + uint32_t r_offset = 0; + uint32_t m_offset = 0; + uint32_t ct_offset = 0; + uint32_t ct_len =0; + + struct _paillier_key paillier_key; + paillier_key.public_key.n = Tobin(pk_.n_); + paillier_key.public_key.g = Tobin(pk_.g_); + paillier_key.public_key.n_bitcount = pk_.n_.numBits(); + + std::vector<Ciphertext> result(size); + Plaintext pt_neg_zero; + pt_neg_zero.Set(0); + + std::vector<uint8_t> m_flg; + for (int64_t i = 0;i<size;i++) + { + m_flg.push_back(0); + } + for (int64_t i = 0;i<size;i++) { + BN_bn2binpad(Plaintext::generateRandom(pk_.n_).bn_,r_bytes+r_offset,BYTECOUNT(pk_.n_.numBits())); + r_offset += BYTECOUNT(pk_.n_.numBits()); + BN_bn2binpad(pt_neg_zero.bn_,m_bytes+m_offset,BYTECOUNT(pk_.n_.numBits())); + m_offset += BYTECOUNT(pk_.n_.numBits()); + } + + if(_runtime.dev_connect()) + { + // _runtime.dev_reset(); + _runtime.paillier_encrypt(m_bytes,r_bytes,m_flg,size,paillier_key.public_key,ct_bytes,ct_len); + for (int64_t i = 0; i < size; i++) { + BN_bin2bn(ct_bytes+ct_offset,BYTECOUNT(pk_.n_.numBits()*2),result[i].bn_); + ct_offset +=BYTECOUNT(pk_.n_.numBits()*2); + } + } + _runtime.dev_close(); + delete []r_bytes; + delete []m_bytes; + delete []ct_bytes; + m_flg.clear(); + return result; + } + int CompareBignum(const BIGNUM* bn1, const BIGNUM* bn2) { + return BN_cmp(bn1, bn2); + } + + std::vector<Ciphertext> Encryptor::Encrypt(ConstSpan<Plaintext> pts) const { + std::vector<Ciphertext> result(pts.size()); + Runtime _runtime; + uint8_t *r_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits())*pts.size()]; + uint8_t *m_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits())*pts.size()]; + uint8_t *ct_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*pts.size()]; + uint32_t r_offset = 0; + uint32_t m_offset = 0; + uint32_t ct_offset = 0; + uint32_t ct_len =0; + std::vector<uint8_t> m_flg; + struct _paillier_key paillier_key; + paillier_key.public_key.n = Tobin(pk_.n_); + paillier_key.public_key.g = Tobin(pk_.g_); + paillier_key.public_key.n_bitcount = pk_.n_.numBits(); + + for (auto item : pts) { + YACL_ENFORCE(CompareBignum(item->bn_, max_plaintext.bn_) < 0, + "Plaintext out of range"); + BN_bn2binpad(Plaintext::generateRandom(pk_.n_).bn_,r_bytes+r_offset,BYTECOUNT(pk_.n_.numBits())); + r_offset += BYTECOUNT(pk_.n_.numBits()); + BN_bn2binpad(item->bn_,m_bytes+m_offset,BYTECOUNT(pk_.n_.numBits())); + m_offset += BYTECOUNT(pk_.n_.numBits()); + } + + for (auto item : pts) + { + if(BN_is_negative(item->bn_)) + { + m_flg.push_back(1); + } + else + { + m_flg.push_back(0); + } + } + + if(_runtime.dev_connect()) + { + // _runtime.dev_reset(); + _runtime.paillier_encrypt(m_bytes,r_bytes,m_flg,pts.size(),paillier_key.public_key,ct_bytes,ct_len); + for (std::size_t i = 0; i < pts.size(); i++) { + BN_bin2bn(ct_bytes+ct_offset,BYTECOUNT(pk_.n_.numBits()*2),result[i].bn_); + ct_offset +=BYTECOUNT(pk_.n_.numBits()*2); + } + } + _runtime.dev_close(); + delete []r_bytes; + delete []m_bytes; + delete []ct_bytes; + m_flg.clear(); + return result; + } +} \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/vector_encryptor.h b/heu/library/algorithms/leichi_paillier/vector_encryptor.h new file mode 100644 index 00000000..d840ccba --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/vector_encryptor.h @@ -0,0 +1,18 @@ +#pragma once +#include "heu/library/algorithms/leichi_paillier/ciphertext.h" +#include "heu/library/algorithms/leichi_paillier/plaintext.h" +#include "heu/library/algorithms/leichi_paillier/public_key.h" +#include "heu/library/algorithms/leichi_paillier/utils.h" +namespace heu::lib::algorithms::leichi_paillier { + class Encryptor { + public: + explicit Encryptor(const PublicKey& pk): pk_(std::move(pk)){max_plaintext = pk.max_plaintext_;} + std::vector<Ciphertext> EncryptZero(int64_t size) const; + std::vector<Ciphertext> Encrypt(ConstSpan<Plaintext> pts) const; + std::pair<std::vector<Ciphertext>, std::vector<std::string>> EncryptWithAudit( + ConstSpan<Plaintext> pts) const{YACL_THROW("Not Implemented.");}; + private: + PublicKey pk_; + Plaintext max_plaintext; + }; +} diff --git a/heu/library/algorithms/leichi_paillier/vector_evaluator.cc b/heu/library/algorithms/leichi_paillier/vector_evaluator.cc new file mode 100644 index 00000000..b99ee54a --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/vector_evaluator.cc @@ -0,0 +1,501 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "heu/library/algorithms/leichi_paillier/vector_evaluator.h" +#include "heu/library/algorithms/leichi_paillier/vector_encryptor.h" +#include "heu/library/algorithms/util/he_assert.h" +#include "heu/library/algorithms/leichi_paillier/runtime.h" +namespace heu::lib::algorithms::leichi_paillier { + void Evaluator::Randomize(Span<Ciphertext> ct) const { + + } + std::vector<Ciphertext> Evaluator::Add(ConstSpan<Ciphertext> a, + ConstSpan<Ciphertext> b) const { + std::vector<Ciphertext> result(a.size()); + struct _paillier_key paillier_key; + paillier_key.public_key.n = Tobin(pk_.n_); + paillier_key.public_key.g = Tobin(pk_.g_); + paillier_key.public_key.n_bitcount = pk_.n_.numBits(); + + uint8_t *a_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*a.size()]; + uint8_t *b_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*a.size()]; + uint8_t *out_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*a.size()]; + uint32_t a_len = a.size()*BYTECOUNT(pk_.n_.numBits()*2); + uint32_t b_len = b.size()*BYTECOUNT(pk_.n_.numBits()*2); + uint32_t vec_size = a.size(); + uint32_t output_len = 0; + uint32_t a_offset = 0; + uint32_t b_offset = 0; + uint32_t out_offset = 0; + + for (auto item : a) { + BN_bn2binpad(item->bn_,a_bytes+a_offset,BYTECOUNT(pk_.n_.numBits()*2)); + a_offset += BYTECOUNT(pk_.n_.numBits()*2); + } + + for (auto item : b) { + BN_bn2binpad(item->bn_,b_bytes+b_offset,BYTECOUNT(pk_.n_.numBits()*2)); + b_offset += BYTECOUNT(pk_.n_.numBits()*2); + } + Runtime _runtime; + if(_runtime.dev_connect()) + { + _runtime.paillier_add(a_bytes,a_len,b_bytes,b_len,vec_size,out_bytes,output_len,paillier_key.public_key); + for (std::size_t i = 0; i < a.size(); i++) { + BN_bin2bn(out_bytes+out_offset,BYTECOUNT(pk_.n_.numBits()*2),result[i].bn_); + out_offset +=BYTECOUNT(pk_.n_.numBits()*2); + } + } + _runtime.dev_close(); + + delete [] a_bytes; + delete [] b_bytes; + delete [] out_bytes; + return result; + } + + std::vector<Ciphertext> Evaluator::Add(ConstSpan<Ciphertext> a, + ConstSpan<Plaintext> b) const { + std::vector<Ciphertext> result(a.size()); + struct _paillier_key paillier_key; + paillier_key.public_key.n = Tobin(pk_.n_); + paillier_key.public_key.g = Tobin(pk_.g_); + paillier_key.public_key.n_bitcount = pk_.n_.numBits(); + + uint8_t *a_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*a.size()]; + uint8_t *b_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits())*a.size()]; + uint8_t *out_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*a.size()]; + uint32_t a_len = a.size()*BYTECOUNT(pk_.n_.numBits()*2); + uint32_t vec_size = a.size(); + uint32_t output_len = 0; + uint32_t a_offset = 0; + uint32_t b_offset = 0; + uint32_t out_offset = 0; + + for (auto item : a) { + BN_bn2binpad(item->bn_,a_bytes+a_offset,BYTECOUNT(pk_.n_.numBits()*2)); + a_offset += BYTECOUNT(pk_.n_.numBits()*2); + } + uint8_t *r_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits())*b.size()]; + uint32_t r_offset = 0; + uint8_t *ct_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*b.size()]; + uint32_t ct_len =0; + std::vector<uint8_t> b_flg; + + for (auto item_b : b) { + BN_bn2binpad(Plaintext::generateRandom(pk_.n_).bn_,r_bytes+r_offset,BYTECOUNT(pk_.n_.numBits())); + r_offset += BYTECOUNT(pk_.n_.numBits()); + BN_bn2binpad(item_b->bn_,b_bytes+b_offset,BYTECOUNT(pk_.n_.numBits())); + b_offset += BYTECOUNT(pk_.n_.numBits()); + } + + for (auto item : b) + { + if(BN_is_negative(item->bn_)) + { + b_flg.push_back(1); + } + else + { + b_flg.push_back(0); + } + } + Runtime _runtime; + if(_runtime.dev_connect()) + { + + _runtime.paillier_encrypt(b_bytes,r_bytes,b_flg,b.size(),paillier_key.public_key,ct_bytes,ct_len); + _runtime.paillier_add(a_bytes,a_len,ct_bytes,ct_len,vec_size,out_bytes,output_len,paillier_key.public_key); + for (std::size_t i = 0; i < a.size(); i++) { + BN_bin2bn(out_bytes+out_offset,BYTECOUNT(pk_.n_.numBits()*2),result[i].bn_); + out_offset +=BYTECOUNT(pk_.n_.numBits()*2); + } + } + _runtime.dev_close(); + + delete [] a_bytes; + delete [] b_bytes; + delete [] out_bytes; + delete [] r_bytes; + delete [] ct_bytes; + b_flg.clear(); + return result; + } + + std::vector<Ciphertext> Evaluator::Add(ConstSpan<Plaintext> a, + ConstSpan<Ciphertext> b) const { + return Add(b, a); + } + + std::vector<Plaintext> Evaluator::Add(ConstSpan<Plaintext> a, + ConstSpan<Plaintext> b) const { + HE_ASSERT(a.size() == b.size(), "PT + PT error: size mismatch."); + std::vector<Plaintext> sum; + size_t vec_size = a.size(); + for (size_t i = 0; i < vec_size; i++) { + sum.push_back(*a[i] + *b[i]); + } + return sum; + } + + void Evaluator::AddInplace(Span<Ciphertext> a, ConstSpan<Ciphertext> b) const { + auto sum = Add(a, b); + size_t vec_size = sum.size(); + for (size_t i = 0; i < vec_size; i++) { + *a[i] = sum[i]; + } + } + + void Evaluator::AddInplace(Span<Ciphertext> a, ConstSpan<Plaintext> b) const { + auto sum = Add(a, b); + size_t vec_size = sum.size(); + for (size_t i = 0; i < vec_size; i++) { + *a[i] = sum[i]; + } + } + + void Evaluator::AddInplace(Span<Plaintext> a, ConstSpan<Plaintext> b) const { + auto sum = Add(a, b); + size_t vec_size = sum.size(); + for (size_t i = 0; i < vec_size; i++) { + *a[i] = sum[i]; + } + } + + std::vector<Ciphertext> Evaluator::Sub(ConstSpan<Ciphertext> a, + ConstSpan<Ciphertext> b) const { + std::vector<Ciphertext> result(a.size()); + HE_ASSERT(a.size() == b.size(), "CT - CT error: size mismatch."); + struct _paillier_key paillier_key; + paillier_key.public_key.n = Tobin(pk_.n_); + paillier_key.public_key.g = Tobin(pk_.g_); + paillier_key.public_key.n_bitcount = pk_.n_.numBits(); + uint8_t *a_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*a.size()]; + uint8_t *b_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits())*2*a.size()]; + uint8_t *out_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*a.size()]; + uint32_t a_len = a.size()*BYTECOUNT(pk_.n_.numBits()*2); + uint32_t b_len = b.size()*BYTECOUNT(pk_.n_.numBits()*2); + uint32_t vec_size = a.size(); + uint32_t output_len = 0; + uint32_t a_offset = 0; + uint32_t b_offset = 0; + uint32_t out_offset = 0; + + for (auto item : a) { + BN_bn2binpad(item->bn_,a_bytes+a_offset,BYTECOUNT(pk_.n_.numBits()*2)); + a_offset += BYTECOUNT(pk_.n_.numBits()*2); + } + + for (auto item : b) { + BN_bn2binpad(item->bn_,b_bytes+b_offset,BYTECOUNT(pk_.n_.numBits()*2)); + b_offset += BYTECOUNT(pk_.n_.numBits()*2); + } + + Runtime _runtime; + if(_runtime.dev_connect()) + { + _runtime.paillier_sub(a_bytes,a_len,b_bytes,b_len,vec_size,out_bytes,output_len,paillier_key.public_key); + for (std::size_t i = 0; i < a.size(); i++) { + BN_bin2bn(out_bytes+out_offset,BYTECOUNT(pk_.n_.numBits()*2),result[i].bn_); + out_offset +=BYTECOUNT(pk_.n_.numBits()*2); + } + } + + _runtime.dev_close(); + delete [] a_bytes; + delete [] b_bytes; + delete [] out_bytes; + return result; + } + + std::vector<Ciphertext> Evaluator::Sub(ConstSpan<Ciphertext> a, + ConstSpan<Plaintext> b) const { + HE_ASSERT(a.size() == b.size(), "CT - PT error: size mismatch."); + std::vector<Ciphertext> result(a.size()); + struct _paillier_key paillier_key; + paillier_key.public_key.n = Tobin(pk_.n_); + paillier_key.public_key.g = Tobin(pk_.g_); + paillier_key.public_key.n_bitcount = pk_.n_.numBits(); + uint8_t *a_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*a.size()]; + uint8_t *b_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits())*2*a.size()]; + uint8_t *out_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*a.size()]; + uint32_t a_len = a.size()*BYTECOUNT(pk_.n_.numBits()*2); + uint32_t vec_size = a.size(); + uint32_t output_len = 0; + uint32_t a_offset = 0; + uint32_t b_offset = 0; + uint32_t out_offset = 0; + + for (auto item : a) { + BN_bn2binpad(item->bn_,a_bytes+a_offset,BYTECOUNT(pk_.n_.numBits()*2)); + a_offset += BYTECOUNT(pk_.n_.numBits()*2); + } + + uint8_t *r_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits())*b.size()]; + uint32_t r_offset = 0; + uint8_t *ct_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*b.size()]; + uint32_t ct_len =0; + std::vector<uint8_t> b_flg; + + for (auto item : b) { + BN_bn2binpad(Plaintext::generateRandom(pk_.n_).bn_,r_bytes+r_offset,BYTECOUNT(pk_.n_.numBits())); + r_offset += BYTECOUNT(pk_.n_.numBits()); + BN_bn2binpad(item->bn_,b_bytes+b_offset,BYTECOUNT(pk_.n_.numBits())); + b_offset += BYTECOUNT(pk_.n_.numBits()); + } + + for (auto item : b) + { + if(BN_is_negative(item->bn_)) + { + b_flg.push_back(1); + } + else + { + b_flg.push_back(0); + } + } + + Runtime _runtime; + if(_runtime.dev_connect()) + { + _runtime.paillier_encrypt(b_bytes,r_bytes,b_flg,b.size(),paillier_key.public_key,ct_bytes,ct_len); + _runtime.paillier_sub(a_bytes,a_len,ct_bytes,ct_len,vec_size,out_bytes,output_len,paillier_key.public_key); + for (std::size_t i = 0; i < a.size(); i++) { + BN_bin2bn(out_bytes+out_offset,BYTECOUNT(pk_.n_.numBits()*2),result[i].bn_); + out_offset +=BYTECOUNT(pk_.n_.numBits()*2); + } + } + _runtime.dev_close(); + delete [] a_bytes; + delete [] b_bytes; + delete [] out_bytes; + delete [] r_bytes; + delete [] ct_bytes; + b_flg.clear(); + return result; + } + + std::vector<Ciphertext> Evaluator::Sub(ConstSpan<Plaintext> a, + ConstSpan<Ciphertext> b) const { + HE_ASSERT(a.size() == b.size(), "CT - PT error: size mismatch."); + std::vector<Ciphertext> result(a.size()); + struct _paillier_key paillier_key; + paillier_key.public_key.n = Tobin(pk_.n_); + paillier_key.public_key.g = Tobin(pk_.g_); + paillier_key.public_key.n_bitcount = pk_.n_.numBits(); + uint8_t *a_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*a.size()]; + uint8_t *b_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits())*2*a.size()]; + uint8_t *out_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*a.size()]; + uint32_t b_len = b.size()*BYTECOUNT(pk_.n_.numBits()*2); + uint32_t vec_size = a.size(); + uint32_t output_len = 0; + uint32_t a_offset = 0; + uint32_t b_offset = 0; + uint32_t out_offset = 0; + + for (auto item : b) { + BN_bn2binpad(item->bn_,b_bytes+b_offset,BYTECOUNT(pk_.n_.numBits()*2)); + b_offset += BYTECOUNT(pk_.n_.numBits()*2); + } + + uint8_t *r_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits())*b.size()]; + uint32_t r_offset = 0; + uint8_t *ct_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*b.size()]; + uint32_t ct_len =0; + + for (auto item : a) { + BN_bn2binpad(Plaintext::generateRandom(pk_.n_).bn_,r_bytes+r_offset,BYTECOUNT(pk_.n_.numBits())); + r_offset += BYTECOUNT(pk_.n_.numBits()); + BN_bn2binpad(item->bn_,a_bytes+a_offset,BYTECOUNT(pk_.n_.numBits())); + a_offset += BYTECOUNT(pk_.n_.numBits()); + } + std::vector<uint8_t> a_flg; + for (auto item : a) + { + if(BN_is_negative(item->bn_)) + { + a_flg.push_back(1); + } + else + { + a_flg.push_back(0); + } + } + + Runtime _runtime; + if(_runtime.dev_connect()) + { + _runtime.paillier_encrypt(a_bytes,r_bytes,a_flg,b.size(),paillier_key.public_key,ct_bytes,ct_len); + _runtime.paillier_sub(ct_bytes,ct_len,b_bytes,b_len,vec_size,out_bytes,output_len,paillier_key.public_key); + for (std::size_t i = 0; i < a.size(); i++) { + BN_bin2bn(out_bytes+out_offset,BYTECOUNT(pk_.n_.numBits()*2),result[i].bn_); + out_offset +=BYTECOUNT(pk_.n_.numBits()*2); + } + } + _runtime.dev_close(); + delete [] a_bytes; + delete [] b_bytes; + delete [] out_bytes; + delete [] r_bytes; + delete [] ct_bytes; + return result; + } + + std::vector<Plaintext> Evaluator::Sub(ConstSpan<Plaintext> a, + ConstSpan<Plaintext> b) const { + HE_ASSERT(a.size() == b.size(), "PT - PT error: size mismatch."); + size_t size = a.size(); + std::vector<Plaintext> result; + for (size_t i = 0; i < size; i++) { + result.push_back(*a[i] - *b[i]); + } + return result; + } + + void Evaluator::SubInplace(Span<Ciphertext> a, ConstSpan<Ciphertext> b) const { + auto res = Sub(a, b); + size_t vec_size = res.size(); + for (size_t i = 0; i < vec_size; i++) { + *a[i] = res[i]; + } + } + void Evaluator::SubInplace(Span<Ciphertext> a, ConstSpan<Plaintext> p) const { + auto res = Sub(a, p); + size_t vec_size = res.size(); + for (size_t i = 0; i < vec_size; i++) { + *a[i] = res[i]; + } + } + void Evaluator::SubInplace(Span<Plaintext> a, ConstSpan<Plaintext> b) const { + auto res = Sub(a, b); + size_t vec_size = res.size(); + for (size_t i = 0; i < vec_size; i++) { + *a[i] = res[i]; + } + } + + std::vector<Ciphertext> Evaluator::Mul(ConstSpan<Ciphertext> a, + ConstSpan<Plaintext> b) const { + HE_ASSERT((a.size() == b.size() || b.size() == 1), + "CT * PT error: size mismatch."); + std::vector<Ciphertext> result(a.size()); + + struct _paillier_key paillier_key; + paillier_key.public_key.n = Tobin(pk_.n_); + paillier_key.public_key.g = Tobin(pk_.g_); + paillier_key.public_key.n_bitcount = pk_.n_.numBits(); + + uint8_t *a_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*a.size()]; + uint8_t *b_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits())*a.size()]; + uint8_t *out_bytes = new uint8_t[BYTECOUNT(pk_.n_.numBits()*2)*a.size()]; + uint32_t a_len = a.size()*BYTECOUNT(pk_.n_.numBits()*2); + uint32_t b_len = b.size(); + uint32_t vec_size = a.size(); + uint32_t output_len = 0; + uint32_t a_offset = 0; + uint32_t b_offset = 0; + uint32_t out_offset = 0; + + for (auto item : a) { + BN_bn2binpad(item->bn_,a_bytes+a_offset,BYTECOUNT(pk_.n_.numBits()*2)); + a_offset += BYTECOUNT(pk_.n_.numBits()*2); + } + + for (auto item : b) { + BN_bn2binpad(item->bn_,b_bytes+b_offset,BYTECOUNT(pk_.n_.numBits())); + b_offset += BYTECOUNT(pk_.n_.numBits()); + } + + Runtime _runtime; + if(_runtime.dev_connect()) + { + _runtime.paillier_mul(a_bytes,a_len,b_bytes,b_len,vec_size,out_bytes,output_len,paillier_key.public_key); + for (std::size_t i = 0; i < a.size(); i++) { + BN_bin2bn(out_bytes+out_offset,BYTECOUNT(pk_.n_.numBits()*2),result[i].bn_); + out_offset +=BYTECOUNT(pk_.n_.numBits()*2); + } + } + _runtime.dev_close(); + + delete [] a_bytes; + delete [] b_bytes; + delete [] out_bytes; + return result; + } + + std::vector<Ciphertext> Evaluator::Mul(ConstSpan<Plaintext> a, + ConstSpan<Ciphertext> b) const { + return Mul(b, a); + } + + std::vector<Plaintext> Evaluator::Mul(ConstSpan<Plaintext> a, + ConstSpan<Plaintext> b) const { + HE_ASSERT((a.size() == b.size() || b.size() == 1), + "PT * PT error: size mismatch."); + std::vector<Plaintext> product; + size_t vec_size = a.size(); + if (b.size() == 1) { + for (size_t i = 0; i < vec_size; i++) { + product.push_back(*a[i] * *b[0]); + } + } else { + for (size_t i = 0; i < vec_size; i++) { + product.push_back(*a[i] * *b[i]); + } + } + return product; + } + + void Evaluator::MulInplace(Span<Ciphertext> a, ConstSpan<Plaintext> b) const { + auto product = Mul(a, b); + size_t vec_size = product.size(); + for (size_t i = 0; i < vec_size; i++) { + *a[i] = product[i]; + } + } + + void Evaluator::MulInplace(Span<Plaintext> a, ConstSpan<Plaintext> b) const { + auto product = Mul(a, b); + size_t vec_size = product.size(); + for (size_t i = 0; i < vec_size; i++) { + *a[i] = product[i]; + } + } + + std::vector<Ciphertext> Evaluator::Negate(ConstSpan<Ciphertext> a) const { + + std::vector<Ciphertext> result(a.size()); + for (uint32_t i = 0;i<a.size();i++) { + BN_copy(result[i].bn_,a[i]->bn_); + if(BN_is_negative(result[i].bn_)) + { + BN_set_negative(result[i].bn_,0); + BN_add_word(result[i].bn_,1); + } + else{ + BN_set_negative(result[i].bn_,1); + } + } + return result; + } + + void Evaluator::NegateInplace(Span<Ciphertext> a) const { + auto neg_a = Negate(a); + size_t vec_size = neg_a.size(); + for (size_t i = 0; i < vec_size; i++) { + *a[i] = neg_a[i]; + } + } +} \ No newline at end of file diff --git a/heu/library/algorithms/leichi_paillier/vector_evaluator.h b/heu/library/algorithms/leichi_paillier/vector_evaluator.h new file mode 100644 index 00000000..90f3e49c --- /dev/null +++ b/heu/library/algorithms/leichi_paillier/vector_evaluator.h @@ -0,0 +1,69 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "heu/library/algorithms/leichi_paillier/ciphertext.h" +#include "heu/library/algorithms/leichi_paillier/plaintext.h" +#include "heu/library/algorithms/leichi_paillier/public_key.h" +#include "heu/library/algorithms/leichi_paillier/secret_key.h" +#include "heu/library/algorithms/leichi_paillier/utils.h" +namespace heu::lib::algorithms::leichi_paillier { + class Evaluator { + private: + PublicKey pk_; + public: + explicit Evaluator(const PublicKey& pk): pk_(std::move(pk)){} + + void Randomize(Span<Ciphertext> ct) const; + + std::vector<Ciphertext> Add(ConstSpan<Ciphertext> a, + ConstSpan<Ciphertext> b) const; + std::vector<Ciphertext> Add(ConstSpan<Ciphertext> a, + ConstSpan<Plaintext> b) const; + std::vector<Ciphertext> Add(ConstSpan<Plaintext> a, + ConstSpan<Ciphertext> b) const; + std::vector<Plaintext> Add(ConstSpan<Plaintext> a, + ConstSpan<Plaintext> b) const; + + void AddInplace(Span<Ciphertext> a, ConstSpan<Ciphertext> b) const; + void AddInplace(Span<Ciphertext> a, ConstSpan<Plaintext> b) const; + void AddInplace(Span<Plaintext> a, ConstSpan<Plaintext> b) const; + + std::vector<Ciphertext> Sub(ConstSpan<Ciphertext> a, + ConstSpan<Ciphertext> b) const; + std::vector<Ciphertext> Sub(ConstSpan<Ciphertext> a, + ConstSpan<Plaintext> b) const; + std::vector<Ciphertext> Sub(ConstSpan<Plaintext> a, + ConstSpan<Ciphertext> b) const; + std::vector<Plaintext> Sub(ConstSpan<Plaintext> a, + ConstSpan<Plaintext> b) const; + + void SubInplace(Span<Ciphertext> a, ConstSpan<Ciphertext> b) const; + void SubInplace(Span<Ciphertext> a, ConstSpan<Plaintext> p) const; + void SubInplace(Span<Plaintext> a, ConstSpan<Plaintext> b) const; + + std::vector<Ciphertext> Mul(ConstSpan<Ciphertext> a, + ConstSpan<Plaintext> b) const; + std::vector<Ciphertext> Mul(ConstSpan<Plaintext> a, + ConstSpan<Ciphertext> b) const; + std::vector<Plaintext> Mul(ConstSpan<Plaintext> a, + ConstSpan<Plaintext> b) const; + + void MulInplace(Span<Ciphertext> a, ConstSpan<Plaintext> b) const; + void MulInplace(Span<Plaintext> a, ConstSpan<Plaintext> b) const; + + std::vector<Ciphertext> Negate(ConstSpan<Ciphertext> a) const; + void NegateInplace(Span<Ciphertext> a) const; + }; +} diff --git a/heu/library/numpy/decryptor.cc b/heu/library/numpy/decryptor.cc index be2a03a6..3401d295 100644 --- a/heu/library/numpy/decryptor.cc +++ b/heu/library/numpy/decryptor.cc @@ -26,7 +26,7 @@ auto DoCallDecrypt(const CLAZZ& sub_decryptor, const CMatrix& in, size_t range_bits, PMatrix* out) -> std::enable_if_t< std::experimental::is_detected_v<kHasVectorizedDecrypt, CLAZZ, CT>> { - yacl::parallel_for(0, in.size(), 1, [&](int64_t beg, int64_t end) { + yacl::parallel_for(0, in.size(), in.size(), [&](int64_t beg, int64_t end) { std::vector<const CT*> cts; cts.reserve(end - beg); for (int64_t i = beg; i < end; ++i) { diff --git a/heu/library/numpy/encryptor.cc b/heu/library/numpy/encryptor.cc index 88312717..ed5c45f9 100644 --- a/heu/library/numpy/encryptor.cc +++ b/heu/library/numpy/encryptor.cc @@ -25,7 +25,7 @@ template <typename CLAZZ, typename PT> auto DoCallEncrypt(const CLAZZ& sub_encryptor, const PMatrix& in, CMatrix* out) -> std::enable_if_t< std::experimental::is_detected_v<kHasVectorizedEncrypt, CLAZZ, PT>> { - yacl::parallel_for(0, in.size(), 1, [&](int64_t beg, int64_t end) { + yacl::parallel_for(0, in.size(), in.size(), [&](int64_t beg, int64_t end) { std::vector<const PT*> pts; pts.reserve(end - beg); for (int64_t i = beg; i < end; ++i) { diff --git a/heu/library/numpy/evaluator.cc b/heu/library/numpy/evaluator.cc index 20dde9ac..205b5786 100644 --- a/heu/library/numpy/evaluator.cc +++ b/heu/library/numpy/evaluator.cc @@ -115,7 +115,7 @@ using kHasVectorizedMul = decltype(std::declval<const CLAZZ&>().Mul( const auto* y_base = y.data(); \ RET::value_type* out_base = out->data(); \ int64_t rows = out->rows(); \ - yacl::parallel_for(0, out->size(), 1, [&](int64_t beg, int64_t end) { \ + yacl::parallel_for(0, out->size(), out->size(), [&](int64_t beg, int64_t end) { \ std::vector<const SUB_TX*> in_x; \ std::vector<const SUB_TY*> in_y; \ in_x.reserve(end - beg); \ diff --git a/heu/library/phe/base/BUILD.bazel b/heu/library/phe/base/BUILD.bazel index df583260..2928739b 100644 --- a/heu/library/phe/base/BUILD.bazel +++ b/heu/library/phe/base/BUILD.bazel @@ -21,6 +21,7 @@ yacl_cc_library( "//heu/library/algorithms/paillier_zahlen", "//heu/library/algorithms/paillier_ipcl", "//heu/library/algorithms/elgamal", + "//heu/library/algorithms/leichi_paillier:leichi_paillier_defs", ], ) diff --git a/heu/library/phe/base/schema.cc b/heu/library/phe/base/schema.cc index 7da2d845..b9b3f2fc 100644 --- a/heu/library/phe/base/schema.cc +++ b/heu/library/phe/base/schema.cc @@ -48,6 +48,7 @@ static const std::map<SchemaType, std::vector<std::string>> "paillier_ipcl", "paillier-ipcl"), MAP_ITEM(true, ElGamal, "elgamal", "ec_elgamal", "exponential_elgamal", "exp_elgamal", "lifted_elgamal"), + MAP_ITEM(ENABLE_LEICHI, Leichi, "leichi","leichi-paillier", "paillier-leichi"), // MAP_ITEM(ENABLE, YOUR_ALGO, "one_or_more_name_alias"), }; diff --git a/heu/library/phe/base/schema.h b/heu/library/phe/base/schema.h index edb57020..d041464d 100644 --- a/heu/library/phe/base/schema.h +++ b/heu/library/phe/base/schema.h @@ -26,6 +26,7 @@ // [SPI: Please register your algorithm here] || progress: (1 of 5) // Do not forget to add your algo header file here // #include "heu/library/algorithms/your_algo/algo.h" +#include "heu/library/algorithms/leichi_paillier/leichi.h" namespace heu::lib::phe { @@ -44,6 +45,7 @@ enum class SchemaType { ENUM_ELEMENT(true, ZPaillier) // Preferred ENUM_ELEMENT(true, FPaillier) ENUM_ELEMENT(true, ElGamal) + ENUM_ELEMENT(ENABLE_LEICHI, Leichi) // YOUR_ALGO }; // clang-format on @@ -82,7 +84,8 @@ std::ostream& operator<<(std::ostream& os, SchemaType st); INVOKE(ENABLE_IPCL, func_or_macro, ::heu::lib::algorithms::paillier_ipcl, ##__VA_ARGS__) \ INVOKE(true, func_or_macro, ::heu::lib::algorithms::paillier_z, ##__VA_ARGS__) \ INVOKE(true, func_or_macro, ::heu::lib::algorithms::paillier_f, ##__VA_ARGS__) \ - INVOKE(true, func_or_macro, ::heu::lib::algorithms::elgamal, ##__VA_ARGS__) + INVOKE(true, func_or_macro, ::heu::lib::algorithms::elgamal, ##__VA_ARGS__) \ + INVOKE(ENABLE_LEICHI, func_or_macro, ::heu::lib::algorithms::leichi_paillier, ##__VA_ARGS__) \ // [SPI: Please register your algorithm here] || progress: (4 of 5) // If you add a new schema, change this !! @@ -93,6 +96,7 @@ std::ostream& operator<<(std::ostream& os, SchemaType st); func_or_macro(::heu::lib::algorithms, MPInt, ##__VA_ARGS__) \ INVOKE(true, func_or_macro, ::heu::lib::algorithms::mock, Plaintext, ##__VA_ARGS__) \ INVOKE(ENABLE_IPCL, func_or_macro, ::heu::lib::algorithms::paillier_ipcl, Plaintext, ##__VA_ARGS__) \ + INVOKE(ENABLE_LEICHI, func_or_macro, ::heu::lib::algorithms::leichi_paillier, Plaintext, ##__VA_ARGS__) \ // INVOKE(true, func_or_macro, ::heu::lib::algorithms::your_algo, Plaintext, ##__VA_ARGS__) // clang-format on