2020from aepsych .config import Config
2121from aepsych .strategy import Strategy
2222from sqlalchemy import create_engine
23- from sqlalchemy .orm import sessionmaker
23+ from sqlalchemy .orm import scoped_session , sessionmaker
2424from sqlalchemy .orm .session import close_all_sessions
2525
2626logger = logging .getLogger ()
@@ -46,33 +46,22 @@ def __init__(self, db_path: Optional[str] = None, update: bool = True) -> None:
4646 else :
4747 logger .info (f"No DB found at { db_path } , creating a new DB!" )
4848
49- self ._engine = self .get_engine ()
49+ self ._full_db_path = Path (self ._db_dir )
50+ self ._full_db_path .mkdir (parents = True , exist_ok = True )
51+ self ._full_db_path = self ._full_db_path .joinpath (self ._db_name )
5052
51- if update and self .is_update_required ():
52- self .perform_updates ()
53-
54- def get_engine (self ) -> sessionmaker :
55- """Get the engine for the database.
56-
57- Returns:
58- sessionmaker: The sessionmaker object for the database.
59- """
60- if not hasattr (self , "_engine" ) or self ._engine is None :
61- self ._full_db_path = Path (self ._db_dir )
62- self ._full_db_path .mkdir (parents = True , exist_ok = True )
63- self ._full_db_path = self ._full_db_path .joinpath (self ._db_name )
64-
65- self ._engine = create_engine (f"sqlite:///{ self ._full_db_path .as_posix ()} " )
53+ self ._engine = create_engine (f"sqlite:///{ self ._full_db_path .as_posix ()} " )
6654
67- # create the table metadata and tables
68- tables .Base .metadata .create_all (self ._engine )
55+ # create the table metadata and tables
56+ tables .Base .metadata .create_all (self ._engine )
6957
70- # create an ongoing session to be used. Provides a conduit
71- # to the db so the instantiated objects work properly.
72- Session = sessionmaker (bind = self .get_engine () )
73- self . _session = Session ( )
58+ # Create a session to be start and closed on each use
59+ self . session = scoped_session (
60+ sessionmaker (bind = self ._engine , expire_on_commit = False )
61+ )
7462
75- return self ._engine
63+ if update and self .is_update_required ():
64+ self .perform_updates ()
7665
7766 def delete_db (self ) -> None :
7867 """Delete the database."""
@@ -107,21 +96,6 @@ def perform_updates(self) -> None:
10796 tables .DbParamTable .update (self ._engine )
10897 tables .DbOutcomeTable .update (self ._engine )
10998
110- @contextmanager
111- def session_scope (self ):
112- """Provide a transactional scope around a series of operations."""
113- Session = sessionmaker (bind = self .get_engine ())
114- session = Session ()
115- try :
116- yield session
117- session .commit ()
118- except Exception as err :
119- logger .error (f"db session use failed: { err } " )
120- session .rollback ()
121- raise
122- finally :
123- session .close ()
124-
12599 # @retry(stop_max_attempt_number=8, wait_exponential_multiplier=1.8)
126100 def execute_sql_query (self , query : str , vals : Dict [str , str ]) -> List [Any ]:
127101 """Execute an arbitrary query written in sql.
@@ -133,7 +107,7 @@ def execute_sql_query(self, query: str, vals: Dict[str, str]) -> List[Any]:
133107 Returns:
134108 List[Any]: The results of the query.
135109 """
136- with self .session_scope () as session :
110+ with self .session () as session :
137111 return session .execute (query , vals ).all ()
138112
139113 def get_master_records (self ) -> List [tables .DBMasterTable ]:
@@ -142,7 +116,8 @@ def get_master_records(self) -> List[tables.DBMasterTable]:
142116 Returns:
143117 List[tables.DBMasterTable]: The list of master records.
144118 """
145- records = self ._session .query (tables .DBMasterTable ).all ()
119+ with self .session () as session :
120+ records = session .query (tables .DBMasterTable ).all ()
146121 return records
147122
148123 def get_master_record (self , master_id : int ) -> Optional [tables .DBMasterTable ]:
@@ -154,11 +129,12 @@ def get_master_record(self, master_id: int) -> Optional[tables.DBMasterTable]:
154129 Returns:
155130 tables.DBMasterTable or None: The master record or None if it doesn't exist.
156131 """
157- records = (
158- self ._session .query (tables .DBMasterTable )
159- .filter (tables .DBMasterTable .unique_id == master_id )
160- .all ()
161- )
132+ with self .session () as session :
133+ records = (
134+ session .query (tables .DBMasterTable )
135+ .filter (tables .DBMasterTable .unique_id == master_id )
136+ .all ()
137+ )
162138
163139 if 0 < len (records ):
164140 return records [0 ]
@@ -260,11 +236,7 @@ def get_params_for(self, master_id: int) -> List[List[tables.DbParamTable]]:
260236 raw_record = self .get_raw_for (master_id )
261237
262238 if raw_record is not None :
263- return [
264- rec .children_param
265- for rec in self .get_raw_for (master_id )
266- if rec is not None
267- ]
239+ return [raw .children_param for raw in raw_record ]
268240
269241 return []
270242
@@ -283,14 +255,19 @@ def get_outcomes_for(self, master_id: int) -> List[List[tables.DbParamTable]]:
283255 raw_record = self .get_raw_for (master_id )
284256
285257 if raw_record is not None :
286- return [
287- rec .children_outcome
288- for rec in self .get_raw_for (master_id )
289- if rec is not None
290- ]
258+ return [raw .children_outcome for raw in raw_record ]
291259
292260 return []
293261
262+ @staticmethod
263+ def _add_commit (session , obj ):
264+ # Helps guarantee duplicated objects across session can still be written
265+ merged = session .merge (obj )
266+ session .add (merged )
267+ session .commit ()
268+ session .refresh (merged )
269+ return merged
270+
294271 def record_setup (
295272 self ,
296273 description : str = None ,
@@ -313,34 +290,36 @@ def record_setup(
313290 Returns:
314291 str: The experiment id.
315292 """
316- self .get_engine ()
317-
318- master_table = tables . DBMasterTable ()
319- master_table .experiment_description = description
320- master_table .experiment_name = name
321- master_table . experiment_id = exp_id if exp_id is not None else str (uuid .uuid4 ())
322- master_table . participant_id = (
323- par_id if par_id is not None else str ( uuid . uuid4 ())
324- )
325- master_table . extra_metadata = extra_metadata
326- self . _session . add ( master_table )
293+ with self .session () as session :
294+ master_table = tables . DBMasterTable ()
295+ master_table . experiment_description = description
296+ master_table .experiment_name = name
297+ master_table .experiment_id = (
298+ exp_id if exp_id is not None else str (uuid .uuid4 ())
299+ )
300+ master_table . participant_id = (
301+ par_id if par_id is not None else str ( uuid . uuid4 () )
302+ )
303+ master_table . extra_metadata = extra_metadata
327304
328- logger . debug ( f"record_setup = [ { master_table } ]" )
305+ master_table = self . _add_commit ( session , master_table )
329306
330- record = tables .DbReplayTable ()
331- record .message_type = "setup"
332- record .message_contents = request
307+ logger .debug (f"record_setup = [{ master_table } ]" )
333308
334- if request is not None and "extra_info" in request :
335- record .extra_info = request ["extra_info" ]
309+ record = tables .DbReplayTable ()
310+ record .message_type = "setup"
311+ record .message_contents = request
336312
337- record .timestamp = datetime .datetime .now ()
338- record .parent = master_table
339- logger .debug (f"record_setup = [{ record } ]" )
313+ if request is not None and "extra_info" in request :
314+ record .extra_info = request ["extra_info" ]
340315
341- self ._session .add (record )
342- self ._session .commit ()
316+ record .timestamp = datetime .datetime .now ()
317+ record .parent = master_table
318+ logger .debug (f"record_setup = [{ record } ]" )
343319
320+ self ._add_commit (session , record )
321+
322+ master_table
344323 # return the master table if it has a link to the list of child rows
345324 # tis needs to be passed into all future calls to link properly
346325 return master_table
@@ -355,19 +334,19 @@ def record_message(
355334 type (str): The type of the message.
356335 request (Dict[str, Any]): The request.
357336 """
358- # create a linked setup table
359- record = tables .DbReplayTable ()
360- record .message_type = type
361- record .message_contents = request
337+ with self .session () as session :
338+ # create a linked setup table
339+ record = tables .DbReplayTable ()
340+ record .message_type = type
341+ record .message_contents = request
362342
363- if "extra_info" in request :
364- record .extra_info = request ["extra_info" ]
343+ if "extra_info" in request :
344+ record .extra_info = request ["extra_info" ]
365345
366- record .timestamp = datetime .datetime .now ()
367- record .parent = master_table
346+ record .timestamp = datetime .datetime .now ()
347+ record .parent = master_table
368348
369- self ._session .add (record )
370- self ._session .commit ()
349+ self ._add_commit (session , record )
371350
372351 def record_raw (
373352 self ,
@@ -387,19 +366,19 @@ def record_raw(
387366 Returns:
388367 tables.DbRawTable: The raw entry.
389368 """
390- raw_entry = tables .DbRawTable ()
391- raw_entry .model_data = model_data
369+ with self .session () as session :
370+ raw_entry = tables .DbRawTable ()
371+ raw_entry .model_data = model_data
392372
393- if timestamp is None :
394- raw_entry .timestamp = datetime .datetime .now ()
395- else :
396- raw_entry .timestamp = timestamp
397- raw_entry .parent = master_table
373+ if timestamp is None :
374+ raw_entry .timestamp = datetime .datetime .now ()
375+ else :
376+ raw_entry .timestamp = timestamp
377+ raw_entry .parent = master_table
398378
399- raw_entry .extra_data = json .dumps (extra_data )
379+ raw_entry .extra_data = json .dumps (extra_data )
400380
401- self ._session .add (raw_entry )
402- self ._session .commit ()
381+ raw_entry = self ._add_commit (session , raw_entry )
403382
404383 return raw_entry
405384
@@ -413,14 +392,14 @@ def record_param(
413392 param_name (str): The parameter name.
414393 param_value (str): The parameter value.
415394 """
416- param_entry = tables .DbParamTable ()
417- param_entry .param_name = param_name
418- param_entry .param_value = param_value
395+ with self .session () as session :
396+ param_entry = tables .DbParamTable ()
397+ param_entry .param_name = param_name
398+ param_entry .param_value = param_value
419399
420- param_entry .parent = raw_table
400+ param_entry .parent = raw_table
421401
422- self ._session .add (param_entry )
423- self ._session .commit ()
402+ self ._add_commit (session , param_entry )
424403
425404 def record_outcome (
426405 self , raw_table : tables .DbRawTable , outcome_name : str , outcome_value : float
@@ -432,14 +411,14 @@ def record_outcome(
432411 outcome_name (str): The outcome name.
433412 outcome_value (float): The outcome value.
434413 """
435- outcome_entry = tables .DbOutcomeTable ()
436- outcome_entry .outcome_name = outcome_name
437- outcome_entry .outcome_value = outcome_value
414+ with self .session () as session :
415+ outcome_entry = tables .DbOutcomeTable ()
416+ outcome_entry .outcome_name = outcome_name
417+ outcome_entry .outcome_value = outcome_value
438418
439- outcome_entry .parent = raw_table
419+ outcome_entry .parent = raw_table
440420
441- self ._session .add (outcome_entry )
442- self ._session .commit ()
421+ self ._add_commit (session , outcome_entry )
443422
444423 def record_strat (
445424 self , master_table : tables .DBMasterTable , strat : io .BytesIO
@@ -450,13 +429,13 @@ def record_strat(
450429 master_table (tables.DBMasterTable): The master table.
451430 strat (BytesIO): The strategy in buffer form.
452431 """
453- strat_entry = tables .DbStratTable ()
454- strat_entry .strat = strat
455- strat_entry .timestamp = datetime .datetime .now ()
456- strat_entry .parent = master_table
432+ with self .session () as session :
433+ strat_entry = tables .DbStratTable ()
434+ strat_entry .strat = strat
435+ strat_entry .timestamp = datetime .datetime .now ()
436+ strat_entry .parent = master_table
457437
458- self ._session .add (strat_entry )
459- self ._session .commit ()
438+ self ._add_commit (session , strat_entry )
460439
461440 def record_config (self , master_table : tables .DBMasterTable , config : Config ) -> None :
462441 """Record a config in the database.
@@ -465,13 +444,13 @@ def record_config(self, master_table: tables.DBMasterTable, config: Config) -> N
465444 master_table (tables.DBMasterTable): The master table.
466445 config (Config): The config.
467446 """
468- config_entry = tables .DbConfigTable ()
469- config_entry .config = config
470- config_entry .timestamp = datetime .datetime .now ()
471- config_entry .parent = master_table
447+ with self .session () as session :
448+ config_entry = tables .DbConfigTable ()
449+ config_entry .config = config
450+ config_entry .timestamp = datetime .datetime .now ()
451+ config_entry .parent = master_table
472452
473- self ._session .add (config_entry )
474- self ._session .commit ()
453+ self ._add_commit (session , config_entry )
475454
476455 def summarize_experiments (self ) -> pd .DataFrame :
477456 """Provides a summary of the experiments contained in the database as a pandas dataframe.
0 commit comments