Skip to content

Commit da82518

Browse files
committed
Add support for async db operations
Summary: Use scoped sessions that is shared between threads to allow async db operations. Test Plan: Tests should pass
1 parent e24192d commit da82518

File tree

3 files changed

+179
-184
lines changed

3 files changed

+179
-184
lines changed

aepsych/database/db.py

Lines changed: 101 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from aepsych.config import Config
2121
from aepsych.strategy import Strategy
2222
from sqlalchemy import create_engine
23-
from sqlalchemy.orm import sessionmaker
23+
from sqlalchemy.orm import scoped_session, sessionmaker
2424
from sqlalchemy.orm.session import close_all_sessions
2525

2626
logger = logging.getLogger()
@@ -46,33 +46,22 @@ def __init__(self, db_path: Optional[str] = None, update: bool = True) -> None:
4646
else:
4747
logger.info(f"No DB found at {db_path}, creating a new DB!")
4848

49-
self._engine = self.get_engine()
49+
self._full_db_path = Path(self._db_dir)
50+
self._full_db_path.mkdir(parents=True, exist_ok=True)
51+
self._full_db_path = self._full_db_path.joinpath(self._db_name)
5052

51-
if update and self.is_update_required():
52-
self.perform_updates()
53-
54-
def get_engine(self) -> sessionmaker:
55-
"""Get the engine for the database.
56-
57-
Returns:
58-
sessionmaker: The sessionmaker object for the database.
59-
"""
60-
if not hasattr(self, "_engine") or self._engine is None:
61-
self._full_db_path = Path(self._db_dir)
62-
self._full_db_path.mkdir(parents=True, exist_ok=True)
63-
self._full_db_path = self._full_db_path.joinpath(self._db_name)
64-
65-
self._engine = create_engine(f"sqlite:///{self._full_db_path.as_posix()}")
53+
self._engine = create_engine(f"sqlite:///{self._full_db_path.as_posix()}")
6654

67-
# create the table metadata and tables
68-
tables.Base.metadata.create_all(self._engine)
55+
# create the table metadata and tables
56+
tables.Base.metadata.create_all(self._engine)
6957

70-
# create an ongoing session to be used. Provides a conduit
71-
# to the db so the instantiated objects work properly.
72-
Session = sessionmaker(bind=self.get_engine())
73-
self._session = Session()
58+
# Create a session to be start and closed on each use
59+
self.session = scoped_session(
60+
sessionmaker(bind=self._engine, expire_on_commit=False)
61+
)
7462

75-
return self._engine
63+
if update and self.is_update_required():
64+
self.perform_updates()
7665

7766
def delete_db(self) -> None:
7867
"""Delete the database."""
@@ -107,21 +96,6 @@ def perform_updates(self) -> None:
10796
tables.DbParamTable.update(self._engine)
10897
tables.DbOutcomeTable.update(self._engine)
10998

110-
@contextmanager
111-
def session_scope(self):
112-
"""Provide a transactional scope around a series of operations."""
113-
Session = sessionmaker(bind=self.get_engine())
114-
session = Session()
115-
try:
116-
yield session
117-
session.commit()
118-
except Exception as err:
119-
logger.error(f"db session use failed: {err}")
120-
session.rollback()
121-
raise
122-
finally:
123-
session.close()
124-
12599
# @retry(stop_max_attempt_number=8, wait_exponential_multiplier=1.8)
126100
def execute_sql_query(self, query: str, vals: Dict[str, str]) -> List[Any]:
127101
"""Execute an arbitrary query written in sql.
@@ -133,7 +107,7 @@ def execute_sql_query(self, query: str, vals: Dict[str, str]) -> List[Any]:
133107
Returns:
134108
List[Any]: The results of the query.
135109
"""
136-
with self.session_scope() as session:
110+
with self.session() as session:
137111
return session.execute(query, vals).all()
138112

139113
def get_master_records(self) -> List[tables.DBMasterTable]:
@@ -142,7 +116,8 @@ def get_master_records(self) -> List[tables.DBMasterTable]:
142116
Returns:
143117
List[tables.DBMasterTable]: The list of master records.
144118
"""
145-
records = self._session.query(tables.DBMasterTable).all()
119+
with self.session() as session:
120+
records = session.query(tables.DBMasterTable).all()
146121
return records
147122

148123
def get_master_record(self, master_id: int) -> Optional[tables.DBMasterTable]:
@@ -154,11 +129,12 @@ def get_master_record(self, master_id: int) -> Optional[tables.DBMasterTable]:
154129
Returns:
155130
tables.DBMasterTable or None: The master record or None if it doesn't exist.
156131
"""
157-
records = (
158-
self._session.query(tables.DBMasterTable)
159-
.filter(tables.DBMasterTable.unique_id == master_id)
160-
.all()
161-
)
132+
with self.session() as session:
133+
records = (
134+
session.query(tables.DBMasterTable)
135+
.filter(tables.DBMasterTable.unique_id == master_id)
136+
.all()
137+
)
162138

