Skip to content

Commit

Permalink
fix: rollback to multiprocessing.Pool + maxtasksperchild to avoid…
Browse files Browse the repository at this point in the history
… deadlock while exiting
  • Loading branch information
ClemDoum committed Jan 3, 2024
1 parent 510106a commit 4c937a6
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 38 deletions.
73 changes: 39 additions & 34 deletions neo4j-app/neo4j_app/app/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import concurrent.futures
import inspect
import logging
import multiprocessing
Expand Down Expand Up @@ -37,7 +36,7 @@
_TASK_MANAGER: Optional[TaskManager] = None
_TEST_DB_FILE: Optional[Path] = None
_TEST_LOCK: Optional[multiprocessing.Lock] = None
_PROCESS_EXECUTOR: Optional[concurrent.futures.ProcessPoolExecutor] = None
_WORKER_POOL: Optional[multiprocessing.Pool] = None
_MP_CONTEXT = None


Expand Down Expand Up @@ -95,7 +94,13 @@ async def neo4j_driver_enter(**__):


async def neo4j_driver_exit(exc_type, exc_value, trace):
await _NEO4J_DRIVER.__aexit__(exc_type, exc_value, trace)
already_closed = False
try:
await _NEO4J_DRIVER.verify_connectivity()
except:
already_closed = True
if not already_closed:
await _NEO4J_DRIVER.__aexit__(exc_type, exc_value, trace)


def lifespan_neo4j_driver() -> neo4j.AsyncDriver:
Expand Down Expand Up @@ -146,7 +151,7 @@ def _lifespan_test_db_path() -> Path:

def test_process_manager_enter(**_):
global _PROCESS_MANAGER
_PROCESS_MANAGER = multiprocessing.Manager()
_PROCESS_MANAGER = lifespan_mp_context().Manager()


def test_process_manager_exit(exc_type, exc_value, trace):
Expand All @@ -172,20 +177,20 @@ def _lifespan_test_lock() -> multiprocessing.Lock:
return cast(multiprocessing.Lock, _TEST_LOCK)


def process_executor_enter(**_):
def worker_pool_enter(**_):
# pylint: disable=consider-using-with
global _PROCESS_EXECUTOR
global _WORKER_POOL
process_id = os.getpid()
config = lifespan_config()
n_workers = min(config.neo4j_app_n_async_workers, config.neo4j_app_task_queue_size)
n_workers = max(1, n_workers)
# TODO: let the process choose they ID and set it with the worker process ID,
# this will help debugging
worker_ids = [f"worker-{process_id}-{i}" for i in range(n_workers)]
_PROCESS_EXECUTOR = concurrent.futures.ProcessPoolExecutor( # pylint: disable=unnecessary-dunder-call
max_workers=n_workers,
mp_context=lifespan_mp_context(),
).__enter__()
_WORKER_POOL = multiprocessing.Pool(
processes=config.neo4j_app_n_async_workers, maxtasksperchild=1
)
_WORKER_POOL.__enter__() # pylint: disable=unnecessary-dunder-call
kwargs = dict()
worker_cls = config.to_worker_cls()
if worker_cls.__name__ == "MockWorker":
Expand All @@ -195,22 +200,22 @@ def process_executor_enter(**_):
for w_id in worker_ids:
kwargs.update({"worker_id": w_id})
logger.info("starting worker %s", w_id)
_PROCESS_EXECUTOR.submit(worker_cls.work_forever_from_config, **kwargs)
_WORKER_POOL.apply_async(worker_cls.work_forever_from_config, kwds=kwargs)

logger.info("worker pool ready !")


def process_executor_exit(exc_type, exc_value, trace):
def worker_pool_exit(exc_type, exc_value, trace):
# pylint: disable=unused-argument
pool = lifespan_process_executor()
pool.shutdown(wait=False)
pool = lifespan_worker_pool()
pool.__exit__(exc_type, exc_value, trace)
logger.debug("async worker pool has shut down !")


