diff --git a/fbpcs/emp_games/he_aggregation/test/HEAggGameTest.cpp b/fbpcs/emp_games/he_aggregation/test/HEAggGameTest.cpp index 37434d738..41614a52e 100644 --- a/fbpcs/emp_games/he_aggregation/test/HEAggGameTest.cpp +++ b/fbpcs/emp_games/he_aggregation/test/HEAggGameTest.cpp @@ -29,6 +29,10 @@ #include "fbpcs/emp_games/he_aggregation/HEAggGame.h" #include "fbpcs/emp_games/he_aggregation/HEAggOptions.h" +#include "privacy_infra/elgamal/ElGamal.h" + +namespace heschme = facebook::privacy_infra::elgamal; + namespace pcf2_he { std::unordered_map runGame( @@ -56,6 +60,52 @@ void verifyOutput( folly::toJson(actualOutput), folly::toJson(expectedOutput)); } +TEST(HEAggGameTest, HECiphertextAdditionTest) { + const std::string baseDir_ = + private_measurement::test_util::getBaseDirFromPath(__FILE__); + + // Generate private key, public key and decryption table + auto sk = heschme::PrivateKey::generate(); + auto pk = sk.toPublicKey(); + heschme::initializeElGamalDecryptionTable(FLAGS_decryption_table_size); + + // Encrypt values + int x = 111; + int y = 222; + heschme::Ciphertext c1 = pk.encrypt(x); + heschme::Ciphertext c2 = pk.encrypt(y); + + // Perform addition + heschme::Ciphertext c3 = heschme::Ciphertext::add_with_ciphertext(c1, c2); + + // Decrypt + uint64_t decrypted = sk.decrypt(c3); + + EXPECT_EQ(decrypted, x + y); +} +TEST(HEAggGameTest, HEPlaintextAdditionTest) { + const std::string baseDir_ = + private_measurement::test_util::getBaseDirFromPath(__FILE__); + + // Generate private key, public key and decryption table + auto sk = heschme::PrivateKey::generate(); + auto pk = sk.toPublicKey(); + heschme::initializeElGamalDecryptionTable(FLAGS_decryption_table_size); + + // Encrypt values + int x = 111; + int y = 444; + heschme::Ciphertext c1 = pk.encrypt(x); + + // Perform addition + heschme::Ciphertext c2 = heschme::Ciphertext::add_with_plaintext(c1, y); + + // Decrypt + uint64_t decrypted = sk.decrypt(c2); + + EXPECT_EQ(decrypted, x + y); +} + TEST(HEAggGameTest, HEAggGameCorrectnessTest) { const std::string baseDir_ = private_measurement::test_util::getBaseDirFromPath(__FILE__);