From 33bd8f684ae0f2c1ff883f2e5de617e7ef44e42e Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Sun, 31 Jul 2022 00:44:15 +0300 Subject: [PATCH] init session --- fastapi_async_sqlalchemy/middleware.py | 5 ++++- tests/test_session.py | 8 ++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index 20c9cb1..c100ca9 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -67,11 +67,14 @@ def __init__(self, session_args: Dict = None, commit_on_exit: bool = False): self.session_args = session_args or {} self.commit_on_exit = commit_on_exit + async def _init_session(self): + self.token = _session.set(_Session(**self.session_args)) + async def __aenter__(self): if not isinstance(_Session, sessionmaker): raise SessionNotInitialisedError - self.token = _session.set(_Session(**self.session_args)) + await self._init_session() return type(self) async def __aexit__(self, exc_type, exc_value, traceback): diff --git a/tests/test_session.py b/tests/test_session.py index 19a681a..b5fb3b1 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -100,6 +100,14 @@ def test_outside_of_route_without_context_fails(app, db, SQLAlchemyMiddleware): db.session +@pytest.mark.asyncio +async def test_init_session(app, db, SQLAlchemyMiddleware): + app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + + await db()._init_session() + assert isinstance(db.session, AsyncSession) + + @pytest.mark.parametrize("commit_on_exit", [True, False]) @pytest.mark.asyncio async def test_db_context_session_args(app, db, SQLAlchemyMiddleware, commit_on_exit):