Skip to content

Commit 918ea6c

Browse files
committed
feat: support sqlalchemy select API
1 parent 2133fd7 commit 918ea6c

File tree

4 files changed

+93
-46
lines changed

4 files changed

+93
-46
lines changed

RELEASE.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Release type: minor
2+
3+
Support SQLAlchemy select API when resolving.

src/strawberry_sqlalchemy_mapper/field.py

+35-29
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
from typing_extensions import Annotated, TypeAlias
3030

3131
from sqlakeyset.types import Keyset
32+
from sqlalchemy import Select, select
3233
from sqlalchemy.ext.asyncio import AsyncSession
33-
from sqlalchemy.orm import Query, Session
34+
from sqlalchemy.orm import Session
3435
from strawberry import relay
3536
from strawberry.annotation import StrawberryAnnotation
3637
from strawberry.extensions.field_extension import (
@@ -59,11 +60,11 @@
5960
assert argument # type: ignore[truthy-function]
6061

6162

62-
connection_session: contextvars.ContextVar[
63-
Union[Session, AsyncSession, None]
64-
] = contextvars.ContextVar(
65-
"connection-session",
66-
default=None,
63+
connection_session: contextvars.ContextVar[Union[Session, AsyncSession, None]] = (
64+
contextvars.ContextVar(
65+
"connection-session",
66+
default=None,
67+
)
6768
)
6869

6970

@@ -97,7 +98,7 @@ def __init__(
9798
@dataclasses.dataclass
9899
class StrawberrySQLAlchemyAsyncQuery:
99100
session: AsyncSession
100-
query: Callable[[Session], Query]
101+
query: Callable[[], Select]
101102
iterator: Iterator[Any] | None = None
102103
limit: int | None = None
103104
offset: int | None = None
@@ -120,16 +121,13 @@ def __aiter__(self):
120121

121122
async def __anext__(self):
122123
if self.iterator is None:
124+
q = self.query()
125+
if self.limit is not None:
126+
q = q.limit(self.limit)
127+
if self.offset is not None:
128+
q = q.offset(self.offset)
123129

124-
def query_runner(s: Session):
125-
q = self.query(s)
126-
if self.limit is not None:
127-
q = q.limit(self.limit)
128-
if self.offset is not None:
129-
q = q.offset(self.offset)
130-
return list(q)
131-
132-
self.iterator = iter(await self.session.run_sync(query_runner))
130+
self.iterator = iter(await self.session.scalars(q))
133131

134132
try:
135133
return next(self.iterator)
@@ -325,7 +323,7 @@ def default_resolver(
325323
if session is None:
326324
session = field_sessionmaker()
327325

328-
def _get_query(s: Session):
326+
def _get_orm_query(s: Session):
329327
if root is not None:
330328
# root won't be None when resolving nested connections.
331329
# TODO: Maybe we want to send this to a dataloader?
@@ -338,16 +336,29 @@ def _get_query(s: Session):
338336

339337
return query
340338

339+
def _get_select_query():
340+
if root is not None:
341+
# root won't be None when resolving nested connections.
342+
# TODO: Maybe we want to send this to a dataloader?
343+
query = getattr(root, field.python_name)
344+
else:
345+
query = select(model)
346+
347+
if field.keyset is not None:
348+
query = query.order_by(*field.keyset)
349+
350+
return query
351+
341352
if isinstance(session, AsyncSession):
342353
return cast(
343354
Iterable[Any],
344355
StrawberrySQLAlchemyAsyncQuery(
345356
session=session,
346-
query=lambda s: _get_query(s),
357+
query=_get_select_query,
347358
),
348359
)
349360

350-
return _get_query(session)
361+
return _get_orm_query(session)
351362

352363
field.base_resolver = StrawberryResolver(default_resolver)
353364

@@ -415,8 +426,7 @@ def field(
415426
graphql_type: Any | None = None,
416427
extensions: Sequence[FieldExtension] = (),
417428
sessionmaker: _SessionMaker | None = None,
418-
) -> _T:
419-
...
429+
) -> _T: ...
420430

421431

422432
@overload
@@ -437,8 +447,7 @@ def field(
437447
graphql_type: Any | None = None,
438448
extensions: Sequence[FieldExtension] = (),
439449
sessionmaker: _SessionMaker | None = None,
440-
) -> Any:
441-
...
450+
) -> Any: ...
442451

443452

444453
@overload
@@ -459,8 +468,7 @@ def field(
459468
graphql_type: Any | None = None,
460469
extensions: Sequence[FieldExtension] = (),
461470
sessionmaker: _SessionMaker | None = None,
462-
) -> StrawberrySQLAlchemyField:
463-
...
471+
) -> StrawberrySQLAlchemyField: ...
464472

465473

466474
def field(
@@ -599,8 +607,7 @@ def connection(
599607
extensions: Sequence[FieldExtension] = (),
600608
sessionmaker: _SessionMaker | None = None,
601609
keyset: Keyset | None = None,
602-
) -> Any:
603-
...
610+
) -> Any: ...
604611

605612

606613
@overload
@@ -622,8 +629,7 @@ def connection(
622629
extensions: Sequence[FieldExtension] = (),
623630
sessionmaker: _SessionMaker | None = None,
624631
keyset: Keyset | None = None,
625-
) -> Any:
626-
...
632+
) -> Any: ...
627633