def lifespan_process_executor() -> concurrent.futures.ProcessPoolExecutor:
if _PROCESS_EXECUTOR is None:
def lifespan_worker_pool() -> multiprocessing.Pool:
if _WORKER_POOL is None:
raise DependencyInjectionError("worker pool")
return cast(concurrent.futures.ProcessPoolExecutor, _PROCESS_EXECUTOR)
return cast(multiprocessing.Pool, _WORKER_POOL)


def task_manager_enter(**_):
Expand Down Expand Up @@ -287,13 +292,13 @@ def lifespan_event_publisher() -> EventPublisher:
(None, _test_lock_enter, None),
("task manager creation", task_manager_enter, None),
("event publisher creation", event_publisher_enter, None),
("async worker executor creation", process_executor_enter, process_executor_exit),
("async worker pool creation", worker_pool_enter, worker_pool_exit),
("neo4j DB migration", migrate_app_db_enter, None),
]


@contextmanager
def _log_and_reraise():
def _log_exceptions():
try:
yield
except Exception as exc:
Expand All @@ -318,7 +323,7 @@ async def run_deps(dependencies: List, **kwargs) -> AsyncGenerator[None, None]:
to_close = []
original_ex = None
try:
with _log_and_reraise():
with _log_exceptions():
logger.info("applying dependencies...")
for name, enter_fn, exit_fn in dependencies:
if enter_fn is not None:
Expand All @@ -333,24 +338,24 @@ async def run_deps(dependencies: List, **kwargs) -> AsyncGenerator[None, None]:
except Exception as e: # pylint: disable=broad-exception-caught
original_ex = e
finally:
with _log_and_reraise():
to_raise = []
if original_ex is not None:
to_raise.append(original_ex)
logger.info("rolling back dependencies...")
for name, exit_fn in to_close[::-1]:
if exit_fn is None:
continue
try:
if name is not None:
logger.debug("rolling back %s", name)
exc_info = sys.exc_info()
to_raise = []
if original_ex is not None:
to_raise.append(original_ex)
logger.info("rolling back dependencies...")
for name, exit_fn in to_close[::-1]:
if exit_fn is None:
continue
try:
if name is not None:
logger.debug("rolling back %s", name)
exc_info = sys.exc_info()
with _log_exceptions():
if inspect.iscoroutinefunction(exit_fn):
await exit_fn(*exc_info)
else:
exit_fn(*exc_info)
except Exception as e: # pylint: disable=broad-exception-caught
to_raise.append(e)
except Exception as e: # pylint: disable=broad-exception-caught
to_raise.append(e)
logger.debug("rolled back all dependencies !")
if to_raise:
raise RuntimeError(to_raise)
4 changes: 2 additions & 2 deletions neo4j-app/neo4j_app/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
lifespan_es_client,
lifespan_neo4j_driver,
lifespan_task_manager,
lifespan_process_executor,
lifespan_worker_pool,
)
from neo4j_app.app.doc import OTHER_TAG
from neo4j_app.core import AppConfig
Expand All @@ -25,7 +25,7 @@ async def ping() -> str:
await driver.verify_connectivity()
lifespan_es_client()
lifespan_task_manager()
lifespan_process_executor()
lifespan_worker_pool()
except (DriverError, DependencyInjectionError) as e:
raise HTTPException(503, detail="Service Unavailable") from e
return "pong"
Expand Down
4 changes: 2 additions & 2 deletions neo4j-app/neo4j_app/tests/app/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def test_create_task_should_return_429_when_too_many_tasks(

# Then
assert res_0.status_code == 201, res_0.json()
# This one is queued or rejected depending on if the first one is processed or still
# in the queue
# This one is queued or rejected depending on whether the first one is processed or
# still in the queue
assert res_1.status_code in [201, 429], res_1.json()
assert res_2.status_code == 429, res_1.json()

Expand Down

0 comments on commit 4c937a6

Please sign in to comment.