163139
if 0 < len(records):
164140
return records[0]
@@ -260,11 +236,7 @@ def get_params_for(self, master_id: int) -> List[List[tables.DbParamTable]]:
260236
raw_record = self.get_raw_for(master_id)
261237

262238
if raw_record is not None:
263-
return [
264-
rec.children_param
265-
for rec in self.get_raw_for(master_id)
266-
if rec is not None
267-
]
239+
return [raw.children_param for raw in raw_record]
268240

269241
return []
270242

@@ -283,14 +255,19 @@ def get_outcomes_for(self, master_id: int) -> List[List[tables.DbParamTable]]:
283255
raw_record = self.get_raw_for(master_id)
284256

285257
if raw_record is not None:
286-
return [
287-
rec.children_outcome
288-
for rec in self.get_raw_for(master_id)
289-
if rec is not None
290-
]
258+
return [raw.children_outcome for raw in raw_record]
291259

292260
return []
293261

262+
@staticmethod
263+
def _add_commit(session, obj):
264+
# Helps guarantee duplicated objects across session can still be written
265+
merged = session.merge(obj)
266+
session.add(merged)
267+
session.commit()
268+
session.refresh(merged)
269+
return merged
270+
294271
def record_setup(
295272
self,
296273
description: str = None,
@@ -313,34 +290,36 @@ def record_setup(
313290
Returns:
314291
str: The experiment id.
315292
"""
316-
self.get_engine()
317-
318-
master_table = tables.DBMasterTable()
319-
master_table.experiment_description = description
320-
master_table.experiment_name = name
321-
master_table.experiment_id = exp_id if exp_id is not None else str(uuid.uuid4())
322-
master_table.participant_id = (
323-
par_id if par_id is not None else str(uuid.uuid4())
324-
)
325-
master_table.extra_metadata = extra_metadata
326-
self._session.add(master_table)
293+
with self.session() as session:
294+
master_table = tables.DBMasterTable()
295+
master_table.experiment_description = description
296+
master_table.experiment_name = name
297+
master_table.experiment_id = (
298+
exp_id if exp_id is not None else str(uuid.uuid4())
299+
)
300+
master_table.participant_id = (
301+
par_id if par_id is not None else str(uuid.uuid4())
302+
)
303+
master_table.extra_metadata = extra_metadata
327304

328-
logger.debug(f"record_setup = [{master_table}]")
305+
master_table = self._add_commit(session, master_table)
329306

330-
record = tables.DbReplayTable()
331-
record.message_type = "setup"
332-
record.message_contents = request
307+
logger.debug(f"record_setup = [{master_table}]")
333308

334-
if request is not None and "extra_info" in request:
335-
record.extra_info = request["extra_info"]
309+
record = tables.DbReplayTable()
310+
record.message_type = "setup"
311+
record.message_contents = request
336312

337-
record.timestamp = datetime.datetime.now()
338-
record.parent = master_table
339-
logger.debug(f"record_setup = [{record}]")
313+
if request is not None and "extra_info" in request:
314+
record.extra_info = request["extra_info"]
340315

341-
self._session.add(record)
342-
self._session.commit()
316+
record.timestamp = datetime.datetime.now()
317+
record.parent = master_table
318+
logger.debug(f"record_setup = [{record}]")
343319

320+
self._add_commit(session, record)
321+
322+
master_table
344323
# return the master table if it has a link to the list of child rows
345324
# tis needs to be passed into all future calls to link properly
346325
return master_table
@@ -355,19 +334,19 @@ def record_message(
355334
type (str): The type of the message.
356335
request (Dict[str, Any]): The request.
357336
"""
358-
# create a linked setup table
359-
record = tables.DbReplayTable()
360-
record.message_type = type
361-
record.message_contents = request
337+
with self.session() as session:
338+
# create a linked setup table
339+
record = tables.DbReplayTable()
340+
record.message_type = type
341+
record.message_contents = request
362342

363-
if "extra_info" in request:
364-
record.extra_info = request["extra_info"]
343+
if "extra_info" in request:
344+
record.extra_info = request["extra_info"]
365345

366-
record.timestamp = datetime.datetime.now()
367-
record.parent = master_table
346+
record.timestamp = datetime.datetime.now()
347+
record.parent = master_table
368348

369-
self._session.add(record)
370-
self._session.commit()
349+
self._add_commit(session, record)
371350

372351
def record_raw(
373352
self,
@@ -387,19 +366,19 @@ def record_raw(
387366
Returns:
388367
tables.DbRawTable: The raw entry.
389368
"""
390-
raw_entry = tables.DbRawTable()
391-
raw_entry.model_data = model_data
369+
with self.session() as session:
370+
raw_entry = tables.DbRawTable()
371+
raw_entry.model_data = model_data
392372

393-
if timestamp is None:
394-
raw_entry.timestamp = datetime.datetime.now()
395-
else:
396-
raw_entry.timestamp = timestamp
397-
raw_entry.parent = master_table
373+
if timestamp is None:
374+
raw_entry.timestamp = datetime.datetime.now()
375+
else:
376+
raw_entry.timestamp = timestamp
377+
raw_entry.parent = master_table
398378

399-
raw_entry.extra_data = json.dumps(extra_data)
379+
raw_entry.extra_data = json.dumps(extra_data)
400380

401-
self._session.add(raw_entry)
402-
self._session.commit()
381+
raw_entry = self._add_commit(session, raw_entry)
403382

404383
return raw_entry
405384

@@ -413,14 +392,14 @@ def record_param(
413392
param_name (str): The parameter name.
414393
param_value (str): The parameter value.
415394
"""
416-
param_entry = tables.DbParamTable()
417-
param_entry.param_name = param_name
418-
param_entry.param_value = param_value
395+
with self.session() as session:
396+
param_entry = tables.DbParamTable()
397+
param_entry.param_name = param_name
398+
param_entry.param_value = param_value
419399

420-
param_entry.parent = raw_table
400+
param_entry.parent = raw_table
421401

422-
self._session.add(param_entry)
423-
self._session.commit()
402+
self._add_commit(session, param_entry)
424403

425404
def record_outcome(
426405
self, raw_table: tables.DbRawTable, outcome_name: str, outcome_value: float
@@ -432,14 +411,14 @@ def record_outcome(
432411
outcome_name (str): The outcome name.
433412
outcome_value (float): The outcome value.
434413
"""
435-
outcome_entry = tables.DbOutcomeTable()
436-
outcome_entry.outcome_name = outcome_name
437-
outcome_entry.outcome_value = outcome_value
414+
with self.session() as session:
415+
outcome_entry = tables.DbOutcomeTable()
416+
outcome_entry.outcome_name = outcome_name
417+
outcome_entry.outcome_value = outcome_value
438418

439-
outcome_entry.parent = raw_table
419+
outcome_entry.parent = raw_table
440420

441-
self._session.add(outcome_entry)
442-
self._session.commit()
421+
self._add_commit(session, outcome_entry)
443422

444423
def record_strat(
445424
self, master_table: tables.DBMasterTable, strat: io.BytesIO
@@ -450,13 +429,13 @@ def record_strat(
450429
master_table (tables.DBMasterTable): The master table.
451430
strat (BytesIO): The strategy in buffer form.
452431
"""
453-
strat_entry = tables.DbStratTable()
454-
strat_entry.strat = strat
455-
strat_entry.timestamp = datetime.datetime.now()
456-
strat_entry.parent = master_table
432+
with self.session() as session:
433+
strat_entry = tables.DbStratTable()
434+
strat_entry.strat = strat
435+
strat_entry.timestamp = datetime.datetime.now()
436+
strat_entry.parent = master_table
457437

458-
self._session.add(strat_entry)
459-
self._session.commit()
438+
self._add_commit(session, strat_entry)
460439

461440
def record_config(self, master_table: tables.DBMasterTable, config: Config) -> None:
462441
"""Record a config in the database.
@@ -465,13 +444,13 @@ def record_config(self, master_table: tables.DBMasterTable, config: Config) -> N
465444
master_table (tables.DBMasterTable): The master table.
466445
config (Config): The config.
467446
"""
468-
config_entry = tables.DbConfigTable()
469-
config_entry.config = config
470-
config_entry.timestamp = datetime.datetime.now()
471-
config_entry.parent = master_table
447+
with self.session() as session:
448+
config_entry = tables.DbConfigTable()
449+
config_entry.config = config
450+
config_entry.timestamp = datetime.datetime.now()
451+
config_entry.parent = master_table
472452

473-
self._session.add(config_entry)
474-
self._session.commit()
453+
self._add_commit(session, config_entry)
475454

476455
def summarize_experiments(self) -> pd.DataFrame:
477456
"""Provides a summary of the experiments contained in the database as a pandas dataframe.

0 commit comments

Comments
 (0)