diff --git a/fbpcs/common/entity/stage_state_instance.py b/fbpcs/common/entity/stage_state_instance.py new file mode 100644 index 000000000..1cff00459 --- /dev/null +++ b/fbpcs/common/entity/stage_state_instance.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import time +from dataclasses import field, dataclass +from enum import Enum +from typing import Optional, List + +from fbpcp.entity.container_instance import ContainerInstance +from fbpcp.util.typing import checked_cast +from fbpcs.common.entity.instance_base import InstanceBase + + +class StageStateInstanceStatus(Enum): + UNKNOWN = "UNKNOWN" + CREATED = "CREATED" + STARTED = "STARTED" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +@dataclass +class StageStateInstance(InstanceBase): + instance_id: str + stage_name: str + status: StageStateInstanceStatus = StageStateInstanceStatus.CREATED + containers: List[ContainerInstance] = field(default_factory=list) + start_time: int = field(default_factory=lambda: int(time.time())) + end_time: Optional[int] = None + + @property + def server_ips(self) -> List[str]: + return [ + checked_cast(str, container.ip_address) for container in self.containers + ] + + @property + def elapsed_time(self) -> int: + if self.end_time is None: + return int(time.time()) - self.start_time + + return self.end_time - self.start_time + + def get_instance_id(self) -> str: + return self.instance_id diff --git a/fbpcs/common/tests/test_stage_state_instance.py b/fbpcs/common/tests/test_stage_state_instance.py new file mode 100644 index 000000000..57a449518 --- /dev/null +++ b/fbpcs/common/tests/test_stage_state_instance.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from fbpcp.entity.container_instance import ContainerInstance, ContainerInstanceStatus +from fbpcs.common.entity.stage_state_instance import ( + StageStateInstance, + StageStateInstanceStatus, +) + + +class TestStageStateInstance(unittest.TestCase): + def setUp(self): + self.stage_state_instance = StageStateInstance( + instance_id="stage_state_instance", + stage_name="test_stage", + status=StageStateInstanceStatus.COMPLETED, + containers=[ + ContainerInstance( + instance_id="test_container_instance_1", + ip_address="192.0.2.4", + status=ContainerInstanceStatus.COMPLETED, + ), + ContainerInstance( + instance_id="test_container_instance_2", + ip_address="192.0.2.5", + status=ContainerInstanceStatus.COMPLETED, + ), + ], + start_time=1646642432, + end_time=1646642432 + 5, + ) + + def test_server_ips(self) -> None: + self.assertEqual(len(self.stage_state_instance.containers), 2) + self.assertEqual( + self.stage_state_instance.server_ips, ["192.0.2.4", "192.0.2.5"] + ) + + def test_elapsed_time(self) -> None: + self.assertEqual(self.stage_state_instance.elapsed_time, 5) diff --git a/fbpcs/private_computation/entity/private_computation_instance.py b/fbpcs/private_computation/entity/private_computation_instance.py index a6567028d..ebb18e383 100644 --- a/fbpcs/private_computation/entity/private_computation_instance.py +++ b/fbpcs/private_computation/entity/private_computation_instance.py @@ -24,6 +24,10 @@ from fbpcp.entity.mpc_instance import MPCInstanceStatus from fbpcs.common.entity.instance_base import InstanceBase from fbpcs.common.entity.pcs_mpc_instance import PCSMPCInstance +from fbpcs.common.entity.stage_state_instance import ( + StageStateInstance, + StageStateInstanceStatus, +) from fbpcs.pid.entity.pid_instance import PIDInstance, PIDInstanceStatus from fbpcs.pid.entity.pid_stages import UnionPIDStage from fbpcs.pid.service.pid_service.pid_stage_mapper import STAGE_TO_FILE_FORMAT_MAP @@ -70,9 +74,14 @@ class ResultVisibility(IntEnum): PARTNER = 2 -UnionedPCInstance = Union[PIDInstance, PCSMPCInstance, PostProcessingInstance] +UnionedPCInstance = Union[ + PIDInstance, PCSMPCInstance, PostProcessingInstance, StageStateInstance +] UnionedPCInstanceStatus = Union[ - PIDInstanceStatus, MPCInstanceStatus, PostProcessingInstanceStatus + PIDInstanceStatus, + MPCInstanceStatus, + PostProcessingInstanceStatus, + StageStateInstanceStatus, ]