Skip to content

Commit 52dd6ce

Browse files
committed
Improve checking of in-memory SQLite database.
1 parent dd71251 commit 52dd6ce

File tree

3 files changed

+56
-13
lines changed

3 files changed

+56
-13
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Union
2+
3+
import pytest
4+
from sqlalchemy.engine import URL, make_url
5+
6+
from tiled.server.connection_pool import is_memory_sqlite
7+
8+
9+
@pytest.mark.parametrize(
10+
("uri", "expected"),
11+
[
12+
("sqlite://", True), # accepts str
13+
(make_url("sqlite://"), True), # accepts URL
14+
("sqlite:///:memory:", True),
15+
("sqlite:///file::memory:?cache=shared", True),
16+
("sqlite:////tmp/example.db", False),
17+
],
18+
)
19+
def test_is_memory_sqlite(uri: Union[str, URL], expected: bool):
20+
actual = is_memory_sqlite(uri)
21+
assert actual is expected

tiled/catalog/adapter.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@
6868
ZARR_MIMETYPE,
6969
)
7070
from ..query_registration import QueryTranslationRegistry
71-
from ..server.connection_pool import close_database_connection_pool, get_database_engine
71+
from ..server.connection_pool import (
72+
close_database_connection_pool,
73+
get_database_engine,
74+
is_memory_sqlite,
75+
)
7276
from ..server.core import NoEntry
7377
from ..server.schemas import Asset, DataSource, Management, Revision
7478
from ..server.settings import DatabaseSettings
@@ -229,10 +233,7 @@ async def execute(self, statement, explain=None):
229233
return result
230234

231235
async def startup(self):
232-
if (self.engine.dialect.name == "sqlite") and (
233-
self.engine.url.database == ":memory:"
234-
or self.engine.url.query.get("mode") == "memory"
235-
):
236+
if is_memory_sqlite(self.engine.url):
236237
# Special-case for in-memory SQLite: Because it is transient we can
237238
# skip over anything related to migrations.
238239
await initialize_database(self.engine)

tiled/server/connection_pool.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from fastapi import Depends
77
from sqlalchemy import event
8-
from sqlalchemy.engine import make_url
8+
from sqlalchemy.engine import URL, make_url
99
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
1010
from sqlalchemy.pool import AsyncAdaptedQueuePool
1111

@@ -55,20 +55,14 @@ async def __aexit__(self, *excinfo):
5555

5656

5757
def open_database_connection_pool(database_settings: DatabaseSettings) -> AsyncEngine:
58-
if (
59-
make_url(database_settings.uri).database == ":memory:"
60-
or database_settings.uri == "sqlite://"
61-
):
62-
# For SQLite databases that exist only in process memory,
58+
if is_memory_sqlite(database_settings.uri):
6359
engine = create_async_engine(
6460
ensure_specified_sql_driver(database_settings.uri),
6561
echo=DEFAULT_ECHO,
6662
json_serializer=json_serializer,
6763
)
6864

6965
else:
70-
# For file-backed SQLite databases, and for PostgreSQL databases,
71-
# connection pooling offers a significant performance boost.
7266
engine = create_async_engine(
7367
ensure_specified_sql_driver(database_settings.uri),
7468
echo=DEFAULT_ECHO,
@@ -139,3 +133,30 @@ def _set_sqlite_pragma(conn, record):
139133
cursor = conn.cursor()
140134
cursor.execute("PRAGMA foreign_keys=ON")
141135
cursor.close()
136+
137+
138+
def is_memory_sqlite(url: Union[URL, str]) -> bool:
139+
"""
140+
Check if a SQLAlchemy URL is a memory-backed SQLite database.
141+
142+
Handles various memory database URL formats:
143+
- sqlite:///:memory:
144+
- sqlite:///file::memory:?cache=shared
145+
- sqlite://
146+
- etc.
147+
"""
148+
url = make_url(url)
149+
# Check if it's SQLite at all
150+
if url.get_dialect().name != "sqlite":
151+
return False
152+
153+
# Check if database is None or empty (default memory DB)
154+
if not url.database:
155+
return True
156+
157+
# Check for explicit :memory: string (case-insensitive)
158+
database = str(url.database).lower()
159+
if ":memory:" in database:
160+
return True
161+
162+
return False

0 commit comments

Comments
 (0)