Skip to content

Commit d40e295

Browse files
tbagsfacebook-github-bot
authored andcommitted
Add PCF2 ShardCombiner stage service (#1429)
Summary: Pull Request resolved: #1429 I just following D36325473 (deb4801) We had a new stage service for PCF2.0 PL, following the instructions here: https://www.internalfb.com/intern/wiki/Private_Computation_Platform_(PCP)/Internal_Developer_Guide/How_to_add_a_stage_to_PCS/Step_2:_Add_a_stage_service/ Reviewed By: robotal Differential Revision: D38340542 fbshipit-source-id: c4bec85d7b86acded8e9a7c4f808866eb4f36040
1 parent 419dce0 commit d40e295

6 files changed

+343
-0
lines changed

fbpcs/onedocker_binary_names.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class OneDockerBinaryNames(Enum):
2121
PCF2_ATTRIBUTION = "private_attribution/pcf2_attribution"
2222
PCF2_AGGREGATION = "private_attribution/pcf2_aggregation"
2323
SHARD_AGGREGATOR = "private_attribution/shard-aggregator"
24+
PCF2_SHARD_COMBINER = "private_attribution/pcf2_shard-combiner"
2425

2526
PID_CLIENT = "pid/private-id-client"
2627
PID_SERVER = "pid/private-id-server"

fbpcs/private_computation/entity/private_computation_instance.py

+4
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ def pcf2_aggregation_stage_output_base_path(self) -> str:
223223
def shard_aggregate_stage_output_path(self) -> str:
224224
return self._get_stage_output_path("shard_aggregation_stage", "json")
225225

226+
@property
227+
def pcf2_shard_combine_stage_output_path(self) -> str:
228+
return self._get_stage_output_path("pcf2_shard_combiner_stage", "json")
229+
226230
def _get_stage_output_path(self, stage: str, extension_type: str) -> str:
227231
return os.path.join(
228232
self.product_config.common.output_dir,

fbpcs/private_computation/repository/private_computation_game.py

+15
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class GameNames(Enum):
2020
LIFT = "lift"
2121
PCF2_LIFT = "pcf2_lift"
2222
SHARD_AGGREGATOR = "shard_aggregator"
23+
PCF2_SHARD_COMBINER = "pcf2_shard_combiner"
2324
DECOUPLED_ATTRIBUTION = "decoupled_attribution"
2425
DECOUPLED_AGGREGATION = "decoupled_aggregation"
2526
PCF2_ATTRIBUTION = "pcf2_attribution"
@@ -75,6 +76,20 @@ class GameNamesValue(TypedDict):
7576
OneDockerArgument(name="visibility", required=False),
7677
],
7778
},
79+
GameNames.PCF2_SHARD_COMBINER.value: {
80+
"onedocker_package_name": OneDockerBinaryNames.PCF2_SHARD_COMBINER.value,
81+
"arguments": [
82+
OneDockerArgument(name="input_base_path", required=True),
83+
OneDockerArgument(name="num_shards", required=True),
84+
OneDockerArgument(name="output_path", required=True),
85+
OneDockerArgument(name="metrics_format_type", required=True),
86+
OneDockerArgument(name="threshold", required=True),
87+
OneDockerArgument(name="first_shard_index", required=False),
88+
OneDockerArgument(name="log_cost", required=False),
89+
OneDockerArgument(name="run_name", required=False),
90+
OneDockerArgument(name="visibility", required=False),
91+
],
92+
},
7893
GameNames.DECOUPLED_ATTRIBUTION.value: {
7994
"onedocker_package_name": OneDockerBinaryNames.DECOUPLED_ATTRIBUTION.value,
8095
"arguments": [
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import logging
10+
11+
from typing import DefaultDict, List, Optional
12+
13+
from fbpcp.service.mpc import MPCService
14+
from fbpcs.common.entity.pcs_mpc_instance import PCSMPCInstance
15+
from fbpcs.onedocker_binary_config import OneDockerBinaryConfig
16+
from fbpcs.onedocker_binary_names import OneDockerBinaryNames
17+
from fbpcs.private_computation.entity.infra_config import PrivateComputationGameType
18+
from fbpcs.private_computation.entity.pcs_feature import PCSFeature
19+
from fbpcs.private_computation.entity.private_computation_instance import (
20+
PrivateComputationInstance,
21+
PrivateComputationInstanceStatus,
22+
)
23+
from fbpcs.private_computation.entity.product_config import (
24+
AttributionConfig,
25+
ResultVisibility,
26+
)
27+
from fbpcs.private_computation.repository.private_computation_game import GameNames
28+
from fbpcs.private_computation.service.constants import DEFAULT_LOG_COST_TO_S3
29+
from fbpcs.private_computation.service.private_computation_stage_service import (
30+
PrivateComputationStageService,
31+
)
32+
from fbpcs.private_computation.service.utils import (
33+
create_and_start_mpc_instance,
34+
get_updated_pc_status_mpc_game,
35+
map_private_computation_role_to_mpc_party,
36+
)
37+
38+
39+
class ShardCombinerStageService(PrivateComputationStageService):
40+
"""Handles business logic for the private computation combine aggregate metrics stage
41+
42+
Private attributes:
43+
_onedocker_binary_config_map: Stores a mapping from mpc game to OneDockerBinaryConfig (binary version and tmp directory)
44+
_mpc_svc: creates and runs MPC instances
45+
_log_cost_to_s3: TODO
46+
_container_timeout: optional duration in seconds before cloud containers timeout
47+
"""
48+
49+
def __init__(
50+
self,
51+
onedocker_binary_config_map: DefaultDict[str, OneDockerBinaryConfig],
52+
mpc_service: MPCService,
53+
log_cost_to_s3: bool = DEFAULT_LOG_COST_TO_S3,
54+
container_timeout: Optional[int] = None,
55+
) -> None:
56+
self._onedocker_binary_config_map = onedocker_binary_config_map
57+
self._mpc_service = mpc_service
58+
self._log_cost_to_s3 = log_cost_to_s3
59+
self._container_timeout = container_timeout
60+
61+
# TODO T88759390: Make this function truly async. It is not because it calls blocking functions.
62+
# Make an async version of run_async() so that it can be called by Thrift
63+
async def run_async(
64+
self,
65+
pc_instance: PrivateComputationInstance,
66+
server_ips: Optional[List[str]] = None,
67+
) -> PrivateComputationInstance:
68+
"""Runs the private computation combine aggregate metrics stage
69+
70+
Args:
71+
pc_instance: the private computation instance to run aggregate metrics with
72+
server_ips: only used by the partner role. These are the ip addresses of the publisher's containers.
73+
74+
Returns:
75+
An updated version of pc_instance that stores an MPCInstance
76+
"""
77+
78+
num_shards = (
79+
pc_instance.infra_config.num_mpc_containers
80+
* pc_instance.infra_config.num_files_per_mpc_container
81+
)
82+
83+
# TODO T101225989: map aggregation_type from the compute stage to metrics_format_type
84+
metrics_format_type = (
85+
"lift"
86+
if pc_instance.infra_config.game_type is PrivateComputationGameType.LIFT
87+
else "ad_object"
88+
)
89+
90+
binary_name = OneDockerBinaryNames.PCF2_SHARD_COMBINER.value
91+
binary_config = self._onedocker_binary_config_map[binary_name]
92+
93+
# Get output path of previous stage depending on what stage flow we are using
94+
# Using "PrivateComputationDecoupledStageFlow" instead of PrivateComputationDecoupledStageFlow.get_cls_name() to avoid
95+
# circular import error.
96+
if pc_instance.get_flow_cls_name in [
97+
"PrivateComputationDecoupledStageFlow",
98+
"PrivateComputationDecoupledLocalTestStageFlow",
99+
]:
100+
input_stage_path = pc_instance.decoupled_aggregation_stage_output_base_path
101+
elif pc_instance.get_flow_cls_name in [
102+
"PrivateComputationPCF2StageFlow",
103+
"PrivateComputationPCF2LocalTestStageFlow",
104+
"PrivateComputationPIDPATestStageFlow",
105+
]:
106+
input_stage_path = pc_instance.pcf2_aggregation_stage_output_base_path
107+
elif pc_instance.get_flow_cls_name in [
108+
"PrivateComputationPCF2LiftStageFlow",
109+
"PrivateComputationPCF2LiftLocalTestStageFlow",
110+
]:
111+
input_stage_path = pc_instance.pcf2_lift_stage_output_base_path
112+
else:
113+
if pc_instance.has_feature(PCSFeature.PRIVATE_LIFT_PCF2_RELEASE):
114+
input_stage_path = pc_instance.pcf2_lift_stage_output_base_path
115+
else:
116+
input_stage_path = pc_instance.compute_stage_output_base_path
117+
118+
if self._log_cost_to_s3:
119+
run_name = pc_instance.infra_config.instance_id
120+
121+
if pc_instance.product_config.common.post_processing_data:
122+
pc_instance.product_config.common.post_processing_data.s3_cost_export_output_paths.add(
123+
f"sa-logs/{run_name}_{pc_instance.infra_config.role.value.title()}.json",
124+
)
125+
else:
126+
run_name = ""
127+
128+
# Create and start MPC instance
129+
game_args = [
130+
{
131+
"input_base_path": input_stage_path,
132+
"metrics_format_type": metrics_format_type,
133+
"num_shards": num_shards,
134+
"output_path": pc_instance.pcf2_shard_combine_stage_output_path,
135+
"threshold": 0
136+
if isinstance(pc_instance.product_config, AttributionConfig)
137+
# pyre-ignore Undefined attribute [16]
138+
else pc_instance.product_config.k_anonymity_threshold,
139+
"run_name": run_name,
140+
"log_cost": self._log_cost_to_s3,
141+
},
142+
]
143+
# We should only export visibility to scribe when it's set
144+
if (
145+
pc_instance.product_config.common.result_visibility
146+
is not ResultVisibility.PUBLIC
147+
):
148+
result_visibility = int(pc_instance.product_config.common.result_visibility)
149+
for arg in game_args:
150+
arg["visibility"] = result_visibility
151+
152+
mpc_instance = await create_and_start_mpc_instance(
153+
mpc_svc=self._mpc_service,
154+
instance_id=pc_instance.infra_config.instance_id
155+
+ "_combine_shards"
156+
+ str(pc_instance.infra_config.retry_counter),
157+
game_name=GameNames.PCF2_SHARD_COMBINER.value,
158+
mpc_party=map_private_computation_role_to_mpc_party(
159+
pc_instance.infra_config.role
160+
),
161+
num_containers=1,
162+
binary_version=binary_config.binary_version,
163+
server_ips=server_ips,
164+
game_args=game_args,
165+
container_timeout=self._container_timeout,
166+
repository_path=binary_config.repository_path,
167+
)
168+
169+
logging.info("MPC instance started running for PCF2.0 Shard Combiner.")
170+
171+
# Push MPC instance to PrivateComputationInstance.instances and update PL Instance status
172+
pc_instance.infra_config.instances.append(
173+
PCSMPCInstance.from_mpc_instance(mpc_instance)
174+
)
175+
return pc_instance
176+
177+
def get_status(
178+
self,
179+
pc_instance: PrivateComputationInstance,
180+
) -> PrivateComputationInstanceStatus:
181+
"""Updates the MPCInstances and gets latest PrivateComputationInstance status
182+
183+
Arguments:
184+
private_computation_instance: The PC instance that is being updated
185+
186+
Returns:
187+
The latest status for private_computation_instance
188+
"""
189+
return get_updated_pc_status_mpc_game(pc_instance, self._mpc_service)

fbpcs/private_computation/service/private_computation_service_data.py

+8
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,14 @@ class PrivateComputationServiceData:
109109
service=None,
110110
)
111111

112+
PCF2_SHARD_COMBINE_STAGE_DATA: StageData = StageData(
113+
binary_name=OneDockerBinaryNames.PCF2_SHARD_COMBINER.value,
114+
game_name=BINARY_NAME_TO_GAME_NAME[
115+
OneDockerBinaryNames.PCF2_SHARD_COMBINER.value
116+
],
117+
service=None,
118+
)
119+
112120
@classmethod
113121
def get(
114122
cls, game_type: PrivateComputationGameType
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from collections import defaultdict
8+
from unittest import IsolatedAsyncioTestCase
9+
from unittest.mock import AsyncMock, MagicMock, patch
10+
11+
from fbpcp.entity.mpc_instance import MPCParty
12+
from fbpcs.common.entity.pcs_mpc_instance import PCSMPCInstance
13+
from fbpcs.onedocker_binary_config import OneDockerBinaryConfig
14+
from fbpcs.private_computation.entity.infra_config import (
15+
InfraConfig,
16+
PrivateComputationGameType,
17+
)
18+
from fbpcs.private_computation.entity.private_computation_instance import (
19+
PrivateComputationInstance,
20+
PrivateComputationInstanceStatus,
21+
PrivateComputationRole,
22+
)
23+
from fbpcs.private_computation.entity.product_config import (
24+
AttributionConfig,
25+
CommonProductConfig,
26+
LiftConfig,
27+
ProductConfig,
28+
)
29+
from fbpcs.private_computation.repository.private_computation_game import GameNames
30+
from fbpcs.private_computation.service.constants import NUM_NEW_SHARDS_PER_FILE
31+
from fbpcs.private_computation.service.pcf2_shard_combiner_stage_service import (
32+
ShardCombinerStageService,
33+
)
34+
35+
36+
class TestShardCombinerStageService(IsolatedAsyncioTestCase):
37+
@patch("fbpcp.service.mpc.MPCService")
38+
def setUp(self, mock_mpc_svc) -> None:
39+
self.mock_mpc_svc = mock_mpc_svc
40+
self.mock_mpc_svc.create_instance = MagicMock()
41+
42+
onedocker_binary_config_map = defaultdict(
43+
lambda: OneDockerBinaryConfig(
44+
tmp_directory="/test_tmp_directory/",
45+
binary_version="latest",
46+
repository_path="test_path/",
47+
)
48+
)
49+
self.stage_svc = ShardCombinerStageService(
50+
onedocker_binary_config_map, self.mock_mpc_svc
51+
)
52+
53+
async def test_shard_combiner(self) -> None:
54+
private_computation_instance = self._create_pc_instance()
55+
mpc_instance = PCSMPCInstance.create_instance(
56+
instance_id=private_computation_instance.infra_config.instance_id
57+
+ "_aggregate_metrics0",
58+
game_name=GameNames.LIFT.value,
59+
mpc_party=MPCParty.CLIENT,
60+
num_workers=private_computation_instance.infra_config.num_mpc_containers,
61+
)
62+
63+
self.mock_mpc_svc.start_instance_async = AsyncMock(return_value=mpc_instance)
64+
65+
test_server_ips = [
66+
f"192.0.2.{i}"
67+
for i in range(private_computation_instance.infra_config.num_mpc_containers)
68+
]
69+
await self.stage_svc.run_async(private_computation_instance, test_server_ips)
70+
test_game_args = [
71+
{
72+
"input_base_path": private_computation_instance.compute_stage_output_base_path,
73+
"metrics_format_type": "lift",
74+
"num_shards": private_computation_instance.infra_config.num_mpc_containers
75+
* NUM_NEW_SHARDS_PER_FILE,
76+
"output_path": private_computation_instance.pcf2_shard_combine_stage_output_path,
77+
"threshold": 0
78+
if isinstance(
79+
private_computation_instance.product_config, AttributionConfig
80+
)
81+
# pyre-ignore Undefined attribute [16]
82+
else private_computation_instance.product_config.k_anonymity_threshold,
83+
"run_name": private_computation_instance.infra_config.instance_id
84+
if self.stage_svc._log_cost_to_s3
85+
else "",
86+
"log_cost": True,
87+
}
88+
]
89+
90+
self.assertEqual(
91+
GameNames.PCF2_SHARD_COMBINER.value,
92+
self.mock_mpc_svc.create_instance.call_args[1]["game_name"],
93+
)
94+
self.assertEqual(
95+
test_game_args,
96+
self.mock_mpc_svc.create_instance.call_args[1]["game_args"],
97+
)
98+
99+
self.assertEqual(
100+
mpc_instance, private_computation_instance.infra_config.instances[0]
101+
)
102+
103+
def _create_pc_instance(self) -> PrivateComputationInstance:
104+
infra_config: InfraConfig = InfraConfig(
105+
instance_id="test_instance_123",
106+
role=PrivateComputationRole.PARTNER,
107+
status=PrivateComputationInstanceStatus.COMPUTATION_COMPLETED,
108+
status_update_ts=1600000000,
109+
instances=[],
110+
game_type=PrivateComputationGameType.LIFT,
111+
num_pid_containers=2,
112+
num_mpc_containers=2,
113+
num_files_per_mpc_container=NUM_NEW_SHARDS_PER_FILE,
114+
status_updates=[],
115+
)
116+
common: CommonProductConfig = CommonProductConfig(
117+
input_path="456",
118+
output_dir="789",
119+
)
120+
product_config: ProductConfig = LiftConfig(
121+
common=common,
122+
)
123+
return PrivateComputationInstance(
124+
infra_config=infra_config,
125+
product_config=product_config,
126+
)

0 commit comments

Comments
 (0)