Skip to content

Commit

Permalink
refactor(multiprocess_predictor): enum for commands
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Jan 21, 2025
1 parent be04aa2 commit d0df2f4
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions src/nrtk_explorer/library/multiprocess_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@
import logging
import queue
import uuid
from enum import Enum
from .predictor import Predictor

WORKER_RESPONSE_TIMEOUT = 40


class Command(Enum):
SET_MODEL = "SET_MODEL"
INFER = "INFER"
RESET = "RESET"


def _child_worker(request_queue, result_queue, model_name, force_cpu):
signal.signal(signal.SIGINT, signal.SIG_IGN) # Ignore Ctrl+C in child
logger = logging.getLogger(__name__)
Expand All @@ -25,12 +32,12 @@ def _child_worker(request_queue, result_queue, model_name, force_cpu):
logger.debug("Worker: Received EXIT command. Shutting down.")
break

command = msg["command"]
command = Command(msg["command"])
req_id = msg["req_id"]
payload = msg.get("payload", {})
logger.debug(f"Worker: Received {command} with ID {req_id}")
logger.debug(f"Worker: Received {command.value} with ID {req_id}")

if command == "SET_MODEL":
if command == Command.SET_MODEL:
try:
predictor = Predictor(
model_name=payload["model_name"], force_cpu=payload["force_cpu"]
Expand All @@ -39,14 +46,14 @@ def _child_worker(request_queue, result_queue, model_name, force_cpu):
except Exception as e:
logger.exception("Failed to set model.")
result_queue.put((req_id, {"status": "ERROR", "message": str(e)}))
elif command == "INFER":
elif command == Command.INFER:
try:
predictions = predictor.eval(payload["images"])
result_queue.put((req_id, {"status": "OK", "result": predictions}))
except Exception as e:
logger.exception("Inference failed.")
result_queue.put((req_id, {"status": "ERROR", "message": str(e)}))
elif command == "RESET":
elif command == Command.RESET:
try:
predictor.reset()
result_queue.put((req_id, {"status": "OK"}))
Expand Down Expand Up @@ -115,7 +122,7 @@ def set_model(self, model_name, force_cpu=False):
req_id = str(uuid.uuid4())
self._request_queue.put(
{
"command": "SET_MODEL",
"command": Command.SET_MODEL.value,
"req_id": req_id,
"payload": {
"model_name": self.model_name,
Expand All @@ -130,7 +137,11 @@ async def infer(self, images):
return {}
with self._lock:
req_id = str(uuid.uuid4())
new_req = {"command": "INFER", "req_id": req_id, "payload": {"images": images}}
new_req = {
"command": Command.INFER.value,
"req_id": req_id,
"payload": {"images": images},
}
self._request_queue.put(new_req)

resp = await self._wait_for_response_async(req_id)
Expand All @@ -139,7 +150,7 @@ async def infer(self, images):
def reset(self):
with self._lock:
req_id = str(uuid.uuid4())
self._request_queue.put({"command": "RESET", "req_id": req_id})
self._request_queue.put({"command": Command.RESET.value, "req_id": req_id})
return self._wait_for_response(req_id)

def shutdown(self):
Expand Down

0 comments on commit d0df2f4

Please sign in to comment.