Skip to content

Commit

Permalink
test: monkeypatch instead of leaking test specifics into production code
Browse files Browse the repository at this point in the history
  • Loading branch information
tumidi committed Jul 4, 2024
1 parent 36a93e8 commit 92b5295
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 72 deletions.
29 changes: 2 additions & 27 deletions questionpy_server/worker/runtime/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,6 @@
# The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md.
# (c) Technische Universität Berlin, innoCampus <[email protected]>

import sys
from io import BufferedReader, FileIO, StringIO
from questionpy_server.worker.runtime.subprocess import main

from questionpy_server.worker.runtime.connection import WorkerToServerConnection
from questionpy_server.worker.runtime.manager import WorkerManager


def setup_server_communication() -> WorkerToServerConnection:
"""Setup stdin/stdout/stderr.
The application server communicates with its worker through stdin/stdout, which means only this class is allowed to
read from and write to these pipes. Other output should go through stderr.
"""
file_stdin = FileIO(sys.stdin.buffer.fileno(), "r", closefd=False)
file_stdout = FileIO(sys.stdout.fileno(), "w", closefd=False)
connection = WorkerToServerConnection(BufferedReader(file_stdin), file_stdout)

sys.stdin = StringIO()
sys.stdout = sys.stderr # All writes to sys.stdout should go to stderr.
return connection


if __name__ == "__main__":
sys.dont_write_bytecode = True
con = setup_server_communication()
manager = WorkerManager(con)
manager.bootstrap()
manager.loop()
main()
10 changes: 0 additions & 10 deletions questionpy_server/worker/runtime/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from questionpy_server.worker.runtime.connection import WorkerToServerConnection
from questionpy_server.worker.runtime.messages import (
CreateQuestionFromOptions,
DebugExec,
Exit,
GetOptionsForm,
GetQPyPackageManifest,
Expand Down Expand Up @@ -70,8 +69,6 @@ def __init__(self, server_connection: WorkerToServerConnection):
ViewAttempt.message_id: self.on_msg_view_attempt,
ScoreAttempt.message_id: self.on_msg_score_attempt,
}
if __debug__:
self._message_dispatch[DebugExec.message_id] = self.on_msg_debug_exec

self._on_request_callbacks: list[OnRequestCallback] = []

Expand Down Expand Up @@ -191,13 +188,6 @@ def on_msg_score_attempt(self, msg: ScoreAttempt) -> ScoreAttempt.Response:
attempt_scored_model = question.score_attempt(msg.attempt_state, msg.scoring_state, msg.response)
return ScoreAttempt.Response(attempt_scored_model=attempt_scored_model)

if __debug__:

def on_msg_debug_exec(self, msg: DebugExec) -> DebugExec.Response:
locals_ = msg.locals.copy()
exec(msg.code, {"manager": self, "env": self._env, **globals()}, locals_) # noqa: S102
return DebugExec.Response()

@staticmethod
def _raise_not_initialized(msg: MessageToWorker) -> NoReturn:
errmsg = f"'{InitWorker.__name__}' message expected, '{type(msg).__name__}' received"
Expand Down
20 changes: 0 additions & 20 deletions questionpy_server/worker/runtime/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ class MessageIds(IntEnum):
VIEW_ATTEMPT = 51
SCORE_ATTEMPT = 52

if __debug__:
DEBUG_EXEC = 900

# Worker to server.
WORKER_STARTED = 1000
SANDBOX_ENABLED = 1001
Expand All @@ -55,9 +52,6 @@ class MessageIds(IntEnum):
RETURN_VIEW_ATTEMPT = 1051
RETURN_SCORE_ATTEMPT = 1052

if __debug__:
RETURN_DEBUG_EXEC = 1900

ERROR = 1100


Expand Down Expand Up @@ -201,20 +195,6 @@ class Response(MessageToServer):
attempt_scored_model: AttemptScoredModel


if __debug__:

class DebugExec(MessageToWorker):
message_id: ClassVar[MessageIds] = MessageIds.DEBUG_EXEC
code: str
locals: dict[str, object] = {}

class Response(MessageToServer):
message_id: ClassVar[MessageIds] = MessageIds.RETURN_DEBUG_EXEC
else:
# So imports needn't by wrapped in 'if __debug__'.
DebugExec = NotImplemented # type: ignore[misc]


class WorkerError(MessageToServer):
"""Error message."""

Expand Down
32 changes: 32 additions & 0 deletions questionpy_server/worker/runtime/subprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# This file is part of the QuestionPy Server. (https://questionpy.org)
# The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md.
# (c) Technische Universität Berlin, innoCampus <[email protected]>

import sys
from io import BufferedReader, FileIO, StringIO

from questionpy_server.worker.runtime.connection import WorkerToServerConnection
from questionpy_server.worker.runtime.manager import WorkerManager


