diff --git a/README.md b/README.md index 84546be..6d1e722 100644 --- a/README.md +++ b/README.md @@ -127,9 +127,10 @@ app.add_middleware( routes.py ```python +import asyncio + from fastapi import APIRouter -from sqlalchemy import column -from sqlalchemy import table +from sqlalchemy import column, table, text from databases import first_db, second_db @@ -147,4 +148,22 @@ async def get_files_from_first_db(): async def get_files_from_second_db(): result = await second_db.session.execute(foo.select()) return result.fetchall() + + +@router.get("/concurrent-queries") +async def parallel_select(): + async with first_db(multi_sessions=True): + async def execute_query(query): + return await first_db.session.execute(text(query)) + + tasks = [ + asyncio.create_task(execute_query("SELECT 1")), + asyncio.create_task(execute_query("SELECT 2")), + asyncio.create_task(execute_query("SELECT 3")), + asyncio.create_task(execute_query("SELECT 4")), + asyncio.create_task(execute_query("SELECT 5")), + asyncio.create_task(execute_query("SELECT 6")), + ] + + await asyncio.gather(*tasks) ``` diff --git a/fastapi_async_sqlalchemy/__init__.py b/fastapi_async_sqlalchemy/__init__.py index ab2c3b7..21c6edc 100644 --- a/fastapi_async_sqlalchemy/__init__.py +++ b/fastapi_async_sqlalchemy/__init__.py @@ -2,4 +2,4 @@ __all__ = ["db", "SQLAlchemyMiddleware"] -__version__ = "0.6.1" +__version__ = "0.7.0.dev1" diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index bb6024d..e3b1563 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -1,9 +1,11 @@ +import asyncio +from asyncio import Task from contextvars import ContextVar from typing import Dict, Optional, Union from sqlalchemy.engine import Engine from sqlalchemy.engine.url import URL -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.types import ASGIApp @@ -11,17 +13,19 @@ from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError try: - from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.ext.asyncio import async_sessionmaker # noqa: F811 except ImportError: from sqlalchemy.orm import sessionmaker as async_sessionmaker def create_middleware_and_session_proxy(): _Session: Optional[async_sessionmaker] = None + _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) + _multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False) + _commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False) # Usage of context vars inside closures is not recommended, since they are not properly # garbage collected, but in our use case context var is created on program startup and # is used throughout the whole its lifecycle. - _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) class SQLAlchemyMiddleware(BaseHTTPMiddleware): def __init__( @@ -61,38 +65,97 @@ def session(self) -> AsyncSession: if _Session is None: raise SessionNotInitialisedError - session = _session.get() - if session is None: - raise MissingSessionError - - return session + multi_sessions = _multi_sessions_ctx.get() + if multi_sessions: + """In this case, we need to create a new session for each task. + We also need to commit the session on exit if commit_on_exit is True. + This is useful when we need to run multiple queries in parallel. + For example, when we need to run multiple queries in parallel in a route handler. + Example: + ```python + async with db(multi_sessions=True): + async def execute_query(query): + return await db.session.execute(text(query)) + + tasks = [ + asyncio.create_task(execute_query("SELECT 1")), + asyncio.create_task(execute_query("SELECT 2")), + asyncio.create_task(execute_query("SELECT 3")), + asyncio.create_task(execute_query("SELECT 4")), + asyncio.create_task(execute_query("SELECT 5")), + asyncio.create_task(execute_query("SELECT 6")), + ] + + await asyncio.gather(*tasks) + ``` + """ + commit_on_exit = _commit_on_exit_ctx.get() + task: Task = asyncio.current_task() # type: ignore + if not hasattr(task, "_db_session"): + task._db_session = _Session() # type: ignore + + def cleanup(future): + session = getattr(task, "_db_session", None) + if session: + + async def do_cleanup(): + try: + if future.exception(): + await session.rollback() + else: + if commit_on_exit: + await session.commit() + finally: + await session.close() + + asyncio.create_task(do_cleanup()) + + task.add_done_callback(cleanup) + return task._db_session # type: ignore + else: + session = _session.get() + if session is None: + raise MissingSessionError + return session class DBSession(metaclass=DBSessionMeta): - def __init__(self, session_args: Dict = None, commit_on_exit: bool = False): + def __init__( + self, + session_args: Dict = None, + commit_on_exit: bool = False, + multi_sessions: bool = False, + ): self.token = None + self.multi_sessions_token = None + self.commit_on_exit_token = None self.session_args = session_args or {} self.commit_on_exit = commit_on_exit + self.multi_sessions = multi_sessions async def __aenter__(self): if not isinstance(_Session, async_sessionmaker): raise SessionNotInitialisedError - self.token = _session.set(_Session(**self.session_args)) # type: ignore + if self.multi_sessions: + self.multi_sessions_token = _multi_sessions_ctx.set(True) + self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit) + + self.token = _session.set(_Session(**self.session_args)) return type(self) async def __aexit__(self, exc_type, exc_value, traceback): session = _session.get() - try: if exc_type is not None: await session.rollback() - elif ( - self.commit_on_exit - ): # Note: Changed this to elif to avoid commit after rollback + elif self.commit_on_exit: await session.commit() finally: await session.close() _session.reset(self.token) + if self.multi_sessions_token is not None: + _multi_sessions_ctx.reset(self.multi_sessions_token) + _commit_on_exit_ctx.reset(self.commit_on_exit_token) return SQLAlchemyMiddleware, DBSession diff --git a/requirements.txt b/requirements.txt index 0279928..e3a0644 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ coverage>=5.2.1 entrypoints==0.3 fastapi==0.90.0 # pyup: ignore flake8==3.7.9 -idna==2.8 +idna==3.7 importlib-metadata==1.5.0 isort==4.3.21 mccabe==0.6.1 @@ -17,13 +17,13 @@ packaging>=22.0 pathspec>=0.9.0 pluggy==0.13.0 pycodestyle==2.5.0 -pydantic==1.10.13 +pydantic==1.10.18 pyflakes==2.1.1 pyparsing==2.4.2 pytest==7.2.0 pytest-cov==2.11.1 PyYAML>=5.4 -regex==2020.2.20 +regex>=2020.2.20 requests>=2.22.0 httpx>=0.20.0 six==1.12.0 @@ -36,7 +36,7 @@ toml>=0.10.1 typed-ast>=1.4.2 urllib3>=1.25.9 wcwidth==0.1.7 -zipp==3.1.0 +zipp==3.19.1 black==24.4.2 pytest-asyncio==0.21.0 -greenlet==2.0.2 +greenlet==3.1.1 diff --git a/tests/test_session.py b/tests/test_session.py index 06163ac..82f5dc9 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,4 +1,7 @@ +import asyncio + import pytest +from sqlalchemy import text from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from starlette.middleware.base import BaseHTTPMiddleware @@ -148,3 +151,25 @@ async def test_db_context_session_args(app, db, SQLAlchemyMiddleware, commit_on_ session_args = {"expire_on_commit": False} async with db(session_args=session_args): db.session + + +@pytest.mark.asyncio +async def test_multi_sessions(app, db, SQLAlchemyMiddleware): + app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + + async with db(multi_sessions=True): + + async def execute_query(query): + return await db.session.execute(text(query)) + + tasks = [ + asyncio.create_task(execute_query("SELECT 1")), + asyncio.create_task(execute_query("SELECT 2")), + asyncio.create_task(execute_query("SELECT 3")), + asyncio.create_task(execute_query("SELECT 4")), + asyncio.create_task(execute_query("SELECT 5")), + asyncio.create_task(execute_query("SELECT 6")), + ] + + res = await asyncio.gather(*tasks) + assert len(res) == 6