Skip to content

Commit 083197a

Browse files
authored
Merge pull request #22 from h0rn3t/task-local-sessions
Task local sessions
2 parents 52272ca + 83d7bef commit 083197a

File tree

5 files changed

+129
-22
lines changed

5 files changed

+129
-22
lines changed

README.md

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,10 @@ app.add_middleware(
127127
routes.py
128128

129129
```python
130+
import asyncio
131+
130132
from fastapi import APIRouter
131-
from sqlalchemy import column
132-
from sqlalchemy import table
133+
from sqlalchemy import column, table, text
133134

134135
from databases import first_db, second_db
135136

@@ -147,4 +148,22 @@ async def get_files_from_first_db():
147148
async def get_files_from_second_db():
148149
result = await second_db.session.execute(foo.select())
149150
return result.fetchall()
151+
152+
153+
@router.get("/concurrent-queries")
154+
async def parallel_select():
155+
async with first_db(multi_sessions=True):
156+
async def execute_query(query):
157+
return await first_db.session.execute(text(query))
158+
159+
tasks = [
160+
asyncio.create_task(execute_query("SELECT 1")),
161+
asyncio.create_task(execute_query("SELECT 2")),
162+
asyncio.create_task(execute_query("SELECT 3")),
163+
asyncio.create_task(execute_query("SELECT 4")),
164+
asyncio.create_task(execute_query("SELECT 5")),
165+
asyncio.create_task(execute_query("SELECT 6")),
166+
]
167+
168+
await asyncio.gather(*tasks)
150169
```

fastapi_async_sqlalchemy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
__all__ = ["db", "SQLAlchemyMiddleware"]
44

5-
__version__ = "0.6.1"
5+
__version__ = "0.7.0.dev1"

fastapi_async_sqlalchemy/middleware.py

Lines changed: 77 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,31 @@
1+
import asyncio
2+
from asyncio import Task
13
from contextvars import ContextVar
24
from typing import Dict, Optional, Union
35

46
from sqlalchemy.engine import Engine
57
from sqlalchemy.engine.url import URL
6-
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
8+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
79
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
810
from starlette.requests import Request
911
from starlette.types import ASGIApp
1012

1113
from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError
1214

1315
try:
14-
from sqlalchemy.ext.asyncio import async_sessionmaker
16+
from sqlalchemy.ext.asyncio import async_sessionmaker # noqa: F811
1517
except ImportError:
1618
from sqlalchemy.orm import sessionmaker as async_sessionmaker
1719

1820

1921
def create_middleware_and_session_proxy():
2022
_Session: Optional[async_sessionmaker] = None
23+
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
24+
_multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False)
25+
_commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False)
2126
# Usage of context vars inside closures is not recommended, since they are not properly
2227
# garbage collected, but in our use case context var is created on program startup and
2328
# is used throughout the whole its lifecycle.
24-
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
2529

2630
class SQLAlchemyMiddleware(BaseHTTPMiddleware):
2731
def __init__(
@@ -61,38 +65,97 @@ def session(self) -> AsyncSession:
6165
if _Session is None:
6266
raise SessionNotInitialisedError
6367

64-
session = _session.get()
65-
if session is None:
66-
raise MissingSessionError
67-
68-
return session
68+
multi_sessions = _multi_sessions_ctx.get()
69+
if multi_sessions:
70+
"""In this case, we need to create a new session for each task.
71+
We also need to commit the session on exit if commit_on_exit is True.
72+
This is useful when we need to run multiple queries in parallel.
73+
For example, when we need to run multiple queries in parallel in a route handler.
74+
Example:
75+
```python
76+
async with db(multi_sessions=True):
77+
async def execute_query(query):
78+
return await db.session.execute(text(query))
79+
80+
tasks = [
81+
asyncio.create_task(execute_query("SELECT 1")),
82+
asyncio.create_task(execute_query("SELECT 2")),
83+
asyncio.create_task(execute_query("SELECT 3")),
84+
asyncio.create_task(execute_query("SELECT 4")),
85+
asyncio.create_task(execute_query("SELECT 5")),
86+
asyncio.create_task(execute_query("SELECT 6")),
87+
]
88+
89+
await asyncio.gather(*tasks)
90+
```
91+
"""
92+
commit_on_exit = _commit_on_exit_ctx.get()
93+
task: Task = asyncio.current_task() # type: ignore
94+
if not hasattr(task, "_db_session"):
95+
task._db_session = _Session() # type: ignore
96+
97+
def cleanup(future):
98+
session = getattr(task, "_db_session", None)
99+
if session:
100+
101+
async def do_cleanup():
102+
try:
103+
if future.exception():
104+
await session.rollback()
105+
else:
106+
if commit_on_exit:
107+
await session.commit()
108+
finally:
109+
await session.close()
110+
111+
asyncio.create_task(do_cleanup())
112+
113+
task.add_done_callback(cleanup)
114+
return task._db_session # type: ignore
115+
else:
116+
session = _session.get()
117+
if session is None:
118+
raise MissingSessionError
119+
return session
69120

70121
class DBSession(metaclass=DBSessionMeta):
71-
def __init__(self, session_args: Dict = None, commit_on_exit: bool = False):
122+
def __init__(
123+
self,
124+
session_args: Dict = None,
125+
commit_on_exit: bool = False,
126+
multi_sessions: bool = False,
127+
):
72128
self.token = None
129+
self.multi_sessions_token = None
130+
self.commit_on_exit_token = None
73131
self.session_args = session_args or {}
74132
self.commit_on_exit = commit_on_exit
133+
self.multi_sessions = multi_sessions
75134

76135
async def __aenter__(self):
77136
if not isinstance(_Session, async_sessionmaker):
78137
raise SessionNotInitialisedError
79138

80-
self.token = _session.set(_Session(**self.session_args)) # type: ignore
139+
if self.multi_sessions:
140+
self.multi_sessions_token = _multi_sessions_ctx.set(True)
141+
self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit)
142+
143+
self.token = _session.set(_Session(**self.session_args))
81144
return type(self)
82145

