Skip to content

Commit 1dbcdd7

Browse files
authored
Merge branch 'master' into double-pluralization-of-names
2 parents 35f11ee + 50d6ddb commit 1dbcdd7

File tree

5 files changed

+109
-3
lines changed

5 files changed

+109
-3
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ repos:
1616
- id: trailing-whitespace
1717

1818
- repo: https://github.com/astral-sh/ruff-pre-commit
19-
rev: v0.9.10
19+
rev: v0.11.5
2020
hooks:
2121
- id: ruff
2222
args: [--fix, --show-fixes]

CHANGES.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ Version history
55

66
- Type annotations for ARRAY column attributes now include the Python type of
77
the array elements
8+
- Added support for specifying engine arguments via ``--engine-arg``
9+
(PR by @LajosCseppento)
810

911
**3.0.0**
1012

README.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ Examples::
6767
sqlacodegen postgresql:///some_local_db
6868
sqlacodegen --generator tables mysql+pymysql://user:password@localhost/dbname
6969
sqlacodegen --generator dataclasses sqlite:///database.db
70+
# --engine-arg values are parsed with ast.literal_eval
71+
sqlacodegen oracle+oracledb://user:[email protected]:1521/XE --engine-arg thick_mode=True
72+
sqlacodegen oracle+oracledb://user:[email protected]:1521/XE --engine-arg thick_mode=True --engine-arg connect_args='{"user": "user", "dsn": "..."}'
7073

7174
To see the list of generic options::
7275

src/sqlacodegen/cli.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

33
import argparse
4+
import ast
45
import sys
56
from contextlib import ExitStack
6-
from typing import TextIO
7+
from typing import Any, TextIO
78

89
from sqlalchemy.engine import create_engine
910
from sqlalchemy.schema import MetaData
@@ -29,6 +30,28 @@
2930
from importlib.metadata import entry_points, version
3031

3132

33+
def _parse_engine_arg(arg_str: str) -> tuple[str, Any]:
34+
if "=" not in arg_str:
35+
raise argparse.ArgumentTypeError("engine-arg must be in key=value format")
36+
37+
key, value = arg_str.split("=", 1)
38+
try:
39+
value = ast.literal_eval(value)
40+
except Exception:
41+
pass # Leave as string if literal_eval fails
42+
43+
return key, value
44+
45+
46+
def _parse_engine_args(arg_list: list[str]) -> dict[str, Any]:
47+
result = {}
48+
for arg in arg_list or []:
49+
key, value = _parse_engine_arg(arg)
50+
result[key] = value
51+
52+
return result
53+
54+
3255
def main() -> None:
3356
generators = {ep.name: ep for ep in entry_points(group="sqlacodegen.generators")}
3457
parser = argparse.ArgumentParser(
@@ -58,6 +81,17 @@ def main() -> None:
5881
action="store_true",
5982
help="ignore views (always true for sqlmodels generator)",
6083
)
84+
parser.add_argument(
85+
"--engine-arg",
86+
action="append",
87+
help=(
88+
"engine arguments in key=value format, e.g., "
89+
'--engine-arg=connect_args=\'{"user": "scott"}\' '
90+
"--engine-arg thick_mode=true or "
91+
'--engine-arg thick_mode=\'{"lib_dir": "/path"}\' '
92+
"(values are parsed with ast.literal_eval)"
93+
),
94+
)
6195
parser.add_argument("--outfile", help="file to write output to (default: stdout)")
6296
args = parser.parse_args()
6397

@@ -80,7 +114,8 @@ def main() -> None:
80114
print(f"Using pgvector {version('pgvector')}")
81115

82116
# Use reflection to fill in the metadata
83-
engine = create_engine(args.url)
117+
engine_args = _parse_engine_args(args.engine_arg)
118+
engine = create_engine(args.url, **engine_args)
84119
metadata = MetaData()
85120
tables = args.tables.split(",") if args.tables else None
86121
schemas = args.schemas.split(",") if args.schemas else [None]

tests/test_cli.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,72 @@ class Foo(SQLModel, table=True):
150150
)
151151

152152

153+
def test_cli_engine_arg(db_path: Path, tmp_path: Path) -> None:
154+
output_path = tmp_path / "outfile"
155+
subprocess.run(
156+
[
157+
"sqlacodegen",
158+
f"sqlite:///{db_path}",
159+
"--generator",
160+
"tables",
161+
"--engine-arg",
162+
'connect_args={"timeout": 10}',
163+
"--outfile",
164+
str(output_path),
165+
],
166+
check=True,
167+
)
168+
169+
assert (
170+
output_path.read_text()
171+
== """\
172+
from sqlalchemy import Column, Integer, MetaData, Table, Text
173+
174+
metadata = MetaData()
175+
176+
177+
t_foo = Table(
178+
'foo', metadata,
179+
Column('id', Integer, primary_key=True),
180+
Column('name', Text, nullable=False)
181+
)
182+
"""
183+
)
184+
185+
186+
def test_cli_invalid_engine_arg(db_path: Path, tmp_path: Path) -> None:
187+
output_path = tmp_path / "outfile"
188+
189+
# Expect exception:
190+
# TypeError: 'this_arg_does_not_exist' is an invalid keyword argument for Connection()
191+
with pytest.raises(subprocess.CalledProcessError) as exc_info:
192+
subprocess.run(
193+
[
194+
"sqlacodegen",
195+
f"sqlite:///{db_path}",
196+
"--generator",
197+
"tables",
198+
"--engine-arg",
199+
'connect_args={"this_arg_does_not_exist": 10}',
200+
"--outfile",
201+
str(output_path),
202+
],
203+
check=True,
204+
capture_output=True,
205+
)
206+
207+
if sys.version_info < (3, 13):
208+
assert (
209+
"'this_arg_does_not_exist' is an invalid keyword argument"
210+
in exc_info.value.stderr.decode()
211+
)
212+
else:
213+
assert (
214+
"got an unexpected keyword argument 'this_arg_does_not_exist'"
215+
in exc_info.value.stderr.decode()
216+
)
217+
218+
153219
def test_main() -> None:
154220
expected_version = version("sqlacodegen")
155221
completed = subprocess.run(

0 commit comments

Comments
 (0)