Skip to content

Commit c4446fc

Browse files
committed
Add utility class to support background server
Summary: A subclass of the aepsych server with methods specifically to run the server in a background process. This will be used to ensured that even within the same main script, the server will run like an actual server and does not do anything sneaky like bypassing the async queue. Test Plan: New test
1 parent 9619f2d commit c4446fc

File tree

3 files changed

+175
-12
lines changed

3 files changed

+175
-12
lines changed

aepsych/server/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
# This source code is licensed under the license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
from .server import AEPsychServer
8+
from .server import AEPsychBackgroundServer, AEPsychServer
99

10-
__all__ = ["AEPsychServer"]
10+
__all__ = ["AEPsychServer", "AEPsychBackgroundServer"]

aepsych/server/server.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
replay,
3333
)
3434
from aepsych.strategy import SequentialStrategy, Strategy
35+
from multiprocess import Process
3536

3637
logger = utils_logging.getLogger()
3738

@@ -48,11 +49,9 @@ def __init__(
4849
host: str = "0.0.0.0",
4950
port: int = 5555,
5051
database_path: str = "./databases/default.db",
51-
max_workers: Optional[int] = None,
5252
):
5353
self.host = host
5454
self.port = port
55-
self.max_workers = max_workers
5655
self.clients_connected = 0
5756
self.db: db.Database = db.Database(database_path)
5857
self.is_performing_replay = False
@@ -278,11 +277,6 @@ def start_blocking(self) -> None:
278277
process or machine."""
279278
asyncio.run(self.serve())
280279

281-
def start_background(self):
282-
"""Starts the server in a background thread. Used for scripts where the
283-
client and server are in the same process."""
284-
raise NotImplementedError
285-
286280
async def serve(self) -> None:
287281
"""Serves the server on the set IP and port. This creates a coroutine
288282
for asyncio to handle requests asyncronously.
@@ -291,7 +285,7 @@ async def serve(self) -> None:
291285
self.handle_client, self.host, self.port
292286
)
293287
self.loop = asyncio.get_running_loop()
294-
pool = concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers)
288+
pool = concurrent.futures.ThreadPoolExecutor()
295289
self.loop.set_default_executor(pool)
296290

297291
async with self.server:
@@ -427,6 +421,64 @@ def __getstate__(self):
427421
return state
428422

429423

