@@ -28,6 +28,24 @@ class Base(AsyncAttrs, DeclarativeBase):
2828 pass
2929
3030
31+ class SessionManager :
32+ def __init__ (self , session_factory = None ) -> None :
33+ self .session_factory = session_factory or SessionLocal
34+ self .session = None
35+
36+ async def __aenter__ (self ):
37+ self .session = self .session_factory ()
38+ await self .session .begin ()
39+ return self .session
40+
41+ async def __aexit__ (self , exc_type , exc , tb ):
42+ if exc_type :
43+ await self .session .rollback ()
44+ else :
45+ await self .session .commit ()
46+ await self .session .close ()
47+
48+
3149class Model (Base ):
3250 __abstract__ = True
3351
@@ -50,9 +68,25 @@ def _build_columns(cls, columns: list[str]):
5068
5169 @classmethod
5270 async def execute_query (
53- cls , query : Union [Query , str ], scalar : bool = False , all : bool = False
71+ cls ,
72+ query : Union [Query , str ],
73+ scalar : bool = False ,
74+ all : bool = False ,
75+ session : AsyncSession = None ,
5476 ):
55- async with SessionLocal () as session :
77+ if session is None :
78+ async with SessionLocal () as session :
79+ result = await session .execute (query )
80+ if scalar and all :
81+ data = result .scalars ().all ()
82+ elif not scalar and all :
83+ data = result .all ()
84+ elif scalar and not all :
85+ data = result .scalar ()
86+ else :
87+ raise NotImplementedError
88+ return data
89+ else :
5690 result = await session .execute (query )
5791 if scalar and all :
5892 data = result .scalars ().all ()
@@ -65,12 +99,21 @@ async def execute_query(
6599 return data
66100
67101 @classmethod
68- async def create (cls , data : dict ):
69- async with SessionLocal () as session :
102+ async def create (cls , data : dict , session : AsyncSession = None ):
103+ if session is None :
104+ async with SessionLocal () as session :
105+ try :
106+ data = cls (** data )
107+ session .add (data )
108+ await session .commit ()
109+ return data
110+ except Exception as e :
111+ await session .rollback ()
112+ raise e
113+ else :
70114 try :
71115 data = cls (** data )
72116 session .add (data )
73- await session .commit ()
74117 return data
75118 except Exception as e :
76119 await session .rollback ()
@@ -83,10 +126,26 @@ async def select_one(
83126 order_by : list [str ] = None ,
84127 load_with : list [str ] = None ,
85128 loader_func : Callable = None ,
86- columns : list [str ] = None
129+ columns : list [str ] = None ,
130+ session : AsyncSession = None
87131 ):
88132 loaders = []
89- async with SessionLocal () as session :
133+ if session is None :
134+ async with SessionLocal () as session :
135+ if load_with :
136+ loaders = cls ._build_loader (load_with , loader_func )
137+ if columns :
138+ selected_columns = [getattr (cls , col ) for col in columns ]
139+ query = select (* selected_columns ).where (* args ).options (* loaders )
140+ else :
141+ query = select (cls ).where (* args ).options (* loaders )
142+
143+ query = cls ._order_by (query , order_by )
144+ result = await session .execute (query )
145+ data = result .scalar ()
146+
147+ return data
148+ else :
90149 if load_with :
91150 loaders = cls ._build_loader (load_with , loader_func )
92151 if columns :
@@ -108,10 +167,27 @@ async def select_all(
108167 order_by : list [str ] = None ,
109168 load_with : list [str ] = None ,
110169 loader_func : Callable = None ,
111- columns : list [str ] = None
170+ columns : list [str ] = None ,
171+ session : AsyncSession = None
112172 ):
113173 loaders = []
114- async with SessionLocal () as session :
174+ if session is None :
175+ async with SessionLocal () as session :
176+ if load_with :
177+ loaders = cls ._build_loader (load_with , loader_func )
178+
179+ if columns :
180+ selected_columns = [getattr (cls , col ) for col in columns ]
181+ query = select (* selected_columns ).where (* args ).options (* loaders )
182+ else :
183+ query = select (cls ).where (* args ).options (* loaders )
184+
185+ query = cls ._order_by (query , order_by )
186+ result = await session .execute (query )
187+ data = result .scalars ().all ()
188+
189+ return data
190+ else :
115191 if load_with :
116192 loaders = cls ._build_loader (load_with , loader_func )
117193
@@ -125,36 +201,62 @@ async def select_all(
125201 result = await session .execute (query )
126202 data = result .scalars ().all ()
127203
128- return data
204+ return data
129205
130206 @classmethod
131- async def update (cls , data : dict , * args : BinaryExpression ):
132- async with SessionLocal () as session :
207+ async def update (
208+ cls , data : dict , * args : BinaryExpression , session : AsyncSession = None
209+ ):
210+ if session is None :
211+ async with SessionLocal () as session :
212+ try :
213+ query = update (cls ).where (* args ).values (** data ).returning (cls .id )
214+ db_data = await session .execute (query )
215+ db_data = db_data .scalar ()
216+ await session .commit ()
217+ return db_data
218+ except Exception as e :
219+ await session .rollback ()
220+ raise e
221+ else :
133222 try :
134223 query = update (cls ).where (* args ).values (** data ).returning (cls .id )
135224 db_data = await session .execute (query )
136225 db_data = db_data .scalar ()
137- await session .commit ()
138226 return db_data
139227 except Exception as e :
140228 await session .rollback ()
141229 raise e
142230
143231 @classmethod
144- async def delete (cls , * args : BinaryExpression ):
145- async with SessionLocal () as session :
232+ async def delete (cls , * args : BinaryExpression , session : AsyncSession = None ):
233+ if session is None :
234+ async with SessionLocal () as session :
235+ try :
236+ query = delete (cls ).where (* args )
237+ db_data = await session .execute (query )
238+ await session .commit ()
239+ return db_data
240+ except Exception as e :
241+ await session .rollback ()
242+ raise e
243+ else :
146244 try :
147245 query = delete (cls ).where (* args )
148246 db_data = await session .execute (query )
149- await session .commit ()
150247 return db_data
151248 except Exception as e :
152249 await session .rollback ()
153250 raise e
154251
155252 @classmethod
156- async def get_count (cls , * args : BinaryExpression ):
157- async with SessionLocal () as session :
253+ async def get_count (cls , * args : BinaryExpression , session : AsyncSession = None ):
254+ if session is None :
255+ async with SessionLocal () as session :
256+ result = await session .execute (select (count (cls .id )).where (* args ))
257+ total_count = result .scalar ()
258+ return total_count
259+ else :
158260 result = await session .execute (select (count (cls .id )).where (* args ))
159261 total_count = result .scalar ()
160262 return total_count
@@ -167,12 +269,28 @@ async def select_with_pagination(
167269 limit : int = 10 ,
168270 order_by : list [str ] = None ,
169271 load_with : list [str ] = None ,
170- loader_func : Callable = None
272+ loader_func : Callable = None ,
273+ session : AsyncSession = None
171274 ):
172275 loaders = []
173276 if offset < 0 :
174277 raise Exception ("offset can not be a negative" )
175- async with SessionLocal () as session :
278+ if session is None :
279+ async with SessionLocal () as session :
280+ if load_with :
281+ loaders = cls ._build_loader (load_with , loader_func )
282+ query = (
283+ select (cls )
284+ .where (* args )
285+ .offset (offset )
286+ .limit (limit )
287+ .options (* loaders )
288+ )
289+ query = cls ._order_by (query , order_by )
290+ result = await session .execute (query )
291+ data = result .scalars ().all ()
292+ return data
293+ else :
176294 if load_with :
177295 loaders = cls ._build_loader (load_with , loader_func )
178296 query = (
@@ -183,18 +301,33 @@ async def select_with_pagination(
183301 data = result .scalars ().all ()
184302 return data
185303
186- async def apply (self ):
187- async with SessionLocal () as session :
304+ async def apply (self , session : AsyncSession = None ):
305+ if session is None :
306+ async with SessionLocal () as session :
307+ try :
308+ session .add (self )
309+ await session .commit ()
310+ except Exception as e :
311+ await session .rollback ()
312+ raise e
313+ else :
188314 try :
189315 session .add (self )
190- await session .commit ()
191316 except Exception as e :
192317 await session .rollback ()
193318 raise e
194319
195320 @classmethod
196- async def apply_all (self , models : List [TModels ]):
197- async with SessionLocal () as session :
321+ async def apply_all (self , models : List [TModels ], session : AsyncSession = None ):
322+ if session is None :
323+ async with SessionLocal () as session :
324+ try :
325+ session .add_all (models )
326+ await session .commit ()
327+ except Exception as e :
328+ await session .rollback ()
329+ raise e
330+ else :
198331 try :
199332 session .add_all (models )
200333 await session .commit ()
@@ -211,22 +344,31 @@ async def select_with_joins(
211344 columns : List [str ],
212345 order_by : List [str ] = None ,
213346 offset : int = 0 ,
214- limit : int = 10
347+ limit : int = 10 ,
348+ session : AsyncSession = None
215349 ):
216350 if offset < 0 :
217351 raise ValueError ("Offset cannot be negative" )
218-
219- async with SessionLocal () as session :
352+ if session is None :
353+ async with SessionLocal () as session :
354+ query = select (* columns )
355+ for join_table , condition in zip (join_tables , join_conditions ):
356+ query = query .join (join_table , condition )
357+ query = query .where (* args )
358+ if order_by :
359+ query = query .order_by (* order_by )
360+ query = query .offset (offset ).limit (limit )
361+ result = await session .execute (query )
362+ data = result .all ()
363+ return data
364+ else :
220365 query = select (* columns )
221366 for join_table , condition in zip (join_tables , join_conditions ):
222367 query = query .join (join_table , condition )
223368 query = query .where (* args )
224369 if order_by :
225370 query = query .order_by (* order_by )
226-
227371 query = query .offset (offset ).limit (limit )
228-
229372 result = await session .execute (query )
230373 data = result .all ()
231-
232- return data
374+ return data
0 commit comments