|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# |
| 4 | +# This source code is licensed under the MIT license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-strict |
| 8 | + |
| 9 | +import logging |
| 10 | + |
| 11 | +from typing import DefaultDict, List, Optional |
| 12 | + |
| 13 | +from fbpcp.service.mpc import MPCService |
| 14 | +from fbpcs.common.entity.pcs_mpc_instance import PCSMPCInstance |
| 15 | +from fbpcs.onedocker_binary_config import OneDockerBinaryConfig |
| 16 | +from fbpcs.onedocker_binary_names import OneDockerBinaryNames |
| 17 | +from fbpcs.private_computation.entity.infra_config import PrivateComputationGameType |
| 18 | +from fbpcs.private_computation.entity.pcs_feature import PCSFeature |
| 19 | +from fbpcs.private_computation.entity.private_computation_instance import ( |
| 20 | + PrivateComputationInstance, |
| 21 | + PrivateComputationInstanceStatus, |
| 22 | +) |
| 23 | +from fbpcs.private_computation.entity.product_config import ( |
| 24 | + AttributionConfig, |
| 25 | + ResultVisibility, |
| 26 | +) |
| 27 | +from fbpcs.private_computation.repository.private_computation_game import GameNames |
| 28 | +from fbpcs.private_computation.service.constants import DEFAULT_LOG_COST_TO_S3 |
| 29 | +from fbpcs.private_computation.service.private_computation_stage_service import ( |
| 30 | + PrivateComputationStageService, |
| 31 | +) |
| 32 | +from fbpcs.private_computation.service.utils import ( |
| 33 | + create_and_start_mpc_instance, |
| 34 | + get_updated_pc_status_mpc_game, |
| 35 | + map_private_computation_role_to_mpc_party, |
| 36 | +) |
| 37 | + |
| 38 | + |
| 39 | +class ShardCombinerStageService(PrivateComputationStageService): |
| 40 | + """Handles business logic for the private computation combine aggregate metrics stage |
| 41 | +
|
| 42 | + Private attributes: |
| 43 | + _onedocker_binary_config_map: Stores a mapping from mpc game to OneDockerBinaryConfig (binary version and tmp directory) |
| 44 | + _mpc_svc: creates and runs MPC instances |
| 45 | + _log_cost_to_s3: TODO |
| 46 | + _container_timeout: optional duration in seconds before cloud containers timeout |
| 47 | + """ |
| 48 | + |
| 49 | + def __init__( |
| 50 | + self, |
| 51 | + onedocker_binary_config_map: DefaultDict[str, OneDockerBinaryConfig], |
| 52 | + mpc_service: MPCService, |
| 53 | + log_cost_to_s3: bool = DEFAULT_LOG_COST_TO_S3, |
| 54 | + container_timeout: Optional[int] = None, |
| 55 | + ) -> None: |
| 56 | + self._onedocker_binary_config_map = onedocker_binary_config_map |
| 57 | + self._mpc_service = mpc_service |
| 58 | + self._log_cost_to_s3 = log_cost_to_s3 |
| 59 | + self._container_timeout = container_timeout |
| 60 | + |
| 61 | + # TODO T88759390: Make this function truly async. It is not because it calls blocking functions. |
| 62 | + # Make an async version of run_async() so that it can be called by Thrift |
| 63 | + async def run_async( |
| 64 | + self, |
| 65 | + pc_instance: PrivateComputationInstance, |
| 66 | + server_ips: Optional[List[str]] = None, |
| 67 | + ) -> PrivateComputationInstance: |
| 68 | + """Runs the private computation combine aggregate metrics stage |
| 69 | +
|
| 70 | + Args: |
| 71 | + pc_instance: the private computation instance to run aggregate metrics with |
| 72 | + server_ips: only used by the partner role. These are the ip addresses of the publisher's containers. |
| 73 | +
|
| 74 | + Returns: |
| 75 | + An updated version of pc_instance that stores an MPCInstance |
| 76 | + """ |
| 77 | + |
| 78 | + num_shards = ( |
| 79 | + pc_instance.infra_config.num_mpc_containers |
| 80 | + * pc_instance.infra_config.num_files_per_mpc_container |
| 81 | + ) |
| 82 | + |
| 83 | + # TODO T101225989: map aggregation_type from the compute stage to metrics_format_type |
| 84 | + metrics_format_type = ( |
| 85 | + "lift" |
| 86 | + if pc_instance.infra_config.game_type is PrivateComputationGameType.LIFT |
| 87 | + else "ad_object" |
| 88 | + ) |
| 89 | + |
| 90 | + binary_name = OneDockerBinaryNames.PCF2_SHARD_COMBINER.value |
| 91 | + binary_config = self._onedocker_binary_config_map[binary_name] |
| 92 | + |
| 93 | + # Get output path of previous stage depending on what stage flow we are using |
| 94 | + # Using "PrivateComputationDecoupledStageFlow" instead of PrivateComputationDecoupledStageFlow.get_cls_name() to avoid |
| 95 | + # circular import error. |
| 96 | + if pc_instance.get_flow_cls_name in [ |
| 97 | + "PrivateComputationDecoupledStageFlow", |
| 98 | + "PrivateComputationDecoupledLocalTestStageFlow", |
| 99 | + ]: |
| 100 | + input_stage_path = pc_instance.decoupled_aggregation_stage_output_base_path |
| 101 | + elif pc_instance.get_flow_cls_name in [ |
| 102 | + "PrivateComputationPCF2StageFlow", |
| 103 | + "PrivateComputationPCF2LocalTestStageFlow", |
| 104 | + "PrivateComputationPIDPATestStageFlow", |
| 105 | + ]: |
| 106 | + input_stage_path = pc_instance.pcf2_aggregation_stage_output_base_path |
| 107 | + elif pc_instance.get_flow_cls_name in [ |
| 108 | + "PrivateComputationPCF2LiftStageFlow", |
| 109 | + "PrivateComputationPCF2LiftLocalTestStageFlow", |
| 110 | + ]: |
| 111 | + input_stage_path = pc_instance.pcf2_lift_stage_output_base_path |
| 112 | + else: |
| 113 | + if pc_instance.has_feature(PCSFeature.PRIVATE_LIFT_PCF2_RELEASE): |
| 114 | + input_stage_path = pc_instance.pcf2_lift_stage_output_base_path |
| 115 | + else: |
| 116 | + input_stage_path = pc_instance.compute_stage_output_base_path |
| 117 | + |
| 118 | + if self._log_cost_to_s3: |
| 119 | + run_name = pc_instance.infra_config.instance_id |
| 120 | + |
| 121 | + if pc_instance.product_config.common.post_processing_data: |
| 122 | + pc_instance.product_config.common.post_processing_data.s3_cost_export_output_paths.add( |
| 123 | + f"sa-logs/{run_name}_{pc_instance.infra_config.role.value.title()}.json", |
| 124 | + ) |
| 125 | + else: |
| 126 | + run_name = "" |
| 127 | + |
| 128 | + # Create and start MPC instance |
| 129 | + game_args = [ |
| 130 | + { |
| 131 | + "input_base_path": input_stage_path, |
| 132 | + "metrics_format_type": metrics_format_type, |
| 133 | + "num_shards": num_shards, |
| 134 | + "output_path": pc_instance.pcf2_shard_combine_stage_output_path, |
| 135 | + "threshold": 0 |
| 136 | + if isinstance(pc_instance.product_config, AttributionConfig) |
| 137 | + # pyre-ignore Undefined attribute [16] |
| 138 | + else pc_instance.product_config.k_anonymity_threshold, |
| 139 | + "run_name": run_name, |
| 140 | + "log_cost": self._log_cost_to_s3, |
| 141 | + }, |
| 142 | + ] |
| 143 | + # We should only export visibility to scribe when it's set |
| 144 | + if ( |
| 145 | + pc_instance.product_config.common.result_visibility |
| 146 | + is not ResultVisibility.PUBLIC |
| 147 | + ): |
| 148 | + result_visibility = int(pc_instance.product_config.common.result_visibility) |
| 149 | + for arg in game_args: |
| 150 | + arg["visibility"] = result_visibility |
| 151 | + |
| 152 | + mpc_instance = await create_and_start_mpc_instance( |
| 153 | + mpc_svc=self._mpc_service, |
| 154 | + instance_id=pc_instance.infra_config.instance_id |
| 155 | + + "_combine_shards" |
| 156 | + + str(pc_instance.infra_config.retry_counter), |
| 157 | + game_name=GameNames.PCF2_SHARD_COMBINER.value, |
| 158 | + mpc_party=map_private_computation_role_to_mpc_party( |
| 159 | + pc_instance.infra_config.role |
| 160 | + ), |
| 161 | + num_containers=1, |
| 162 | + binary_version=binary_config.binary_version, |
| 163 | + server_ips=server_ips, |
| 164 | + game_args=game_args, |
| 165 | + container_timeout=self._container_timeout, |
| 166 | + repository_path=binary_config.repository_path, |
| 167 | + ) |
| 168 | + |
| 169 | + logging.info("MPC instance started running for PCF2.0 Shard Combiner.") |
| 170 | + |
| 171 | + # Push MPC instance to PrivateComputationInstance.instances and update PL Instance status |
| 172 | + pc_instance.infra_config.instances.append( |
| 173 | + PCSMPCInstance.from_mpc_instance(mpc_instance) |
| 174 | + ) |
| 175 | + return pc_instance |
| 176 | + |
| 177 | + def get_status( |
| 178 | + self, |
| 179 | + pc_instance: PrivateComputationInstance, |
| 180 | + ) -> PrivateComputationInstanceStatus: |
| 181 | + """Updates the MPCInstances and gets latest PrivateComputationInstance status |
| 182 | +
|
| 183 | + Arguments: |
| 184 | + private_computation_instance: The PC instance that is being updated |
| 185 | +
|
| 186 | + Returns: |
| 187 | + The latest status for private_computation_instance |
| 188 | + """ |
| 189 | + return get_updated_pc_status_mpc_game(pc_instance, self._mpc_service) |
0 commit comments