def setup_server_communication() -> WorkerToServerConnection:
"""Setup stdin/stdout/stderr.
The application server communicates with its worker through stdin/stdout, which means only this class is allowed to
read from and write to these pipes. Other output should go through stderr.
"""
file_stdin = FileIO(sys.stdin.buffer.fileno(), "r", closefd=False)
file_stdout = FileIO(sys.stdout.fileno(), "w", closefd=False)
connection = WorkerToServerConnection(BufferedReader(file_stdin), file_stdout)

sys.stdin = StringIO()
sys.stdout = sys.stderr # All writes to sys.stdout should go to stderr.
return connection


def main() -> None:
sys.dont_write_bytecode = True
con = setup_server_communication()
manager = WorkerManager(con)
manager.bootstrap()
manager.loop()
5 changes: 4 additions & 1 deletion questionpy_server/worker/worker/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class SubprocessWorker(BaseWorker):

_worker_type = "process"

# Allows to use a patched runtime in tests.
_worker_entrypoint = "questionpy_server.worker.runtime"

def __init__(self, package: PackageLocation, limits: WorkerResourceLimits | None):
super().__init__(package, limits)

Expand All @@ -91,7 +94,7 @@ async def start(self) -> None:
sys.executable,
*python_flags,
"-m",
"questionpy_server.worker.runtime",
self._worker_entrypoint,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# This file is part of the QuestionPy Server. (https://questionpy.org)
# The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md.
# (c) Technische Universität Berlin, innoCampus <[email protected]>

from questionpy_server.worker.runtime.subprocess import main
from tests.questionpy_server.worker.worker.conftest import patched_manager

if __name__ == "__main__":
with patched_manager():
main()
44 changes: 44 additions & 0 deletions tests/questionpy_server/worker/worker/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# This file is part of the QuestionPy Server. (https://questionpy.org)
# The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md.
# (c) Technische Universität Berlin, innoCampus <[email protected]>

from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any

import pytest

from questionpy_server.worker.pool import WorkerPool
from questionpy_server.worker.runtime.manager import WorkerManager
from questionpy_server.worker.worker.subprocess import SubprocessWorker
from questionpy_server.worker.worker.thread import ThreadWorker


@contextmanager
def patched_manager() -> Iterator[None]:
with pytest.MonkeyPatch.context() as mp:

def just_raise(*_: list[Any]) -> None:
msg = "some custom error"
raise RuntimeError(msg)

mp.setattr(WorkerManager, "on_msg_get_qpy_package_manifest", just_raise)
yield


@pytest.fixture
def patched_worker_pool(worker_pool: WorkerPool, monkeypatch: pytest.MonkeyPatch) -> Iterator[WorkerPool]:
if worker_pool._worker_type == ThreadWorker:
with patched_manager():
yield worker_pool

# Can't patch stuff inside subprocess, so we use an entrypoint wrapper.
elif worker_pool._worker_type == SubprocessWorker:
with monkeypatch.context() as mp:
patched_entrypoint = "tests.questionpy_server.worker.worker._patched_subprocess_worker"
mp.setattr(SubprocessWorker, "_worker_entrypoint", patched_entrypoint)
yield worker_pool

else:
msg = "Expected ThreadWorker or SubprocessWorker"
raise TypeError(msg)
18 changes: 4 additions & 14 deletions tests/questionpy_server/worker/worker/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from questionpy_common.constants import DIST_DIR, MANIFEST_FILENAME
from questionpy_common.manifest import PackageFile
from questionpy_server import WorkerPool
from questionpy_server.worker.runtime.messages import DebugExec, WorkerUnknownError
from questionpy_server.worker.runtime.messages import WorkerUnknownError
from questionpy_server.worker.runtime.package_location import DirPackageLocation
from questionpy_server.worker.worker import WorkerState
from questionpy_server.worker.worker.base import StaticFileSizeMismatchError
Expand Down Expand Up @@ -90,19 +90,9 @@ async def test_should_raise_static_file_size_mismatch_error_when_sizes_dont_matc
await worker.get_static_file(_STATIC_FILE_NAME)


_MAKE_GET_MANIFEST_RAISE = """
def just_raise(*_):
raise RuntimeError
manager._message_dispatch[GetQPyPackageManifest.message_id] = just_raise
"""


async def test_should_gracefully_handle_error_in_loop(worker_pool: WorkerPool) -> None:
async with worker_pool.get_worker(PACKAGE, 1, 1) as worker:
await worker.send_and_wait_for_response(DebugExec(code=_MAKE_GET_MANIFEST_RAISE), DebugExec.Response)

with pytest.raises(WorkerUnknownError):
async def test_should_gracefully_handle_error_in_loop(patched_worker_pool: WorkerPool) -> None:
async with patched_worker_pool.get_worker(PACKAGE, 1, 1) as worker:
with pytest.raises(WorkerUnknownError, match="some custom error"):
await worker.get_manifest()

assert worker.state == WorkerState.IDLE

0 comments on commit 92b5295

Please sign in to comment.