83146
async def __aexit__(self, exc_type, exc_value, traceback):
84147
session = _session.get()
85-
86148
try:
87149
if exc_type is not None:
88150
await session.rollback()
89-
elif (
90-
self.commit_on_exit
91-
): # Note: Changed this to elif to avoid commit after rollback
151+
elif self.commit_on_exit:
92152
await session.commit()
93153
finally:
94154
await session.close()
95155
_session.reset(self.token)
156+
if self.multi_sessions_token is not None:
157+
_multi_sessions_ctx.reset(self.multi_sessions_token)
158+
_commit_on_exit_ctx.reset(self.commit_on_exit_token)
96159

97160
return SQLAlchemyMiddleware, DBSession
98161

requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ coverage>=5.2.1
88
entrypoints==0.3
99
fastapi==0.90.0 # pyup: ignore
1010
flake8==3.7.9
11-
idna==2.8
11+
idna==3.7
1212
importlib-metadata==1.5.0
1313
isort==4.3.21
1414
mccabe==0.6.1
@@ -17,13 +17,13 @@ packaging>=22.0
1717
pathspec>=0.9.0
1818
pluggy==0.13.0
1919
pycodestyle==2.5.0
20-
pydantic==1.10.13
20+
pydantic==1.10.18
2121
pyflakes==2.1.1
2222
pyparsing==2.4.2
2323
pytest==7.2.0
2424
pytest-cov==2.11.1
2525
PyYAML>=5.4
26-
regex==2020.2.20
26+
regex>=2020.2.20
2727
requests>=2.22.0
2828
httpx>=0.20.0
2929
six==1.12.0
@@ -36,7 +36,7 @@ toml>=0.10.1
3636
typed-ast>=1.4.2
3737
urllib3>=1.25.9
3838
wcwidth==0.1.7
39-
zipp==3.1.0
39+
zipp==3.19.1
4040
black==24.4.2
4141
pytest-asyncio==0.21.0
42-
greenlet==2.0.2
42+
greenlet==3.1.1

tests/test_session.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import asyncio
2+
13
import pytest
4+
from sqlalchemy import text
25
from sqlalchemy.exc import IntegrityError
36
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
47
from starlette.middleware.base import BaseHTTPMiddleware
@@ -148,3 +151,25 @@ async def test_db_context_session_args(app, db, SQLAlchemyMiddleware, commit_on_
148151
session_args = {"expire_on_commit": False}
149152
async with db(session_args=session_args):
150153
db.session
154+
155+
156+
@pytest.mark.asyncio
157+
async def test_multi_sessions(app, db, SQLAlchemyMiddleware):
158+
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)
159+
160+
async with db(multi_sessions=True):
161+
162+
async def execute_query(query):
163+
return await db.session.execute(text(query))
164+
165+
tasks = [
166+
asyncio.create_task(execute_query("SELECT 1")),
167+
asyncio.create_task(execute_query("SELECT 2")),
168+
asyncio.create_task(execute_query("SELECT 3")),
169+
asyncio.create_task(execute_query("SELECT 4")),
170+
asyncio.create_task(execute_query("SELECT 5")),
171+
asyncio.create_task(execute_query("SELECT 6")),
172+
]
173+
174+
res = await asyncio.gather(*tasks)
175+
assert len(res) == 6

0 commit comments

Comments
 (0)