Skip to content

Commit 504d3e2

Browse files
authored
Merge pull request #15 from amrahhh/features/session_manager
add session manager
2 parents 1b109ce + e363a3e commit 504d3e2

File tree

1 file changed

+174
-32
lines changed

1 file changed

+174
-32
lines changed

sqla_async_orm_queries/models.py

Lines changed: 174 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,24 @@ class Base(AsyncAttrs, DeclarativeBase):
2828
pass
2929

3030

31+
class SessionManager:
32+
def __init__(self, session_factory=None) -> None:
33+
self.session_factory = session_factory or SessionLocal
34+
self.session = None
35+
36+
async def __aenter__(self):
37+
self.session = self.session_factory()
38+
await self.session.begin()
39+
return self.session
40+
41+
async def __aexit__(self, exc_type, exc, tb):
42+
if exc_type:
43+
await self.session.rollback()
44+
else:
45+
await self.session.commit()
46+
await self.session.close()
47+
48+
3149
class Model(Base):
3250
__abstract__ = True
3351

@@ -50,9 +68,25 @@ def _build_columns(cls, columns: list[str]):
5068

5169
@classmethod
5270
async def execute_query(
53-
cls, query: Union[Query, str], scalar: bool = False, all: bool = False
71+
cls,
72+
query: Union[Query, str],
73+
scalar: bool = False,
74+
all: bool = False,
75+
session: AsyncSession = None,
5476
):
55-
async with SessionLocal() as session:
77+
if session is None:
78+
async with SessionLocal() as session:
79+
result = await session.execute(query)
80+
if scalar and all:
81+
data = result.scalars().all()
82+
elif not scalar and all:
83+
data = result.all()
84+
elif scalar and not all:
85+
data = result.scalar()
86+
else:
87+
raise NotImplementedError
88+
return data
89+
else:
5690
result = await session.execute(query)
5791
if scalar and all:
5892
data = result.scalars().all()
@@ -65,12 +99,21 @@ async def execute_query(
6599
return data
66100

67101
@classmethod
68-
async def create(cls, data: dict):
69-
async with SessionLocal() as session:
102+
async def create(cls, data: dict, session: AsyncSession = None):
103+
if session is None:
104+
async with SessionLocal() as session:
105+
try:
106+
data = cls(**data)
107+
session.add(data)
108+
await session.commit()
109+
return data
110+
except Exception as e:
111+
await session.rollback()
112+
raise e
113+
else:
70114
try:
71115
data = cls(**data)
72116
session.add(data)
73-
await session.commit()
74117
return data
75118
except Exception as e:
76119
await session.rollback()
@@ -83,10 +126,26 @@ async def select_one(
83126
order_by: list[str] = None,
84127
load_with: list[str] = None,
85128
loader_func: Callable = None,
86-
columns: list[str] = None
129+
columns: list[str] = None,
130+
session: AsyncSession = None
87131
):
88132
loaders = []
89-
async with SessionLocal() as session:
133+
if session is None:
134+
async with SessionLocal() as session:
135+
if load_with:
136+
loaders = cls._build_loader(load_with, loader_func)
137+
if columns:
138+
selected_columns = [getattr(cls, col) for col in columns]
139+
query = select(*selected_columns).where(*args).options(*loaders)
140+
else:
141+
query = select(cls).where(*args).options(*loaders)
142+
143+
query = cls._order_by(query, order_by)
144+
result = await session.execute(query)
145+
data = result.scalar()
146+
147+
return data
148+
else:
90149
if load_with:
91150
loaders = cls._build_loader(load_with, loader_func)
92151
if columns:
@@ -108,10 +167,27 @@ async def select_all(
108167
order_by: list[str] = None,
109168
load_with: list[str] = None,
110169
loader_func: Callable = None,
111-
columns: list[str] = None
170+
columns: list[str] = None,
171+
session: AsyncSession = None
112172
):
113173
loaders = []
114-
async with SessionLocal() as session:
174+
if session is None:
175+
async with SessionLocal() as session:
176+
if load_with:
177+
loaders = cls._build_loader(load_with, loader_func)
178+
179+
if columns:
180+
selected_columns = [getattr(cls, col) for col in columns]
181+
query = select(*selected_columns).where(*args).options(*loaders)
182+
else:
183+
query = select(cls).where(*args).options(*loaders)
184+
185+
query = cls._order_by(query, order_by)
186+
result = await session.execute(query)
187+
data = result.scalars().all()
188+
189+
return data
190+
else:
115191
if load_with:
116192
loaders = cls._build_loader(load_with, loader_func)
117193

@@ -125,36 +201,62 @@ async def select_all(
125201
result = await session.execute(query)
126202
data = result.scalars().all()
127203

128-
return data
204+
return data
129205

130206
@classmethod
131-
async def update(cls, data: dict, *args: BinaryExpression):
132-
async with SessionLocal() as session:
207+
async def update(
208+
cls, data: dict, *args: BinaryExpression, session: AsyncSession = None
209+
):
210+
if session is None:
211+
async with SessionLocal() as session:
212+
try:
213+
query = update(cls).where(*args).values(**data).returning(cls.id)
214+
db_data = await session.execute(query)
215+
db_data = db_data.scalar()
216+
await session.commit()
217+
return db_data
218+
except Exception as e:
219+
await session.rollback()
220+
raise e
221+
else:
133222
try:
134223
query = update(cls).where(*args).values(**data).returning(cls.id)
135224
db_data = await session.execute(query)
136225
db_data = db_data.scalar()
137-
await session.commit()
138226
return db_data
139227
except Exception as e:
140228
await session.rollback()
141229
raise e
142230

143231
@classmethod
144-
async def delete(cls, *args: BinaryExpression):
145-
async with SessionLocal() as session:
232+
async def delete(cls, *args: BinaryExpression, session: AsyncSession = None):
233+
if session is None:
234+
async with SessionLocal() as session:
235+
try:
236+
query = delete(cls).where(*args)
237+
db_data = await session.execute(query)
238+
await session.commit()
239+
return db_data
240+
except Exception as e:
241+
await session.rollback()
242+
raise e
243+
else:
146244
try:
147245
query = delete(cls).where(*args)
148246
db_data = await session.execute(query)
149-
await session.commit()
150247
return db_data
151248
except Exception as e:
152249
await session.rollback()
153250
raise e
154251

155252
@classmethod
156-
async def get_count(cls, *args: BinaryExpression):
157-
async with SessionLocal() as session:
253+
async def get_count(cls, *args: BinaryExpression, session: AsyncSession = None):
254+
if session is None:
255+
async with SessionLocal() as session:
256+
result = await session.execute(select(count(cls.id)).where(*args))
257+
total_count = result.scalar()
258+
return total_count
259+
else:
158260
result = await session.execute(select(count(cls.id)).where(*args))
159261
total_count = result.scalar()
160262
return total_count
@@ -167,12 +269,28 @@ async def select_with_pagination(
167269
limit: int = 10,
168270
order_by: list[str] = None,
169271
load_with: list[str] = None,
170-
loader_func: Callable = None
272+
loader_func: Callable = None,
273+
session: AsyncSession = None
171274
):
172275
loaders = []
173276
if offset < 0:
174277
raise Exception("offset can not be a negative")
175-
async with SessionLocal() as session:
278+
if session is None:
279+
async with SessionLocal() as session:
280+
if load_with:
281+
loaders = cls._build_loader(load_with, loader_func)
282+
query = (
283+
select(cls)
284+
.where(*args)
285+
.offset(offset)
286+
.limit(limit)
287+
.options(*loaders)
288+
)
289+
query = cls._order_by(query, order_by)
290+
result = await session.execute(query)
291+
data = result.scalars().all()
292+
return data
293+
else:
176294
if load_with:
177295
loaders = cls._build_loader(load_with, loader_func)
178296
query = (
@@ -183,18 +301,33 @@ async def select_with_pagination(
183301
data = result.scalars().all()
184302
return data
185303

186-
async def apply(self):
187-
async with SessionLocal() as session:
304+
async def apply(self, session: AsyncSession = None):
305+
if session is None:
306+
async with SessionLocal() as session:
307+
try:
308+
session.add(self)
309+
await session.commit()
310+
except Exception as e:
311+
await session.rollback()
312+
raise e
313+
else:
188314
try:
189315
session.add(self)
190-
await session.commit()
191316
except Exception as e:
192317
await session.rollback()
193318
raise e
194319

195320
@classmethod
196-
async def apply_all(self, models: List[TModels]):
197-
async with SessionLocal() as session:
321+
async def apply_all(self, models: List[TModels], session: AsyncSession = None):
322+
if session is None:
323+
async with SessionLocal() as session:
324+
try:
325+
session.add_all(models)
326+
await session.commit()
327+
except Exception as e:
328+
await session.rollback()
329+
raise e
330+
else:
198331
try:
199332
session.add_all(models)
200333
await session.commit()
@@ -211,22 +344,31 @@ async def select_with_joins(
211344
columns: List[str],
212345
order_by: List[str] = None,
213346
offset: int = 0,
214-
limit: int = 10
347+
limit: int = 10,
348+
session: AsyncSession = None
215349
):
216350
if offset < 0:
217351
raise ValueError("Offset cannot be negative")
218-
219-
async with SessionLocal() as session:
352+
if session is None:
353+
async with SessionLocal() as session:
354+
query = select(*columns)
355+
for join_table, condition in zip(join_tables, join_conditions):
356+
query = query.join(join_table, condition)
357+
query = query.where(*args)
358+
if order_by:
359+
query = query.order_by(*order_by)
360+
query = query.offset(offset).limit(limit)
361+
result = await session.execute(query)
362+
data = result.all()
363+
return data
364+
else:
220365
query = select(*columns)
221366
for join_table, condition in zip(join_tables, join_conditions):
222367
query = query.join(join_table, condition)
223368
query = query.where(*args)
224369
if order_by:
225370
query = query.order_by(*order_by)
226-
227371
query = query.offset(offset).limit(limit)
228-
229372
result = await session.execute(query)
230373
data = result.all()
231-
232-
return data
374+
return data

0 commit comments

Comments
 (0)