Skip to content

Commit d0df2f4

Browse files
committed
refactor(multiprocess_predictor): enum for commands
1 parent be04aa2 commit d0df2f4

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

src/nrtk_explorer/library/multiprocess_predictor.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,18 @@
55
import logging
66
import queue
77
import uuid
8+
from enum import Enum
89
from .predictor import Predictor
910

1011
WORKER_RESPONSE_TIMEOUT = 40
1112

1213

14+
class Command(Enum):
15+
SET_MODEL = "SET_MODEL"
16+
INFER = "INFER"
17+
RESET = "RESET"
18+
19+
1320
def _child_worker(request_queue, result_queue, model_name, force_cpu):
1421
signal.signal(signal.SIGINT, signal.SIG_IGN) # Ignore Ctrl+C in child
1522
logger = logging.getLogger(__name__)
@@ -25,12 +32,12 @@ def _child_worker(request_queue, result_queue, model_name, force_cpu):
2532
logger.debug("Worker: Received EXIT command. Shutting down.")
2633
break
2734

28-
command = msg["command"]
35+
command = Command(msg["command"])
2936
req_id = msg["req_id"]
3037
payload = msg.get("payload", {})
31-
logger.debug(f"Worker: Received {command} with ID {req_id}")
38+
logger.debug(f"Worker: Received {command.value} with ID {req_id}")
3239

33-
if command == "SET_MODEL":
40+
if command == Command.SET_MODEL:
3441
try:
3542
predictor = Predictor(
3643
model_name=payload["model_name"], force_cpu=payload["force_cpu"]
@@ -39,14 +46,14 @@ def _child_worker(request_queue, result_queue, model_name, force_cpu):
3946
except Exception as e:
4047
logger.exception("Failed to set model.")
4148
result_queue.put((req_id, {"status": "ERROR", "message": str(e)}))
42-
elif command == "INFER":
49+
elif command == Command.INFER:
4350
try:
4451
predictions = predictor.eval(payload["images"])
4552
result_queue.put((req_id, {"status": "OK", "result": predictions}))
4653
except Exception as e:
4754
logger.exception("Inference failed.")
4855
result_queue.put((req_id, {"status": "ERROR", "message": str(e)}))
49-
elif command == "RESET":
56+
elif command == Command.RESET:
5057
try:
5158
predictor.reset()
5259
result_queue.put((req_id, {"status": "OK"}))
@@ -115,7 +122,7 @@ def set_model(self, model_name, force_cpu=False):
115122
req_id = str(uuid.uuid4())
116123
self._request_queue.put(
117124
{
118-
"command": "SET_MODEL",
125+
"command": Command.SET_MODEL.value,
119126
"req_id": req_id,
120127
"payload": {
121128
"model_name": self.model_name,
@@ -130,7 +137,11 @@ async def infer(self, images):
130137
return {}
131138
with self._lock:
132139
req_id = str(uuid.uuid4())
133-
new_req = {"command": "INFER", "req_id": req_id, "payload": {"images": images}}
140+
new_req = {
141+
"command": Command.INFER.value,
142+
"req_id": req_id,
143+
"payload": {"images": images},
144+
}
134145
self._request_queue.put(new_req)
135146

136147
resp = await self._wait_for_response_async(req_id)
@@ -139,7 +150,7 @@ async def infer(self, images):
139150
def reset(self):
140151
with self._lock:
141152
req_id = str(uuid.uuid4())
142-
self._request_queue.put({"command": "RESET", "req_id": req_id})
153+
self._request_queue.put({"command": Command.RESET.value, "req_id": req_id})
143154
return self._wait_for_response(req_id)
144155

145156
def shutdown(self):

0 commit comments

Comments
 (0)