424+
class AEPsychBackgroundServer(AEPsychServer):
425+
"""A class to handle the server in a background thread. Unlike the normal
426+
AEPsychServer, this does not create the db right away until the server is
427+
started. When starting this server, it'll be sent to another process, a db
428+
will be initialized, then the server will be served. This server should then
429+
be interacted with by the main thread via a client."""
430+
431+
def __init__(
432+
self,
433+
host: str = "0.0.0.0",
434+
port: int = 5555,
435+
database_path: str = "./databases/default.db",
436+
):
437+
self.host = host
438+
self.port = port
439+
self.database_path = database_path
440+
self.clients_connected = 0
441+
self.is_performing_replay = False
442+
self.exit_server_loop = False
443+
self._db_raw_record = None
444+
self.skip_computations = False
445+
self.background_process = None
446+
self.strat_names = None
447+
self.extensions = None
448+
self._strats = []
449+
self._parnames = []
450+
self._configs = []
451+
self._master_records = []
452+
self.strat_id = -1
453+
self.outcome_names = []
454+
455+
def _start_server(self) -> None:
456+
self.db: db.Database = db.Database(self.database_path)
457+
if self.db.is_update_required():
458+
self.db.perform_updates()
459+
460+
super().start_blocking()
461+
462+
def start(self):
463+
"""Starts the server in a background thread. Used by the client to start
464+
the server for a client in another process or machine."""
465+
self.background_process = Process(target=self._start_server, daemon=True)
466+
self.background_process.start()
467+
468+
def stop(self):
469+
"""Stops the server and closes the background process."""
470+
self.exit_server_loop = True
471+
self.background_process.terminate()
472+
self.background_process.join()
473+
self.background_process.close()
474+
self.background_process = None
475+
476+
def __getstate__(self):
477+
# Override parent's __getstate__ to not worry about the db
478+
state = self.__dict__.copy()
479+
return state
480+
481+
430482
def parse_argument():
431483
parser = argparse.ArgumentParser(description="AEPsych Server")
432484
parser.add_argument(

tests/server/test_server.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import aepsych.server as server
1818
import aepsych.utils_logging as utils_logging
19-
from aepsych.server.sockets import BAD_REQUEST
2019

2120
dummy_config = """
2221
[common]
@@ -88,7 +87,7 @@ async def asyncSetUp(self):
8887
self.port = 5555
8988

9089
# setup logger
91-
server.logger = utils_logging.getLogger("unittests")
90+
self.logger = utils_logging.getLogger("unittests")
9291

9392
# random datebase path name without dashes
9493
database_path = self.database_path
@@ -533,5 +532,117 @@ async def _mock_client2(request: Dict[str, Any]) -> Any:
533532
self.assertTrue(self.s.clients_connected == 2)
534533

535534

535+
class BackgroundServerTestCase(unittest.IsolatedAsyncioTestCase):
536+
@property
537+
def database_path(self):
538+
return "./{}_test_server.db".format(str(uuid.uuid4().hex))
539+
540+
async def asyncSetUp(self):
541+
self.ip = "127.0.0.1"
542+
self.port = 5555
543+
544+
# setup logger
545+
self.logger = utils_logging.getLogger("unittests")
546+
547+
# random datebase path name without dashes
548+
database_path = self.database_path
549+
self.s = server.AEPsychBackgroundServer(
550+
database_path=database_path, host=self.ip, port=self.port
551+
)
552+
self.db_name = database_path.split("/")[1]
553+
self.db_path = database_path
554+
555+
# Writer will be made in tests
556+
self.writer = None
557+
558+
async def asyncTearDown(self):
559+
# Stops the client
560+
if self.writer is not None:
561+
self.writer.close()
562+
563+
time.sleep(0.1)
564+
565+
# cleanup the db
566+
db_path = Path(self.db_path)
567+
try:
568+
print(db_path)
569+
db_path.unlink()
570+
except PermissionError as e:
571+
print("Failed to deleted database: ", e)
572+
573+
async def test_background_server(self):
574+
self.assertIsNone(self.s.background_process)
575+
self.s.start()
576+
self.assertTrue(self.s.background_process.is_alive())
577+
578+
# Make a client
579+
try_again = True
580+
attempts = 0
581+
while try_again:
582+
try_again = False
583+
attempts += 1
584+
try:
585+
reader, self.writer = await asyncio.open_connection(self.ip, self.port)
586+
except ConnectionRefusedError:
587+
if attempts > 10:
588+
raise ConnectionRefusedError
589+
try_again = True
590+
time.sleep(1)
591+
592+
async def _mock_client(request: Dict[str, Any]) -> Any:
593+
self.writer.write(json.dumps(request).encode())
594+
await self.writer.drain()
595+
596+
response = await reader.read(1024 * 512)
597+
return response.decode()
598+
599+
setup_request = {
600+
"type": "setup",
601+
"version": "0.01",
602+
"message": {"config_str": dummy_config},
603+
}
604+
ask_request = {"type": "ask", "message": ""}
605+
tell_request = {
606+
"type": "tell",
607+
"message": {"config": {"x": [0.5]}, "outcome": 1},
608+
"extra_info": {},
609+
}
610+
611+
await _mock_client(setup_request)
612+
613+
expected_x = [0, 1, 2, 3]
614+
expected_z = list(reversed(expected_x))
615+
expected_y = [x % 2 for x in expected_x]
616+
i = 0
617+
while True:
618+
response = await _mock_client(ask_request)
619+
response = json.loads(response)
620+
tell_request["message"]["config"]["x"] = [expected_x[i]]
621+
tell_request["message"]["config"]["z"] = [expected_z[i]]
622+
tell_request["message"]["outcome"] = expected_y[i]
623+
tell_request["extra_info"]["e1"] = 1
624+
tell_request["extra_info"]["e2"] = 2
625+
i = i + 1
626+
await _mock_client(tell_request)
627+
628+
if response["is_finished"]:
629+
break
630+
631+
self.s.stop()
632+
self.assertIsNone(self.s.background_process)
633+
634+
# Create a synchronous server to check db contents
635+
s = server.AEPsychServer(database_path=self.db_path)
636+
unique_id = s.db.get_master_records()[-1].unique_id
637+
out_df = s.get_dataframe_from_replay(unique_id)
638+
self.assertTrue((out_df.x == expected_x).all())
639+
self.assertTrue((out_df.z == expected_z).all())
640+
self.assertTrue((out_df.response == expected_y).all())
641+
self.assertTrue((out_df.e1 == [1] * 4).all())
642+
self.assertTrue((out_df.e2 == [2] * 4).all())
643+
self.assertTrue("post_mean" in out_df.columns)
644+
self.assertTrue("post_var" in out_df.columns)
645+
646+
536647
if __name__ == "__main__":
537648
unittest.main()

0 commit comments

Comments
 (0)