From c024e632e279774ed352b1dd6575acb924242b5c Mon Sep 17 00:00:00 2001 From: Lee Gross Date: Mon, 31 Jan 2022 16:51:18 -0800 Subject: [PATCH] moving update_status to instance class (#546) Summary: Pull Request resolved: https://github.com/facebookresearch/fbpcs/pull/546 cleaning up code - as title Reviewed By: jrodal98 Differential Revision: D33829284 fbshipit-source-id: ec5b3d3f0a24ab608187150b9a8e749491305326 --- .../entity/private_computation_instance.py | 14 +++++++++ .../service/private_computation.py | 30 +++---------------- 2 files changed, 18 insertions(+), 26 deletions(-) diff --git a/fbpcs/private_computation/entity/private_computation_instance.py b/fbpcs/private_computation/entity/private_computation_instance.py index 5f15092c1..684ed3508 100644 --- a/fbpcs/private_computation/entity/private_computation_instance.py +++ b/fbpcs/private_computation/entity/private_computation_instance.py @@ -18,6 +18,9 @@ PrivateComputationBaseStageFlow, ) +from datetime import timezone, datetime +from logging import Logger + from fbpcp.entity.mpc_instance import MPCInstanceStatus from fbpcs.common.entity.instance_base import InstanceBase from fbpcs.common.entity.pcs_mpc_instance import PCSMPCInstance @@ -229,3 +232,14 @@ def get_next_runnable_stage(self) -> Optional["PrivateComputationBaseStageFlow"] * If the instance has a completed status, return the next stage in the flow (which could be None) """ return self.stage_flow.get_next_runnable_stage_from_status(self.status) + + def update_status( + self, new_status: PrivateComputationInstanceStatus, logger: Logger + ) -> None: + old_status = self.status + self.status = new_status + if old_status is not new_status: + self.status_update_ts = int(datetime.now(tz=timezone.utc).timestamp()) + logger.info( + f"Updating status of {self.instance_id} from {old_status} to {self.status} at time {self.status_update_ts}" + ) diff --git a/fbpcs/private_computation/service/private_computation.py b/fbpcs/private_computation/service/private_computation.py index d5bb8ac1d..8510cbabc 100644 --- a/fbpcs/private_computation/service/private_computation.py +++ b/fbpcs/private_computation/service/private_computation.py @@ -195,9 +195,7 @@ def _update_instance( stage_svc = stage.get_stage_service(self.stage_service_args) self.logger.info(f"Updating instance | {stage}={stage!r}") new_status = stage_svc.get_status(private_computation_instance) - private_computation_instance = self._update_status( - private_computation_instance, new_status - ) + private_computation_instance.update_status(new_status, self.logger) self.instance_repository.update(private_computation_instance) self.logger.info( f"Finished updating instance: {private_computation_instance.instance_id}" @@ -296,19 +294,15 @@ async def run_stage_async( instance_id, stage, server_ips, dry_run ) - self._update_status( - private_computation_instance=pc_instance, - new_status=stage.started_status, - ) + pc_instance.update_status(new_status=stage.started_status, logger=self.logger) self.logger.info(repr(stage)) try: stage_svc = stage_svc or stage.get_stage_service(self.stage_service_args) pc_instance = await stage_svc.run_async(pc_instance, server_ips) except Exception as e: self.logger.error(f"Caught exception when running {stage}\n{e}") - self._update_status( - private_computation_instance=pc_instance, - new_status=stage.failed_status, + pc_instance.update_status( + new_status=stage.failed_status, logger=self.logger ) raise e finally: @@ -398,22 +392,6 @@ def cancel_current_stage( def get_ts_now() -> int: return int(datetime.now(tz=timezone.utc).timestamp()) - def _update_status( - self, - private_computation_instance: PrivateComputationInstance, - new_status: PrivateComputationInstanceStatus, - ) -> PrivateComputationInstance: - old_status = private_computation_instance.status - private_computation_instance.status = new_status - if old_status != new_status: - private_computation_instance.status_update_ts = ( - PrivateComputationService.get_ts_now() - ) - self.logger.info( - f"Updating status of {private_computation_instance.instance_id} from {old_status} to {private_computation_instance.status} at time {private_computation_instance.status_update_ts}" - ) - return private_computation_instance - def _get_param( self, param_name: str, instance_param: Optional[T], override_param: Optional[T] ) -> T: