Skip to content

Commit e52bab3

Browse files
authored
Do not require SQL URIs to be prefixed with SQLAlchemy driver (#810)
* Automatically set SQL driver if unset. * Handle special SQLite URIs * Consistently use database URI with schema. * Interpret filepaths as SQLite URIs. * Parse uri earlier. * Update CHANGELOG * Fix missing arg in refactor. * Make utility accept Path-like objects. * Deal with Windows paths in test case
1 parent 0ad610f commit e52bab3

File tree

10 files changed

+136
-18
lines changed

10 files changed

+136
-18
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ Write the date in place of the "Unreleased" in the case a new version is release
1717

1818
- Drop support for Python 3.8, which is reached end of life
1919
upstream on 7 October 2024.
20+
- Do not require SQL database URIs to specify a "driver" (Python
21+
library to be used for connecting).
2022

2123
## v0.1.0b10 (2024-10-11)
2224

tiled/_tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ..catalog import from_uri, in_memory
1414
from ..client.base import BaseClient
1515
from ..server.settings import get_settings
16+
from ..utils import ensure_specified_sql_driver
1617
from .utils import enter_username_password as utils_enter_uname_passwd
1718
from .utils import temp_postgres
1819

@@ -152,7 +153,7 @@ async def postgresql_with_example_data_adapter(request, tmpdir):
152153
if uri.endswith("/"):
153154
uri = uri[:-1]
154155
uri_with_database_name = f"{uri}/{DATABASE_NAME}"
155-
engine = create_async_engine(uri_with_database_name)
156+
engine = create_async_engine(ensure_specified_sql_driver(uri_with_database_name))
156157
try:
157158
async with engine.connect():
158159
pass

tiled/_tests/test_utils.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from pathlib import Path
2+
3+
from ..utils import ensure_specified_sql_driver
4+
5+
6+
def test_ensure_specified_sql_driver():
7+
# Postgres
8+
# Default driver is added if missing.
9+
assert (
10+
ensure_specified_sql_driver(
11+
"postgresql://user:password@localhost:5432/database"
12+
)
13+
== "postgresql+asyncpg://user:password@localhost:5432/database"
14+
)
15+
# Default driver passes through if specified.
16+
assert (
17+
ensure_specified_sql_driver(
18+
"postgresql+asyncpg://user:password@localhost:5432/database"
19+
)
20+
== "postgresql+asyncpg://user:password@localhost:5432/database"
21+
)
22+
# Do not override user-provided.
23+
assert (
24+
ensure_specified_sql_driver(
25+
"postgresql+custom://user:password@localhost:5432/database"
26+
)
27+
== "postgresql+custom://user:password@localhost:5432/database"
28+
)
29+
30+
# SQLite
31+
# Default driver is added if missing.
32+
assert (
33+
ensure_specified_sql_driver("sqlite:////test.db")
34+
== "sqlite+aiosqlite:////test.db"
35+
)
36+
# Default driver passes through if specified.
37+
assert (
38+
ensure_specified_sql_driver("sqlite+aiosqlite:////test.db")
39+
== "sqlite+aiosqlite:////test.db"
40+
)
41+
# Do not override user-provided.
42+
assert (
43+
ensure_specified_sql_driver("sqlite+custom:////test.db")
44+
== "sqlite+custom:////test.db"
45+
)
46+
# Handle SQLite :memory: URIs
47+
assert (
48+
ensure_specified_sql_driver("sqlite+aiosqlite://:memory:")
49+
== "sqlite+aiosqlite://:memory:"
50+
)
51+
assert (
52+
ensure_specified_sql_driver("sqlite://:memory:")
53+
== "sqlite+aiosqlite://:memory:"
54+
)
55+
# Handle SQLite relative URIs
56+
assert (
57+
ensure_specified_sql_driver("sqlite+aiosqlite:///test.db")
58+
== "sqlite+aiosqlite:///test.db"
59+
)
60+
assert (
61+
ensure_specified_sql_driver("sqlite:///test.db")
62+
== "sqlite+aiosqlite:///test.db"
63+
)
64+
# Filepaths are implicitly SQLite databases.
65+
# Relative path
66+
assert ensure_specified_sql_driver("test.db") == "sqlite+aiosqlite:///test.db"
67+
# Path object
68+
assert ensure_specified_sql_driver(Path("test.db")) == "sqlite+aiosqlite:///test.db"
69+
# Relative path anchored to .
70+
assert ensure_specified_sql_driver("./test.db") == "sqlite+aiosqlite:///test.db"
71+
# Absolute path
72+
assert (
73+
ensure_specified_sql_driver(Path("/tmp/test.db"))
74+
== f"sqlite+aiosqlite:///{Path('/tmp/test.db')}"
75+
)

tiled/_tests/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from ..client import context
1616
from ..client.base import BaseClient
17+
from ..utils import ensure_specified_sql_driver
1718

1819
if sys.version_info < (3, 9):
1920
import importlib_resources as resources
@@ -33,7 +34,7 @@ async def temp_postgres(uri):
3334
if uri.endswith("/"):
3435
uri = uri[:-1]
3536
# Create a fresh database.
36-
engine = create_async_engine(uri)
37+
engine = create_async_engine(ensure_specified_sql_driver(uri))
3738
database_name = f"tiled_test_disposable_{uuid.uuid4().hex}"
3839
async with engine.connect() as connection:
3940
await connection.execute(

tiled/authn_database/connection_pool.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
33

44
from ..server.settings import get_settings
5+
from ..utils import ensure_specified_sql_driver
56

67
# A given process probably only has one of these at a time, but we
78
# key on database_settings just case in some testing context or something
@@ -16,7 +17,9 @@ def open_database_connection_pool(database_settings):
1617
# kwargs["pool_pre_ping"] = database_settings.pool_pre_ping
1718
# kwargs["max_overflow"] = database_settings.max_overflow
1819
engine = create_async_engine(
19-
database_settings.uri, connect_args=connect_args, **kwargs
20+
ensure_specified_sql_driver(database_settings.uri),
21+
connect_args=connect_args,
22+
**kwargs,
2023
)
2124
_connection_pools[database_settings] = engine
2225
return engine

tiled/catalog/adapter.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@
6464
from ..server.schemas import Asset, DataSource, Management, Revision, Spec
6565
from ..structures.core import StructureFamily
6666
from ..utils import (
67-
SCHEME_PATTERN,
6867
UNCHANGED,
6968
Conflicts,
7069
OneShotCachedMap,
7170
UnsupportedQueryType,
7271
ensure_awaitable,
72+
ensure_specified_sql_driver,
7373
ensure_uri,
7474
import_object,
7575
path_from_uri,
@@ -1347,7 +1347,7 @@ def from_uri(
13471347
echo=DEFAULT_ECHO,
13481348
adapters_by_mimetype=None,
13491349
):
1350-
uri = str(uri)
1350+
uri = ensure_specified_sql_driver(uri)
13511351
if init_if_not_exists:
13521352
# The alembic stamping can only be does synchronously.
13531353
# The cleanest option available is to start a subprocess
@@ -1366,9 +1366,6 @@ def from_uri(
13661366
stderr = process.stderr.decode()
13671367
logging.info(f"Subprocess stdout: {stdout}")
13681368
logging.error(f"Subprocess stderr: {stderr}")
1369-
if not SCHEME_PATTERN.match(uri):
1370-
# Interpret URI as filepath.
1371-
uri = f"sqlite+aiosqlite:///{uri}"
13721369

13731370
parsed_url = make_url(uri)
13741371
if (parsed_url.get_dialect().name == "sqlite") and (
@@ -1381,7 +1378,10 @@ def from_uri(
13811378
else:
13821379
poolclass = None # defer to sqlalchemy default
13831380
engine = create_async_engine(
1384-
uri, echo=echo, json_serializer=json_serializer, poolclass=poolclass
1381+
uri,
1382+
echo=echo,
1383+
json_serializer=json_serializer,
1384+
poolclass=poolclass,
13851385
)
13861386
if engine.dialect.name == "sqlite":
13871387
event.listens_for(engine.sync_engine, "connect")(_set_sqlite_pragma)

tiled/commandline/_admin.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ def initialize_database(database_uri: str):
2727
REQUIRED_REVISION,
2828
initialize_database,
2929
)
30+
from ..utils import ensure_specified_sql_driver
3031

3132
async def do_setup():
32-
engine = create_async_engine(database_uri)
33+
engine = create_async_engine(ensure_specified_sql_driver(database_uri))
3334
redacted_url = engine.url._replace(password="[redacted]")
3435
try:
3536
await check_database(engine, REQUIRED_REVISION, ALL_REVISIONS)
@@ -71,9 +72,10 @@ def upgrade_database(
7172
ALEMBIC_INI_TEMPLATE_PATH,
7273
)
7374
from ..authn_database.core import ALL_REVISIONS
75+
from ..utils import ensure_specified_sql_driver
7476

7577
async def do_setup():
76-
engine = create_async_engine(database_uri)
78+
engine = create_async_engine(ensure_specified_sql_driver(database_uri))
7779
redacted_url = engine.url._replace(password="[redacted]")
7880
current_revision = await get_current_revision(engine, ALL_REVISIONS)
7981
await engine.dispose()
@@ -107,9 +109,10 @@ def downgrade_database(
107109
ALEMBIC_INI_TEMPLATE_PATH,
108110
)
109111
from ..authn_database.core import ALL_REVISIONS
112+
from ..utils import ensure_specified_sql_driver
110113

111114
async def do_setup():
112-
engine = create_async_engine(database_uri)
115+
engine = create_async_engine(ensure_specified_sql_driver(database_uri))
113116
redacted_url = engine.url._replace(password="[redacted]")
114117
current_revision = await get_current_revision(engine, ALL_REVISIONS)
115118
if current_revision is None:

tiled/commandline/_catalog.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,9 @@ def init(
4444
from ..alembic_utils import UninitializedDatabase, check_database, stamp_head
4545
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
4646
from ..catalog.core import ALL_REVISIONS, REQUIRED_REVISION, initialize_database
47-
from ..utils import SCHEME_PATTERN
47+
from ..utils import ensure_specified_sql_driver
4848

49-
if not SCHEME_PATTERN.match(database):
50-
# Interpret URI as filepath.
51-
database = f"sqlite+aiosqlite:///{database}"
49+
database = ensure_specified_sql_driver(database)
5250

5351
async def do_setup():
5452
engine = create_async_engine(database)
@@ -94,6 +92,9 @@ def upgrade_database(
9492
from ..alembic_utils import get_current_revision, upgrade
9593
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
9694
from ..catalog.core import ALL_REVISIONS
95+
from ..utils import ensure_specified_sql_driver
96+
97+
database_uri = ensure_specified_sql_driver(database_uri)
9798

9899
async def do_setup():
99100
engine = create_async_engine(database_uri)
@@ -127,6 +128,9 @@ def downgrade_database(
127128
from ..alembic_utils import downgrade, get_current_revision
128129
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
129130
from ..catalog.core import ALL_REVISIONS
131+
from ..utils import ensure_specified_sql_driver
132+
133+
database_uri = ensure_specified_sql_driver(database_uri)
130134

131135
async def do_setup():
132136
engine = create_async_engine(database_uri)

tiled/commandline/_serve.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,9 @@ def serve_directory(
129129
from ..alembic_utils import stamp_head
130130
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
131131
from ..catalog.core import initialize_database
132+
from ..utils import ensure_specified_sql_driver
132133

133-
engine = create_async_engine(database)
134+
engine = create_async_engine(ensure_specified_sql_driver(database))
134135
asyncio.run(initialize_database(engine))
135136
stamp_head(ALEMBIC_INI_TEMPLATE_PATH, ALEMBIC_DIR, database)
136137

@@ -389,8 +390,9 @@ def serve_catalog(
389390
from ..alembic_utils import stamp_head
390391
from ..catalog.alembic_constants import ALEMBIC_DIR, ALEMBIC_INI_TEMPLATE_PATH
391392
from ..catalog.core import initialize_database
393+
from ..utils import ensure_specified_sql_driver
392394

393-
engine = create_async_engine(database)
395+
engine = create_async_engine(ensure_specified_sql_driver(database))
394396
asyncio.run(initialize_database(engine))
395397
stamp_head(ALEMBIC_INI_TEMPLATE_PATH, ALEMBIC_DIR, database)
396398

tiled/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,33 @@ def ensure_uri(uri_or_path) -> str:
721721
return str(uri_str)
722722

723723

724+
SCHEME_TO_SCHEME_PLUS_DRIVER = {
725+
"postgresql": "postgresql+asyncpg",
726+
"sqlite": "sqlite+aiosqlite",
727+
}
728+
729+
730+
def ensure_specified_sql_driver(uri: str) -> str:
731+
"""
732+
Given a URI without a driver in the scheme, add Tiled's preferred driver.
733+
734+
If a driver is already specified, the specified one will be used; it
735+
will NOT be overriden by this function.
736+
737+
'postgresql://...' -> 'postgresql+asynpg://...'
738+
'sqlite://...' -> 'sqlite+aiosqlite://...'
739+
'postgresql+asyncpg://...' -> 'postgresql+asynpg://...'
740+
'postgresql+my_custom_driver://...' -> 'postgresql+my_custom_driver://...'
741+
'/path/to/file.db' -> 'sqlite+aiosqlite:////path/to/file.db'
742+
"""
743+
if not SCHEME_PATTERN.match(str(uri)):
744+
# Interpret URI as filepath.
745+
uri = f"sqlite+aiosqlite:///{Path(uri)}"
746+
scheme, rest = uri.split(":", 1)
747+
new_scheme = SCHEME_TO_SCHEME_PLUS_DRIVER.get(scheme, scheme)
748+
return ":".join([new_scheme, rest])
749+
750+
724751
class catch_warning_msg(warnings.catch_warnings):
725752
"""Backward compatible version of catch_warnings for python <3.11.
726753

0 commit comments

Comments
 (0)