Skip to content

Commit aebf824

Browse files
committed
fix(multiprocess_predictor): dont block main process when running inference
1 parent b430e7f commit aebf824

File tree

4 files changed

+25
-13
lines changed

4 files changed

+25
-13
lines changed

src/nrtk_explorer/app/images/annotations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ async def get_annotations(
7575
)
7676

7777
to_detect = {id: id_to_image[id] for id in misses}
78-
predictions = predictor.infer(to_detect)
78+
predictions = await predictor.infer(to_detect)
7979
for id, annotations in predictions.items():
8080
self.cache.add_item(
8181
id, annotations, self.add_to_cache_callback, self.delete_from_cache_callback

src/nrtk_explorer/app/transforms.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from typing import Dict, Callable
3-
from collections.abc import Mapping
3+
from collections.abc import MutableMapping
44

55
from trame.ui.quasar import QLayout
66
from trame.widgets import quasar
@@ -50,7 +50,7 @@
5050
logger.setLevel(logging.INFO)
5151

5252

53-
class LazyDict(Mapping):
53+
class LazyDict(MutableMapping):
5454
"""If function provided for value, run function when value is accessed"""
5555

5656
def __init__(self, *args, **kw):
@@ -63,6 +63,9 @@ def __getitem__(self, key):
6363
def __setitem__(self, key, value):
6464
self._raw_dict[key] = value
6565

66+
def __delitem__(self, key):
67+
del self._raw_dict[key]
68+
6669
def __iter__(self):
6770
return iter(self._raw_dict)
6871

src/nrtk_explorer/library/multiprocess_predictor.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import multiprocessing
2+
import asyncio
23
import signal
34
import threading
45
import logging
@@ -8,6 +9,7 @@
89

910

1011
def _child_worker(request_queue, result_queue, model_name, force_cpu):
12+
signal.signal(signal.SIGINT, signal.SIG_IGN) # Ignore Ctrl+C in child
1113
logger = logging.getLogger(__name__)
1214
predictor = Predictor(model_name=model_name, force_cpu=force_cpu)
1315

@@ -105,32 +107,38 @@ def set_model(self, model_name, force_cpu=False):
105107
)
106108
return self._wait_for_response(req_id)
107109

108-
def infer(self, images):
110+
async def infer(self, images):
109111
if not images:
110112
return {}
111113
with self._lock:
112114
req_id = str(uuid.uuid4())
113115
new_req = {"command": "INFER", "req_id": req_id, "payload": {"images": images}}
114116
self._request_queue.put(new_req)
115117

116-
resp = self._wait_for_response(req_id)
118+
resp = await self._wait_for_response_async(req_id)
117119
return resp.get("result")
118120

119-
def reset(self):
120-
with self._lock:
121-
req_id = str(uuid.uuid4())
122-
self._request_queue.put({"command": "RESET", "req_id": req_id})
123-
return self._wait_for_response(req_id)
121+
async def _wait_for_response_async(self, req_id):
122+
return await asyncio.get_event_loop().run_in_executor(None, self._get_response, req_id, 40)
124123

125124
def _wait_for_response(self, req_id):
125+
return self._get_response(req_id, 40)
126+
127+
def _get_response(self, req_id, timeout=40):
126128
while True:
127129
try:
128-
r_id, data = self._result_queue.get(timeout=40)
130+
r_id, data = self._result_queue.get(timeout=timeout)
129131
except queue.Empty:
130132
raise TimeoutError("No response from worker.")
131133
if r_id == req_id:
132134
return data
133135

136+
def reset(self):
137+
with self._lock:
138+
req_id = str(uuid.uuid4())
139+
self._request_queue.put({"command": "RESET", "req_id": req_id})
140+
return self._wait_for_response(req_id)
141+
134142
def shutdown(self):
135143
with self._lock:
136144
try:

tests/test_predictor.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from nrtk_explorer.library.scoring import compute_score
55
from nrtk_explorer.library.dataset import get_dataset
66
from utils import get_images, DATASET
7+
import asyncio
78

89

910
def test_predictor_small():
@@ -23,7 +24,7 @@ def predictor():
2324
def test_detect(predictor):
2425
"""Test the detect method with sample images."""
2526
images = get_images()
26-
results = predictor.infer(images)
27+
results = asyncio.run(predictor.infer(images))
2728
assert len(results) == len(images), "Number of results should match number of images"
2829
for img_id, preds in results.items():
2930
assert isinstance(preds, list), f"Predictions for {img_id} should be a list"
@@ -33,7 +34,7 @@ def test_set_model(predictor):
3334
"""Test setting a new model and performing detection."""
3435
predictor.set_model(model_name="hustvl/yolos-tiny")
3536
images = get_images()
36-
results = predictor.infer(images)
37+
results = asyncio.run(predictor.infer(images))
3738
assert len(results) == len(
3839
images
3940
), "Number of results should match number of images after setting new model"

0 commit comments

Comments
 (0)