diff --git a/fbpcs/emp_games/common/Util.h b/fbpcs/emp_games/common/Util.h index 0f7d3244a..badc9871c 100644 --- a/fbpcs/emp_games/common/Util.h +++ b/fbpcs/emp_games/common/Util.h @@ -7,6 +7,8 @@ #pragma once +#include +#include #include #include "folly/dynamic.h" diff --git a/fbpcs/emp_games/common/test/UtilTest.cpp b/fbpcs/emp_games/common/test/UtilTest.cpp new file mode 100644 index 000000000..d2f8faaba --- /dev/null +++ b/fbpcs/emp_games/common/test/UtilTest.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include "../Util.h" + +namespace private_measurement { + +TEST(UtilTest, TestGetTlsInfoFromArguments) { + auto tlsInfo = common::getTlsInfoFromArgs( + false, + "cert_path", + "server_cert_path", + "private_key_path", + "passphrase_path"); + + EXPECT_FALSE(tlsInfo.useTls); + EXPECT_STREQ(tlsInfo.rootCaCertPath.c_str(), ""); + EXPECT_STREQ(tlsInfo.certPath.c_str(), ""); + EXPECT_STREQ(tlsInfo.keyPath.c_str(), ""); + EXPECT_STREQ(tlsInfo.passphrasePath.c_str(), ""); + + const char* home_dir = std::getenv("HOME"); + if (home_dir == nullptr) { + home_dir = ""; + } + + std::string home_dir_string(home_dir); + + tlsInfo = common::getTlsInfoFromArgs( + true, + "cert_path", + "server_cert_path", + "private_key_path", + "passphrase_path"); + + EXPECT_TRUE(tlsInfo.useTls); + EXPECT_STREQ( + tlsInfo.rootCaCertPath.c_str(), (home_dir_string + "/cert_path").c_str()); + EXPECT_STREQ( + tlsInfo.certPath.c_str(), + (home_dir_string + "/server_cert_path").c_str()); + EXPECT_STREQ( + tlsInfo.keyPath.c_str(), (home_dir_string + "/private_key_path").c_str()); + EXPECT_STREQ( + tlsInfo.passphrasePath.c_str(), + (home_dir_string + "/passphrase_path").c_str()); +} +} // namespace private_measurement diff --git a/fbpcs/emp_games/compactor/main.cpp b/fbpcs/emp_games/compactor/main.cpp index 1fcde409b..94d8e700b 100644 --- a/fbpcs/emp_games/compactor/main.cpp +++ b/fbpcs/emp_games/compactor/main.cpp @@ -18,6 +18,7 @@ #include "fbpcf/engine/communication/SocketPartyCommunicationAgentFactory.h" #include "fbpcf/io/api/FileIOWrappers.h" #include "fbpcf/scheduler/LazySchedulerFactory.h" +#include "fbpcs/emp_games/common/Util.h" #include "fbpcs/emp_games/compactor/AttributionOutput.h" #include "fbpcs/emp_games/compactor/CompactorGame.h" #include "fbpcs/performance_tools/CostEstimation.h" @@ -101,11 +102,12 @@ int main(int argc, char** argv) { XLOG(INFO) << "Creating communication agent factory\n"; - fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo tlsInfo; - tlsInfo.certPath = ""; - tlsInfo.keyPath = ""; - tlsInfo.passphrasePath = ""; - tlsInfo.useTls = false; + auto tlsInfo = common::getTlsInfoFromArgs( + FLAGS_use_tls, + FLAGS_ca_cert_path, + FLAGS_server_cert_path, + FLAGS_private_key_path, + ""); std::map< int, diff --git a/fbpcs/emp_games/dotproduct/MainUtil.h b/fbpcs/emp_games/dotproduct/MainUtil.h index 0bc3b1c60..8924814ab 100644 --- a/fbpcs/emp_games/dotproduct/MainUtil.h +++ b/fbpcs/emp_games/dotproduct/MainUtil.h @@ -24,17 +24,14 @@ inline common::SchedulerStatistics startDotProductApp( std::string& outFilePath, int numFeatures, int labelWidth, - bool debugMode) { + bool debugMode, + fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo& + tlsInfo) { std::map< int, fbpcf::engine::communication::SocketPartyCommunicationAgentFactory:: PartyInfo> partyInfos({{0, {serverIp, port}}, {1, {serverIp, port}}}); - fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo tlsInfo; - tlsInfo.certPath = ""; - tlsInfo.keyPath = ""; - tlsInfo.passphrasePath = ""; - tlsInfo.useTls = false; auto metricCollector = std::make_shared("dotproduct"); diff --git a/fbpcs/emp_games/dotproduct/main.cpp b/fbpcs/emp_games/dotproduct/main.cpp index 6035565c1..15213ade0 100644 --- a/fbpcs/emp_games/dotproduct/main.cpp +++ b/fbpcs/emp_games/dotproduct/main.cpp @@ -44,6 +44,13 @@ int main(int argc, char* argv[]) { XLOGF(INFO, "Base output path: {}", FLAGS_output_base_path); common::SchedulerStatistics schedulerStatistics; + + auto tlsInfo = common::getTlsInfoFromArgs( + FLAGS_use_tls, + FLAGS_ca_cert_path, + FLAGS_server_cert_path, + FLAGS_private_key_path, + ""); try { if (FLAGS_party == common::PUBLISHER) { XLOG(INFO) @@ -57,7 +64,8 @@ int main(int argc, char* argv[]) { FLAGS_output_base_path, FLAGS_num_features, FLAGS_label_width, - FLAGS_debug); + FLAGS_debug, + tlsInfo); } else if (FLAGS_party == common::PARTNER) { XLOG(INFO) @@ -71,7 +79,8 @@ int main(int argc, char* argv[]) { FLAGS_output_base_path, FLAGS_num_features, FLAGS_label_width, - FLAGS_debug); + FLAGS_debug, + tlsInfo); } else { XLOGF(FATAL, "Invalid Party: {}", FLAGS_party); } diff --git a/fbpcs/emp_games/pcf2_aggregation/MainUtil.h b/fbpcs/emp_games/pcf2_aggregation/MainUtil.h index 9bfad2cd9..91526b9c1 100644 --- a/fbpcs/emp_games/pcf2_aggregation/MainUtil.h +++ b/fbpcs/emp_games/pcf2_aggregation/MainUtil.h @@ -55,7 +55,9 @@ inline common::SchedulerStatistics startAggregationAppsForShardedFilesHelper( std::string aggregationFormats, std::vector& inputSecretShareFilenames, std::vector& inputClearTextFilenames, - std::vector& outputFilenames) { + std::vector& outputFilenames, + fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo& + tlsInfo) { // aggregate scheduler statistics across apps common::SchedulerStatistics schedulerStatistics{ 0, 0, 0, 0, folly::dynamic::object()}; @@ -77,13 +79,6 @@ inline common::SchedulerStatistics startAggregationAppsForShardedFilesHelper( {{0, {serverIp, port + index * 100}}, {1, {serverIp, port + index * 100}}}); - fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo - tlsInfo; - tlsInfo.certPath = ""; - tlsInfo.keyPath = ""; - tlsInfo.passphrasePath = ""; - tlsInfo.useTls = false; - auto metricCollector = std::make_shared( "aggregation_metrics_for_thread_" + std::to_string(index)); @@ -126,7 +121,8 @@ inline common::SchedulerStatistics startAggregationAppsForShardedFilesHelper( aggregationFormats, inputSecretShareFilenames, inputClearTextFilenames, - outputFilenames); + outputFilenames, + tlsInfo); schedulerStatistics.add(remainingStats); } } @@ -146,7 +142,9 @@ inline common::SchedulerStatistics startAggregationAppsForShardedFiles( int16_t concurrency, std::string serverIp, int port, - std::string aggregationFormats) { + std::string aggregationFormats, + fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo& + tlsInfo) { // use only as many threads as the number of files auto numThreads = std::min((int)inputSecretShareFilenames.size(), (int)concurrency); @@ -162,7 +160,8 @@ inline common::SchedulerStatistics startAggregationAppsForShardedFiles( aggregationFormats, inputSecretShareFilenames, inputClearTextFilenames, - outputFilenames); + outputFilenames, + tlsInfo); } } // namespace pcf2_aggregation diff --git a/fbpcs/emp_games/pcf2_aggregation/main.cpp b/fbpcs/emp_games/pcf2_aggregation/main.cpp index 458cb70f7..057fdab67 100644 --- a/fbpcs/emp_games/pcf2_aggregation/main.cpp +++ b/fbpcs/emp_games/pcf2_aggregation/main.cpp @@ -89,6 +89,13 @@ int main(int argc, char* argv[]) { inputEncryption = common::InputEncryption::Plaintext; } + auto tlsInfo = common::getTlsInfoFromArgs( + FLAGS_use_tls, + FLAGS_ca_cert_path, + FLAGS_server_cert_path, + FLAGS_private_key_path, + ""); + if (FLAGS_party == common::PUBLISHER) { XLOGF(INFO, "Aggregation Format: {}", FLAGS_aggregators); @@ -106,7 +113,8 @@ int main(int argc, char* argv[]) { concurrency, FLAGS_server_ip, FLAGS_port, - FLAGS_aggregators); + FLAGS_aggregators, + tlsInfo); } else if (FLAGS_party == common::PARTNER) { XLOG(INFO) << "Starting private aggregation as Partner, will wait for Publisher..."; @@ -121,7 +129,8 @@ int main(int argc, char* argv[]) { concurrency, FLAGS_server_ip, FLAGS_port, - FLAGS_aggregators); + FLAGS_aggregators, + tlsInfo); } else { XLOGF(FATAL, "Invalid Party: {}", FLAGS_party); diff --git a/fbpcs/emp_games/pcf2_attribution/MainUtil.h b/fbpcs/emp_games/pcf2_attribution/MainUtil.h index 25dad1634..217088108 100644 --- a/fbpcs/emp_games/pcf2_attribution/MainUtil.h +++ b/fbpcs/emp_games/pcf2_attribution/MainUtil.h @@ -65,7 +65,9 @@ inline common::SchedulerStatistics startAttributionAppsForShardedFilesHelper( int port, std::string attributionRules, std::vector& inputFilenames, - std::vector& outputFilenames) { + std::vector& outputFilenames, + fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo& + tlsInfo) { // aggregate scheduler statistics across apps common::SchedulerStatistics schedulerStatistics{ 0, 0, 0, 0, folly::dynamic::object()}; @@ -86,13 +88,6 @@ inline common::SchedulerStatistics startAttributionAppsForShardedFilesHelper( {{0, {serverIp, port + static_cast(index) * 100}}, {1, {serverIp, port + static_cast(index) * 100}}}); - fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo - tlsInfo; - tlsInfo.certPath = ""; - tlsInfo.keyPath = ""; - tlsInfo.passphrasePath = ""; - tlsInfo.useTls = false; - auto metricCollector = std::make_shared( "attribution_metrics_for_thread_" + std::to_string(index)); @@ -133,7 +128,8 @@ inline common::SchedulerStatistics startAttributionAppsForShardedFilesHelper( port, attributionRules, inputFilenames, - outputFilenames); + outputFilenames, + tlsInfo); schedulerStatistics.add(remainingStats); } } @@ -150,7 +146,9 @@ inline common::SchedulerStatistics startAttributionAppsForShardedFiles( int16_t concurrency, std::string serverIp, int port, - std::string attributionRules) { + std::string attributionRules, + fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo& + tlsInfo) { // use only as many threads as the number of files auto numThreads = std::min(static_cast(inputFilenames.size()), concurrency); @@ -166,7 +164,8 @@ inline common::SchedulerStatistics startAttributionAppsForShardedFiles( port, attributionRules, inputFilenames, - outputFilenames); + outputFilenames, + tlsInfo); } } // namespace pcf2_attribution diff --git a/fbpcs/emp_games/pcf2_attribution/main.cpp b/fbpcs/emp_games/pcf2_attribution/main.cpp index dbeaf55c9..58d7bdd7b 100644 --- a/fbpcs/emp_games/pcf2_attribution/main.cpp +++ b/fbpcs/emp_games/pcf2_attribution/main.cpp @@ -64,6 +64,13 @@ int main(int argc, char* argv[]) { CHECK_LE(concurrency, pcf2_attribution::kMaxConcurrency) << "Concurrency must be at most " << pcf2_attribution::kMaxConcurrency; + auto tlsInfo = common::getTlsInfoFromArgs( + FLAGS_use_tls, + FLAGS_ca_cert_path, + FLAGS_server_cert_path, + FLAGS_private_key_path, + ""); + if (FLAGS_party == common::PUBLISHER) { XLOGF(INFO, "Attribution Rules: {}", FLAGS_attribution_rules); @@ -81,7 +88,8 @@ int main(int argc, char* argv[]) { concurrency, FLAGS_server_ip, FLAGS_port, - FLAGS_attribution_rules); + FLAGS_attribution_rules, + tlsInfo); } else if (FLAGS_input_encryption == 2) { schedulerStatistics = pcf2_attribution::startAttributionAppsForShardedFiles< @@ -93,7 +101,8 @@ int main(int argc, char* argv[]) { concurrency, FLAGS_server_ip, FLAGS_port, - FLAGS_attribution_rules); + FLAGS_attribution_rules, + tlsInfo); } else { schedulerStatistics = pcf2_attribution::startAttributionAppsForShardedFiles< @@ -105,7 +114,8 @@ int main(int argc, char* argv[]) { concurrency, FLAGS_server_ip, FLAGS_port, - FLAGS_attribution_rules); + FLAGS_attribution_rules, + tlsInfo); } } else if (FLAGS_party == common::PARTNER) { @@ -123,7 +133,8 @@ int main(int argc, char* argv[]) { concurrency, FLAGS_server_ip, FLAGS_port, - FLAGS_attribution_rules); + FLAGS_attribution_rules, + tlsInfo); } else if (FLAGS_input_encryption == 2) { schedulerStatistics = pcf2_attribution::startAttributionAppsForShardedFiles< @@ -135,7 +146,8 @@ int main(int argc, char* argv[]) { concurrency, FLAGS_server_ip, FLAGS_port, - FLAGS_attribution_rules); + FLAGS_attribution_rules, + tlsInfo); } else { schedulerStatistics = @@ -148,7 +160,8 @@ int main(int argc, char* argv[]) { concurrency, FLAGS_server_ip, FLAGS_port, - FLAGS_attribution_rules); + FLAGS_attribution_rules, + tlsInfo); } } else { diff --git a/fbpcs/emp_games/pcf2_shard_combiner/main.cpp b/fbpcs/emp_games/pcf2_shard_combiner/main.cpp index 144bb8a53..74059b58e 100644 --- a/fbpcs/emp_games/pcf2_shard_combiner/main.cpp +++ b/fbpcs/emp_games/pcf2_shard_combiner/main.cpp @@ -121,6 +121,13 @@ int main(int argc, char* argv[]) { common::SchedulerStatistics schedulerStatistics; + auto tlsInfo = common::getTlsInfoFromArgs( + FLAGS_use_tls, + FLAGS_ca_cert_path, + FLAGS_server_cert_path, + FLAGS_private_key_path, + ""); + if (FLAGS_metrics_format_type == "ad_object") { schedulerStatistics = runApp( FLAGS_party, @@ -135,7 +142,8 @@ int main(int argc, char* argv[]) { FLAGS_use_xor_encryption, FLAGS_visibility, FLAGS_server_ip, - FLAGS_port); + FLAGS_port, + tlsInfo); } else if (FLAGS_metrics_format_type == "lift") { schedulerStatistics = runApp( FLAGS_party, @@ -150,7 +158,8 @@ int main(int argc, char* argv[]) { FLAGS_use_xor_encryption, FLAGS_visibility, FLAGS_server_ip, - FLAGS_port); + FLAGS_port, + tlsInfo); } else { std::string errStr = folly::sformat( "unsupported metrics format type: {}", FLAGS_metrics_format_type); diff --git a/fbpcs/emp_games/pcf2_shard_combiner/util/MainUtil.h b/fbpcs/emp_games/pcf2_shard_combiner/util/MainUtil.h index 9b3cba1a7..8f5c3bb90 100644 --- a/fbpcs/emp_games/pcf2_shard_combiner/util/MainUtil.h +++ b/fbpcs/emp_games/pcf2_shard_combiner/util/MainUtil.h @@ -33,7 +33,9 @@ common::SchedulerStatistics runApp( bool useXorEncryption, int32_t visibility, std::string ip, - std::uint16_t port) { + std::uint16_t port, + fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo& + tlsInfo) { assert(inputEncryption == common::InputEncryption::Xor); assert(visibility == 0 || visibility == 1 || visibility == 2); @@ -50,12 +52,6 @@ common::SchedulerStatistics runApp( break; } - fbpcf::engine::communication::SocketPartyCommunicationAgent::TlsInfo tlsInfo; - tlsInfo.certPath = ""; - tlsInfo.keyPath = ""; - tlsInfo.passphrasePath = ""; - tlsInfo.useTls = false; - std::map< int32_t, fbpcf::engine::communication::SocketPartyCommunicationAgentFactory::