628634

629635
def connection(

src/strawberry_sqlalchemy_mapper/relay.py

+53-15
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
)
1515

1616
import sqlakeyset
17+
import sqlakeyset.asyncio
1718
import strawberry
18-
from sqlalchemy import and_, or_
19+
from sqlalchemy import Row, Select, and_, or_
1920
from sqlalchemy.exc import NoResultFound
2021
from sqlalchemy.ext.asyncio import AsyncSession
2122
from sqlalchemy.inspection import inspect as sqlalchemy_inspect
23+
from sqlalchemy.orm import Query
2224
from strawberry import relay
2325
from strawberry.relay.exceptions import NodeIDAnnotationError
2426
from strawberry.relay.types import NodeType
@@ -27,7 +29,7 @@
2729
if TYPE_CHECKING:
2830
from typing_extensions import Literal, Self
2931

30-
from sqlalchemy.orm import Query, Session
32+
from sqlalchemy.orm import Session
3133
from strawberry.types.info import Info
3234
from strawberry.utils.await_maybe import AwaitableOrValue
3335

@@ -64,7 +66,7 @@ class KeysetConnection(relay.Connection[NodeType]):
6466
@classmethod
6567
def resolve_connection(
6668
cls,
67-
nodes: Union[Query, StrawberrySQLAlchemyAsyncQuery], # type: ignore[override]
69+
nodes: Union[Query, Select, StrawberrySQLAlchemyAsyncQuery], # type: ignore[override]
6870
*,
6971
info: Info,
7072
before: Optional[str] = None,
@@ -110,40 +112,76 @@ def resolve_connection(page: sqlakeyset.Page):
110112
end_cursor=page.paging.get_bookmark_at(-1) if page else None,
111113
),
112114
edges=[
113-
edge_class.resolve_edge(n, cursor=page.paging.get_bookmark_at(i))
115+
edge_class.resolve_edge(
116+
n[0] if isinstance(n, Row) else n,
117+
cursor=page.paging.get_bookmark_at(i),
118+
)
114119
for i, n in enumerate(page)
115120
],
116121
)
117122

118-
def resolve_nodes(s: Session, nodes=nodes):
119-
if isinstance(nodes, StrawberrySQLAlchemyAsyncQuery):
120-
nodes = nodes.query(s)
123+
def resolve_nodes(s: Session, nodes: Union[Query, Select]):
124+
if isinstance(nodes, Select):
125+
return resolve_connection(
126+
sqlakeyset.select_page(
127+
s,
128+
nodes,
129+
per_page=per_page,
130+
after=(
131+
sqlakeyset.unserialize_bookmark(after).place
132+
if after
133+
else None
134+
),
135+
before=(
136+
sqlakeyset.unserialize_bookmark(before).place
137+
if before
138+
else None
139+
),
140+
)
141+
)
121142

122143
return resolve_connection(
123144
sqlakeyset.get_page(
124145
nodes,
146+
per_page=per_page,
147+
after=(
148+
sqlakeyset.unserialize_bookmark(after).place if after else None
149+
),
125150
before=(
126151
sqlakeyset.unserialize_bookmark(before).place
127152
if before
128153
else None
129154
),
155+
)
156+
)
157+
158+
async def resolve_nodes_async(s: AsyncSession, nodes: Select):
159+
# the asynchronous SQLAlchemy API only supports select
160+
return resolve_connection(
161+
await sqlakeyset.asyncio.select_page(
162+
s,
163+
nodes,
164+
per_page=per_page,
130165
after=(
131166
sqlakeyset.unserialize_bookmark(after).place if after else None
132167
),
133-
per_page=per_page,
168+
before=(
169+
sqlakeyset.unserialize_bookmark(before).place
170+
if before
171+
else None
172+
),
134173
)
135174
)
136175

137-
# TODO: It would be better to aboid session.run_sync in here but
138-
# sqlakeyset doesn't have a `get_page` async counterpart.
139176
if isinstance(session, AsyncSession):
177+
if isinstance(nodes, StrawberrySQLAlchemyAsyncQuery):
178+
nodes = nodes.query()
140179

141-
async def resolve_async(nodes=nodes):
142-
return await session.run_sync(lambda s: resolve_nodes(s))
143-
144-
return resolve_async()
180+
assert isinstance(nodes, Select)
181+
return resolve_nodes_async(session, nodes)
145182

146-
return resolve_nodes(session)
183+
assert isinstance(nodes, (Query, Select))
184+
return resolve_nodes(session, nodes)
147185

148186

149187
@overload

tests/relay/test_node.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class Query:
109109
await session.commit()
110110

111111
session.add_all([f1, f2, f3])
112-
session.commit()
112+
await session.commit()
113113

114114
for f in [f1, f2, f3]:
115115
result = await schema.execute(query, {"id": relay.to_base64("Fruit", f.id)})
@@ -266,7 +266,7 @@ class Query:
266266
await session.commit()
267267

268268
session.add_all([f1, f2, f3])
269-
session.commit()
269+
await session.commit()
270270

271271
result = await schema.execute(
272272
query,

0 commit comments

Comments
 (0)