diff --git a/example_configs/toy_authentication.yml b/example_configs/toy_authentication.yml index 49a982c78..0207284c6 100644 --- a/example_configs/toy_authentication.yml +++ b/example_configs/toy_authentication.yml @@ -12,9 +12,6 @@ authentication: tiled_admins: - provider: toy id: admin -database: - uri: "sqlite:///file:authn_mem?mode=memory&cache=shared&uri=true" - init_if_not_exists: true access_control: access_policy: "tiled.access_control.access_policies:TagBasedAccessPolicy" args: diff --git a/tiled/_tests/test_configs/config_in_memory_authn.yml b/tiled/_tests/test_configs/config_in_memory_authn.yml new file mode 100644 index 000000000..9377c2a56 --- /dev/null +++ b/tiled/_tests/test_configs/config_in_memory_authn.yml @@ -0,0 +1,16 @@ +# config.yml +trees: + - path: / + tree: tiled.examples.generated_minimal:tree +uvicorn: + host: 0.0.0.0 + port: 8000 +authentication: + providers: + - provider: test + authenticator: tiled.authenticators:DictionaryAuthenticator + args: + users_to_passwords: + alice: PASSWORD + secret_keys: + - SECRET diff --git a/tiled/_tests/test_connection_pool.py b/tiled/_tests/test_connection_pool.py new file mode 100644 index 000000000..d0a7e31ca --- /dev/null +++ b/tiled/_tests/test_connection_pool.py @@ -0,0 +1,22 @@ +from typing import Union + +import pytest +from sqlalchemy.engine import URL, make_url + +from tiled.server.connection_pool import is_memory_sqlite + + +@pytest.mark.parametrize( + ("uri", "expected"), + [ + ("sqlite://", True), # accepts str + (make_url("sqlite://"), True), # accepts URL + ("sqlite:///:memory:", True), + ("sqlite:///file::memory:?cache=shared", True), + ("sqlite:///file:name:?cache=shared&mode=memory", True), + ("sqlite:////tmp/example.db", False), + ], +) +def test_is_memory_sqlite(uri: Union[str, URL], expected: bool): + actual = is_memory_sqlite(uri) + assert actual is expected diff --git a/tiled/_tests/test_in_memory_authn.py b/tiled/_tests/test_in_memory_authn.py new file mode 100644 index 000000000..9b797d2c9 --- /dev/null +++ b/tiled/_tests/test_in_memory_authn.py @@ -0,0 +1,26 @@ +from pathlib import Path + +import yaml + +from tiled._tests.utils import enter_username_password +from tiled.client import Context, from_context +from tiled.server.app import build_app_from_config + +here = Path(__file__).parent.absolute() + + +def test_good_path(): + """Test authn database defaults to in-memory catalog""" + with open(here / "test_configs" / "config_in_memory_authn.yml") as config_file: + config = yaml.load(config_file, Loader=yaml.BaseLoader) + + app = build_app_from_config(config) + context = Context.from_app(app) + + with enter_username_password("alice", "PASSWORD"): + client = from_context(context, remember_me=False) + + client.logout() + context.close() + + assert True diff --git a/tiled/_tests/test_validation.py b/tiled/_tests/test_validation.py index 932d082fb..b58089e0f 100644 --- a/tiled/_tests/test_validation.py +++ b/tiled/_tests/test_validation.py @@ -1,7 +1,6 @@ """ This tests tiled's validation registry """ - import numpy as np import pandas as pd import pytest diff --git a/tiled/authn_database/core.py b/tiled/authn_database/core.py index ab172ae70..1fc1422d8 100644 --- a/tiled/authn_database/core.py +++ b/tiled/authn_database/core.py @@ -76,14 +76,13 @@ async def create_default_roles(db): async def initialize_database(engine: AsyncEngine) -> None: - async with engine.connect() as conn: + async with engine.begin() as conn: # Create all tables. await conn.run_sync(Base.metadata.create_all) - await conn.commit() - # Initialize Roles table. - async with AsyncSession(engine) as db: - await create_default_roles(db) + # Initialize Roles table. + async with AsyncSession(engine) as db: + await create_default_roles(db) async def purge_expired(db: AsyncSession, cls) -> int: diff --git a/tiled/catalog/adapter.py b/tiled/catalog/adapter.py index c5114e417..8b35d2055 100644 --- a/tiled/catalog/adapter.py +++ b/tiled/catalog/adapter.py @@ -68,7 +68,11 @@ ZARR_MIMETYPE, ) from ..query_registration import QueryTranslationRegistry -from ..server.connection_pool import close_database_connection_pool, get_database_engine +from ..server.connection_pool import ( + close_database_connection_pool, + get_database_engine, + is_memory_sqlite, +) from ..server.core import NoEntry from ..server.schemas import Asset, DataSource, Management, Revision from ..server.settings import DatabaseSettings @@ -229,10 +233,7 @@ async def execute(self, statement, explain=None): return result async def startup(self): - if (self.engine.dialect.name == "sqlite") and ( - self.engine.url.database == ":memory:" - or self.engine.url.query.get("mode") == "memory" - ): + if is_memory_sqlite(self.engine.url): # Special-case for in-memory SQLite: Because it is transient we can # skip over anything related to migrations. await initialize_database(self.engine) diff --git a/tiled/server/app.py b/tiled/server/app.py index 1cb605d1f..b8a4cfed3 100644 --- a/tiled/server/app.py +++ b/tiled/server/app.py @@ -468,13 +468,14 @@ def override_get_settings(): settings.database_settings.max_overflow = database.max_overflow if database.init_if_not_exists is not None: settings.database_init_if_not_exists = database.init_if_not_exists - if authenticators: - # If we support authentication providers, we need a database, so if one is - # not set, use a SQLite database in memory. Horizontally scaled deployments - # must specify a persistent database. - settings.database_settings.uri = ( - settings.database_settings.uri or "sqlite://" - ) + if authenticators: + # If we support authentication providers, we need a database, so if one is + # not set, use a SQLite database in memory. Horizontally scaled deployments + # must specify a persistent database. + settings.database_settings.uri = ( + settings.database_settings.uri + or "sqlite:///file:authdb?mode=memory&cache=shared&uri=true" + ) if ( authenticators and len(authenticators) == 1 diff --git a/tiled/server/connection_pool.py b/tiled/server/connection_pool.py index d235622e0..60d066348 100644 --- a/tiled/server/connection_pool.py +++ b/tiled/server/connection_pool.py @@ -5,7 +5,7 @@ from fastapi import Depends from sqlalchemy import event -from sqlalchemy.engine import make_url +from sqlalchemy.engine import URL, make_url from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from sqlalchemy.pool import AsyncAdaptedQueuePool @@ -55,9 +55,7 @@ async def __aexit__(self, *excinfo): def open_database_connection_pool(database_settings: DatabaseSettings) -> AsyncEngine: - if make_url(database_settings.uri).database == ":memory:": - # For SQLite databases that exist only in process memory, - # pooling is not applicable. Just return an engine and don't cache it. + if is_memory_sqlite(database_settings.uri): engine = create_async_engine( ensure_specified_sql_driver(database_settings.uri), echo=DEFAULT_ECHO, @@ -65,8 +63,6 @@ def open_database_connection_pool(database_settings: DatabaseSettings) -> AsyncE ) else: - # For file-backed SQLite databases, and for PostgreSQL databases, - # connection pooling offers a significant performance boost. engine = create_async_engine( ensure_specified_sql_driver(database_settings.uri), echo=DEFAULT_ECHO, @@ -137,3 +133,34 @@ def _set_sqlite_pragma(conn, record): cursor = conn.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() + + +def is_memory_sqlite(url: Union[URL, str]) -> bool: + """ + Check if a SQLAlchemy URL is a memory-backed SQLite database. + + Handles various memory database URL formats: + - sqlite:///:memory: + - sqlite:///file::memory:?cache=shared + - sqlite:// + - etc. + """ + url = make_url(url) + # Check if it's SQLite at all + if url.get_dialect().name != "sqlite": + return False + + # Check if database is None or empty (default memory DB) + if not url.database: + return True + + # Check for explicit :memory: string (case-insensitive) + database = str(url.database).lower() + if ":memory:" in database: + return True + + # Check for mode=memory query parameter + if (mode := url.query.get("mode")) and mode.lower() == "memory": + return True + + return False diff --git a/tiled/server/metrics.py b/tiled/server/metrics.py index d39016d9b..978ae131e 100644 --- a/tiled/server/metrics.py +++ b/tiled/server/metrics.py @@ -6,7 +6,7 @@ from prometheus_client import Counter, Gauge, Histogram from sqlalchemy import event -from sqlalchemy.pool import QueuePool +from sqlalchemy.pool import QueuePool, StaticPool REQUEST_DURATION = Histogram( "tiled_request_duration_seconds", @@ -220,6 +220,10 @@ def on_checkout(dbapi_connection, connection_record, connection_proxy): DB_POOL_CHECKEDOUT.labels(name).inc() DB_POOL_CHECKOUTS_TOTAL.labels(name).inc() + # Skip for single-connection database: + if isinstance(pool, StaticPool): + return + # First overflow: we just used the very first overflow slot if pool.overflow() == 1: DB_POOL_FIRST_OVERFLOW_TOTAL.labels(name).inc()