Skip to content

Commit 5b36546

Browse files
authored
Merge pull request #24 from h0rn3t/multi_sessions_wip
WIP: multi_sessions
2 parents 1b0b5a5 + 955f538 commit 5b36546

File tree

2 files changed

+17
-24
lines changed

2 files changed

+17
-24
lines changed

fastapi_async_sqlalchemy/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
__all__ = ["db", "SQLAlchemyMiddleware"]
44

5-
__version__ = "0.7.0.dev2"
5+
__version__ = "0.7.0.dev3"

fastapi_async_sqlalchemy/middleware.py

+16-23
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ def create_middleware_and_session_proxy():
2121
_Session: Optional[async_sessionmaker] = None
2222
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
2323
_multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False)
24-
_task_session_ctx: ContextVar[Optional[AsyncSession]] = ContextVar(
25-
"_task_session_ctx", default=None
26-
)
2724
_commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False)
2825
# Usage of context vars inside closures is not recommended, since they are not properly
2926
# garbage collected, but in our use case context var is created on program startup and
@@ -92,25 +89,22 @@ async def execute_query(query):
9289
```
9390
"""
9491
commit_on_exit = _commit_on_exit_ctx.get()
95-
session = _task_session_ctx.get()
96-
if session is None:
97-
session = _Session()
98-
_task_session_ctx.set(session)
99-
100-
async def cleanup():
101-
try:
102-
if commit_on_exit:
103-
await session.commit()
104-
except Exception:
105-
await session.rollback()
106-
raise
107-
finally:
108-
await session.close()
109-
_task_session_ctx.set(None)
110-
111-
task = asyncio.current_task()
112-
if task is not None:
113-
task.add_done_callback(lambda t: asyncio.create_task(cleanup()))
92+
# Always create a new session for each access when multi_sessions=True
93+
session = _Session()
94+
95+
async def cleanup():
96+
try:
97+
if commit_on_exit:
98+
await session.commit()
99+
except Exception:
100+
await session.rollback()
101+
raise
102+
finally:
103+
await session.close()
104+
105+
task = asyncio.current_task()
106+
if task is not None:
107+
task.add_done_callback(lambda t: asyncio.create_task(cleanup()))
114108
return session
115109
else:
116110
session = _session.get()
@@ -126,7 +120,6 @@ def __init__(
126120
multi_sessions: bool = False,
127121
):
128122
self.token = None
129-
self.multi_sessions_token = None
130123
self.commit_on_exit_token = None
131124
self.session_args = session_args or {}
132125
self.commit_on_exit = commit_on_exit

0 commit comments

Comments
 (0)