Skip to content

Commit

Permalink
fix(multiprocess_predictor): dont block main process when running inf…
Browse files Browse the repository at this point in the history
…erence
  • Loading branch information
PaulHax committed Jan 21, 2025
1 parent 25253e3 commit e94e3de
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/nrtk_explorer/app/images/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def get_annotations(
)

to_detect = {id: id_to_image[id] for id in misses}
predictions = predictor.infer(to_detect)
predictions = await predictor.infer(to_detect)
for id, annotations in predictions.items():
self.cache.add_item(
id, annotations, self.add_to_cache_callback, self.delete_from_cache_callback
Expand Down
7 changes: 5 additions & 2 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from typing import Dict, Callable
from collections.abc import Mapping
from collections.abc import MutableMapping

from trame.ui.quasar import QLayout
from trame.widgets import quasar
Expand Down Expand Up @@ -50,7 +50,7 @@
logger.setLevel(logging.INFO)


class LazyDict(Mapping):
class LazyDict(MutableMapping):
"""If function provided for value, run function when value is accessed"""

def __init__(self, *args, **kw):
Expand All @@ -63,6 +63,9 @@ def __getitem__(self, key):
def __setitem__(self, key, value):
self._raw_dict[key] = value

def __delitem__(self, key):
del self._raw_dict[key]

def __iter__(self):
return iter(self._raw_dict)

Expand Down
24 changes: 16 additions & 8 deletions src/nrtk_explorer/library/multiprocess_predictor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import multiprocessing
import asyncio
import signal
import threading
import logging
Expand All @@ -8,6 +9,7 @@


def _child_worker(request_queue, result_queue, model_name, force_cpu):
signal.signal(signal.SIGINT, signal.SIG_IGN) # Ignore Ctrl+C in child
logger = logging.getLogger(__name__)
predictor = Predictor(model_name=model_name, force_cpu=force_cpu)

Expand Down Expand Up @@ -105,32 +107,38 @@ def set_model(self, model_name, force_cpu=False):
)
return self._wait_for_response(req_id)

def infer(self, images):
async def infer(self, images):
if not images:
return {}
with self._lock:
req_id = str(uuid.uuid4())
new_req = {"command": "INFER", "req_id": req_id, "payload": {"images": images}}
self._request_queue.put(new_req)

resp = self._wait_for_response(req_id)
resp = await self._wait_for_response_async(req_id)
return resp.get("result")

def reset(self):
with self._lock:
req_id = str(uuid.uuid4())
self._request_queue.put({"command": "RESET", "req_id": req_id})
return self._wait_for_response(req_id)
async def _wait_for_response_async(self, req_id):
return await asyncio.get_event_loop().run_in_executor(None, self._get_response, req_id, 40)

def _wait_for_response(self, req_id):
return self._get_response(req_id, 40)

def _get_response(self, req_id, timeout=40):
while True:
try:
r_id, data = self._result_queue.get(timeout=40)
r_id, data = self._result_queue.get(timeout=timeout)
except queue.Empty:
raise TimeoutError("No response from worker.")
if r_id == req_id:
return data

def reset(self):
with self._lock:
req_id = str(uuid.uuid4())
self._request_queue.put({"command": "RESET", "req_id": req_id})
return self._wait_for_response(req_id)

def shutdown(self):
with self._lock:
try:
Expand Down
5 changes: 3 additions & 2 deletions tests/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from nrtk_explorer.library.scoring import compute_score
from nrtk_explorer.library.dataset import get_dataset
from utils import get_images, DATASET
import asyncio


def test_predictor_small():
Expand All @@ -23,7 +24,7 @@ def predictor():
def test_detect(predictor):
"""Test the detect method with sample images."""
images = get_images()
results = predictor.infer(images)
results = asyncio.run(predictor.infer(images))
assert len(results) == len(images), "Number of results should match number of images"
for img_id, preds in results.items():
assert isinstance(preds, list), f"Predictions for {img_id} should be a list"
Expand All @@ -33,7 +34,7 @@ def test_set_model(predictor):
"""Test setting a new model and performing detection."""
predictor.set_model(model_name="hustvl/yolos-tiny")
images = get_images()
results = predictor.infer(images)
results = asyncio.run(predictor.infer(images))
assert len(results) == len(
images
), "Number of results should match number of images after setting new model"
Expand Down

0 comments on commit e94e3de

Please sign in to comment.