Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to secondary tables relationships #218

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Changes from 19 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
de06188
first fix versoin, working only if the items has the same id
Ckk3 Nov 19, 2024
0cd732d
bring back the first version, still missin the different ids logic!
Ckk3 Nov 19, 2024
401cd65
fix: now query can pickup related_model and self_model id
Ckk3 Nov 22, 2024
6f644e3
fix: not working with different ids
Ckk3 Nov 22, 2024
ca9bc1c
add nes tests
Ckk3 Nov 23, 2024
eb852ce
add tests
Ckk3 Nov 23, 2024
5770379
Fix mypy erros, still missing some tests
Ckk3 Nov 24, 2024
be77996
update code to work with sqlalchemy 1.4
Ckk3 Nov 24, 2024
fb6a580
remove old code that only works with sqlalchemy 2
Ckk3 Nov 24, 2024
0fb61bb
add seconday tables tests in test_loader
Ckk3 Nov 24, 2024
03a5438
add new tests to loadar and start mapper tests
Ckk3 Nov 26, 2024
a575650
add mapper tests
Ckk3 Nov 28, 2024
beaa3f9
refactor conftest
Ckk3 Nov 30, 2024
8a65328
refactor test_loader
Ckk3 Nov 30, 2024
9d76061
refactor test_mapper
Ckk3 Nov 30, 2024
91c24c5
run autopep
Ckk3 Nov 30, 2024
1cd8df4
run autopep
Ckk3 Nov 30, 2024
e96f179
separate test
Ckk3 Nov 30, 2024
4b6516b
fix lint
Ckk3 Nov 30, 2024
9b079d4
add release file
Ckk3 Nov 30, 2024
4baa7ae
refactor tests
Ckk3 Nov 30, 2024
33d7758
refactor loader
Ckk3 Nov 30, 2024
2a53474
fix release
Ckk3 Nov 30, 2024
d04af46
update pre-commit to work with python 3.8
Ckk3 Jan 26, 2025
3f7f13d
update loader.py
Ckk3 Jan 26, 2025
ff3e419
updated mapper
Ckk3 Jan 26, 2025
6752231
fix lint
Ckk3 Jan 26, 2025
0cd68d2
remote autopep8 from dev container because it give problems when work…
Ckk3 Jan 26, 2025
0745c64
fix lint
Ckk3 Jan 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/strawberry_sqlalchemy_mapper/exc.py
Original file line number Diff line number Diff line change
@@ -36,3 +36,11 @@ def __init__(self, model):
f"Model `{model}` is not polymorphic or is not the base model of its "
+ "inheritance chain, and thus cannot be used as an interface."
)


class InvalidLocalRemotePairs(Exception):
def __init__(self, relationship_name):
super().__init__(
f"The `local_remote_pairs` for the relationship `{relationship_name}` is invalid or missing. "
+ "This is likely an issue with the library. Please report this error to the maintainers."
)
77 changes: 68 additions & 9 deletions src/strawberry_sqlalchemy_mapper/loader.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
Tuple,
Union,
)
from strawberry_sqlalchemy_mapper.exc import InvalidLocalRemotePairs

from sqlalchemy import select, tuple_
from sqlalchemy.engine.base import Connection
@@ -45,12 +46,16 @@ def __init__(
"One of bind or async_bind_factory must be set for loader to function properly."
)

async def _scalars_all(self, *args, **kwargs):
async def _scalars_all(self, *args, disabled_optimization_to_secondary_tables=False, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: maybe call this enable_ and have True as the default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont want to do this because it removes optimizations that we only need to remove when we need to pick up secondary tables values, so if the default is True we will lose peformance in queries that dont need it.
But I agree that this var name aren't good enought, so I will change the name to query_secondary_tables and refactor the function.

if self._async_bind_factory:
async with self._async_bind_factory() as bind:
if disabled_optimization_to_secondary_tables is True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick:

Suggested change
if disabled_optimization_to_secondary_tables is True:
if disabled_optimization_to_secondary_tables:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated! Thank you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated!

return (await bind.execute(*args, **kwargs)).all()
return (await bind.scalars(*args, **kwargs)).all()
else:
assert self._bind is not None
if disabled_optimization_to_secondary_tables is True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick:

Suggested change
if disabled_optimization_to_secondary_tables is True:
if disabled_optimization_to_secondary_tables:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

return self._bind.execute(*args, **kwargs).all()
return self._bind.scalars(*args, **kwargs).all()

def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
@@ -63,14 +68,63 @@ def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
related_model = relationship.entity.entity

async def load_fn(keys: List[Tuple]) -> List[Any]:
query = select(related_model).filter(
tuple_(
*[remote for _, remote in relationship.local_remote_pairs or []]
).in_(keys)
)
if relationship.secondary is None:
query = select(related_model).filter(
tuple_(
*[remote for _, remote in relationship.local_remote_pairs or []]
).in_(keys)
)
else:
# Use another query when relationship uses a secondary table
self_model = relationship.parent.entity

if not relationship.local_remote_pairs:
raise InvalidLocalRemotePairs(
f"{related_model.__name__} -- {self_model.__name__}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish: for ruff/black this parenthesis should be closed in the next line. I think you forgot to pre-commit install =P (ditto for the lines below)

Also, we are probably missing a lint check in here which runs ruff/black/etc (and maybe migrate to ruff formatter instead of black soon)

Copy link
Contributor Author

@Ckk3 Ckk3 Jan 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about that, I see now that pre-commit dont run due to some updates that dont work with dev container python version (3.8).
I updated that imports and now i'm fixing all the erros ;)


self_model_key_label = str(
relationship.local_remote_pairs[0][1].key)
related_model_key_label = str(
relationship.local_remote_pairs[1][1].key)

self_model_key = str(
relationship.local_remote_pairs[0][0].key)
related_model_key = str(
relationship.local_remote_pairs[1][0].key)

remote_to_use = relationship.local_remote_pairs[0][1]
query_keys = tuple([item[0] for item in keys])

# This query returns rows in this format -> (self_model.key, related_model)
query = (
select(
getattr(self_model, self_model_key).label(
self_model_key_label),
related_model
)
.join(
relationship.secondary,
getattr(relationship.secondary.c,
related_model_key_label) == getattr(related_model, related_model_key)
)
.join(
self_model,
getattr(relationship.secondary.c,
self_model_key_label) == getattr(self_model, self_model_key)
)
.filter(
remote_to_use.in_(query_keys)
)
)

if relationship.order_by:
query = query.order_by(*relationship.order_by)
rows = await self._scalars_all(query)

if relationship.secondary is not None:
# We need to retrieve values from both the self_model and related_model. To achieve this, we must disable the default SQLAlchemy optimization that returns only related_model values. This is necessary because we use the keys variable to match both related_model and self_model.
rows = await self._scalars_all(query, disabled_optimization_to_secondary_tables=True)
else:
rows = await self._scalars_all(query)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion:

Suggested change
if relationship.secondary is not None:
# We need to retrieve values from both the self_model and related_model. To achieve this, we must disable the default SQLAlchemy optimization that returns only related_model values. This is necessary because we use the keys variable to match both related_model and self_model.
rows = await self._scalars_all(query, disabled_optimization_to_secondary_tables=True)
else:
rows = await self._scalars_all(query)
# We need to retrieve values from both the self_model and related_model.
# To achieve this, we must disable the default SQLAlchemy optimization
# that returns only related_model values. This is necessary because we
# use the keys variable to match both related_model and self_model.
rows = await self._scalars_all(query, disabled_optimization_to_secondary_tables=relationship.secondary is not None)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Updated!


def group_by_remote_key(row: Any) -> Tuple:
return tuple(
@@ -82,8 +136,13 @@ def group_by_remote_key(row: Any) -> Tuple:
)

grouped_keys: Mapping[Tuple, List[Any]] = defaultdict(list)
for row in rows:
grouped_keys[group_by_remote_key(row)].append(row)
if relationship.secondary is None:
for row in rows:
grouped_keys[group_by_remote_key(row)].append(row)
else:
for row in rows:
grouped_keys[(row[0],)].append(row[1])

if relationship.uselist:
return [grouped_keys[key] for key in keys]
else:
93 changes: 63 additions & 30 deletions src/strawberry_sqlalchemy_mapper/mapper.py
Original file line number Diff line number Diff line change
@@ -82,6 +82,7 @@
from strawberry_sqlalchemy_mapper.exc import (
HybridPropertyNotAnnotated,
InterfaceModelNotPolymorphic,
InvalidLocalRemotePairs,
UnsupportedAssociationProxyTarget,
UnsupportedColumnType,
UnsupportedDescriptorType,
@@ -154,7 +155,8 @@ def from_type(cls, type_: type, *, strict: Literal[True]) -> Self: ...

@overload
@classmethod
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]: ...
def from_type(cls, type_: type, *,
strict: bool = False) -> Optional[Self]: ...

@classmethod
def from_type(
@@ -165,7 +167,8 @@ def from_type(
) -> Optional[Self]:
definition = getattr(type_, cls.TYPE_KEY_NAME, None)
if strict and definition is None:
raise TypeError(f"{type_!r} does not have a StrawberrySQLAlchemyType in it")
raise TypeError(
f"{type_!r} does not have a StrawberrySQLAlchemyType in it")
return definition


@@ -228,11 +231,12 @@ class StrawberrySQLAlchemyMapper(Generic[BaseModelType]):

def __init__(
self,
model_to_type_name: Optional[Callable[[Type[BaseModelType]], str]] = None,
model_to_interface_name: Optional[Callable[[Type[BaseModelType]], str]] = None,
extra_sqlalchemy_type_to_strawberry_type_map: Optional[
Mapping[Type[TypeEngine], Type[Any]]
] = None,
model_to_type_name: Optional[Callable[[
Type[BaseModelType]], str]] = None,
model_to_interface_name: Optional[Callable[[
Type[BaseModelType]], str]] = None,
extra_sqlalchemy_type_to_strawberry_type_map: Optional[Mapping[Type[TypeEngine], Type[Any]]
] = None,
) -> None:
if model_to_type_name is None:
model_to_type_name = self._default_model_to_type_name
@@ -295,7 +299,8 @@ def _edge_type_for(self, type_name: str) -> Type[Any]:
"""
edge_name = f"{type_name}Edge"
if edge_name not in self.edge_types:
lazy_type = StrawberrySQLAlchemyLazy(type_name=type_name, mapper=self)
lazy_type = StrawberrySQLAlchemyLazy(
type_name=type_name, mapper=self)
self.edge_types[edge_name] = edge_type = strawberry.type(
dataclasses.make_dataclass(
edge_name,
@@ -314,14 +319,15 @@ def _connection_type_for(self, type_name: str) -> Type[Any]:
connection_name = f"{type_name}Connection"
if connection_name not in self.connection_types:
edge_type = self._edge_type_for(type_name)
lazy_type = StrawberrySQLAlchemyLazy(type_name=type_name, mapper=self)
lazy_type = StrawberrySQLAlchemyLazy(
type_name=type_name, mapper=self)
self.connection_types[connection_name] = connection_type = strawberry.type(
dataclasses.make_dataclass(
connection_name,
[
("edges", List[edge_type]), # type: ignore[valid-type]
],
bases=(relay.ListConnection[lazy_type],), # type: ignore[valid-type]
bases=(relay.ListConnection[lazy_type],), # type: ignore[valid-type]
)
)
setattr(connection_type, _GENERATED_FIELD_KEYS_KEY, ["edges"])
@@ -387,7 +393,7 @@ def _convert_relationship_to_strawberry_type(
if relationship.uselist:
# Use list if excluding relay pagination
if use_list:
return List[ForwardRef(type_name)] # type: ignore
return List[ForwardRef(type_name)] # type: ignore

return self._connection_type_for(type_name)
else:
@@ -451,7 +457,8 @@ def _get_association_proxy_annotation(
strawberry_type.__forward_arg__
)
else:
strawberry_type = self._connection_type_for(strawberry_type.__name__)
strawberry_type = self._connection_type_for(
strawberry_type.__name__)
return strawberry_type

def make_connection_wrapper_resolver(
@@ -500,13 +507,29 @@ async def resolve(self, info: Info):
if relationship.key not in instance_state.unloaded:
related_objects = getattr(self, relationship.key)
else:
relationship_key = tuple(
[
getattr(self, local.key)
for local, _ in relationship.local_remote_pairs or []
if local.key
]
)
if relationship.secondary is None:
relationship_key = tuple(
[
getattr(self, local.key)
for local, _ in relationship.local_remote_pairs or []
if local.key
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

praise: you can pass the iterator to the tuple directly, no need to create a list for that

Suggested change
[
getattr(self, local.key)
for local, _ in relationship.local_remote_pairs or []
if local.key
]
getattr(self, local.key)
for local, _ in relationship.local_remote_pairs or []
if local.key

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

)
else:
# If has a secondary table, gets only the first ID as additional IDs require a separate query
if not relationship.local_remote_pairs:
raise InvalidLocalRemotePairs(
f"{relationship.entity.entity.__name__} -- {relationship.parent.entity.__name__}")

local_remote_pairs_secondary_table_local = relationship.local_remote_pairs[
0][0]
relationship_key = tuple(
[
getattr(
self, str(local_remote_pairs_secondary_table_local.key)),
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

)

if any(item is None for item in relationship_key):
if relationship.uselist:
return []
@@ -536,7 +559,8 @@ def connection_resolver_for(
if relationship.uselist and not use_list:
return self.make_connection_wrapper_resolver(
relationship_resolver,
self.model_to_type_or_interface_name(relationship.entity.entity), # type: ignore[arg-type]
self.model_to_type_or_interface_name(
relationship.entity.entity), # type: ignore[arg-type]
)
else:
return relationship_resolver
@@ -554,13 +578,15 @@ def association_proxy_resolver_for(
Return an async field resolver for the given association proxy.
"""
in_between_relationship = mapper.relationships[descriptor.target_collection]
in_between_resolver = self.relationship_resolver_for(in_between_relationship)
in_between_resolver = self.relationship_resolver_for(
in_between_relationship)
in_between_mapper: Mapper = mapper.relationships[ # type: ignore[assignment]
descriptor.target_collection
].entity
assert descriptor.value_attr in in_between_mapper.relationships
end_relationship = in_between_mapper.relationships[descriptor.value_attr]
end_relationship_resolver = self.relationship_resolver_for(end_relationship)
end_relationship_resolver = self.relationship_resolver_for(
end_relationship)
end_type_name = self.model_to_type_or_interface_name(
end_relationship.entity.entity # type: ignore[arg-type]
)
@@ -587,7 +613,8 @@ async def resolve(self, info: Info):
if outputs and isinstance(outputs[0], list):
outputs = list(chain.from_iterable(outputs))
else:
outputs = [output for output in outputs if output is not None]
outputs = [
output for output in outputs if output is not None]
else:
outputs = await end_relationship_resolver(in_between_objects, info)
if not isinstance(outputs, collections.abc.Iterable):
@@ -683,7 +710,8 @@ def convert(type_: Any) -> Any:
setattr(type_, key, field(resolver=val))
generated_field_keys.append(key)

self._handle_columns(mapper, type_, excluded_keys, generated_field_keys)
self._handle_columns(
mapper, type_, excluded_keys, generated_field_keys)
relationship: RelationshipProperty
for key, relationship in mapper.relationships.items():
if (
@@ -805,7 +833,7 @@ def convert(type_: Any) -> Any:
setattr(
type_,
attr,
types.MethodType(func, type_), # type: ignore[arg-type]
types.MethodType(func, type_), # type: ignore[arg-type]
)

# Adjust types that inherit from other types/interfaces that implement Node
@@ -818,7 +846,8 @@ def convert(type_: Any) -> Any:
setattr(
type_,
attr,
types.MethodType(cast(classmethod, meth).__func__, type_),
types.MethodType(
cast(classmethod, meth).__func__, type_),
)

# need to make fields that are already in the type
@@ -846,7 +875,8 @@ def convert(type_: Any) -> Any:
model=model,
),
)
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY, generated_field_keys)
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY,
generated_field_keys)
setattr(mapped_type, _ORIGINAL_TYPE_KEY, type_)
return mapped_type

@@ -886,14 +916,16 @@ def _fix_annotation_namespaces(self) -> None:
self.edge_types.values(),
self.connection_types.values(),
):
strawberry_definition = get_object_definition(mapped_type, strict=True)
strawberry_definition = get_object_definition(
mapped_type, strict=True)
for f in strawberry_definition.fields:
if f.name in getattr(mapped_type, _GENERATED_FIELD_KEYS_KEY):
namespace = {}
if hasattr(mapped_type, _ORIGINAL_TYPE_KEY):
namespace.update(
sys.modules[
getattr(mapped_type, _ORIGINAL_TYPE_KEY).__module__
getattr(mapped_type,
_ORIGINAL_TYPE_KEY).__module__
].__dict__
)
namespace.update(self.mapped_types)
@@ -924,7 +956,8 @@ def _map_unmapped_relationships(self) -> None:
if type_name not in self.mapped_interfaces:
unmapped_interface_models.add(model)
for model in unmapped_models:
self.type(model)(type(self.model_to_type_name(model), (object,), {}))
self.type(model)(
type(self.model_to_type_name(model), (object,), {}))
for model in unmapped_interface_models:
self.interface(model)(
type(self.model_to_interface_name(model), (object,), {})
460 changes: 460 additions & 0 deletions tests/conftest.py

Large diffs are not rendered by default.

182 changes: 174 additions & 8 deletions tests/relay/test_connection.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.orm import sessionmaker
from strawberry import relay
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper, connection
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper, connection, StrawberrySQLAlchemyLoader
from strawberry_sqlalchemy_mapper.relay import KeysetConnection


@@ -37,7 +37,8 @@ class Fruit(relay.Node):

@strawberry.type
class Query:
fruits: relay.ListConnection[Fruit] = connection(sessionmaker=sessionmaker)
fruits: relay.ListConnection[Fruit] = connection(
sessionmaker=sessionmaker)

schema = strawberry.Schema(query=Query)

@@ -74,7 +75,8 @@ class Fruit(relay.Node):

@strawberry.type
class Query:
fruits: relay.ListConnection[Fruit] = connection(sessionmaker=sessionmaker)
fruits: relay.ListConnection[Fruit] = connection(
sessionmaker=sessionmaker)

schema = strawberry.Schema(query=Query)

@@ -259,7 +261,8 @@ class Fruit(relay.Node):

@strawberry.type
class Query:
fruits: relay.ListConnection[Fruit] = connection(sessionmaker=sessionmaker)
fruits: relay.ListConnection[Fruit] = connection(
sessionmaker=sessionmaker)

schema = strawberry.Schema(query=Query)

@@ -319,7 +322,8 @@ class Fruit(relay.Node):

@strawberry.type
class Query:
fruits: relay.ListConnection[Fruit] = connection(sessionmaker=sessionmaker)
fruits: relay.ListConnection[Fruit] = connection(
sessionmaker=sessionmaker)

schema = strawberry.Schema(query=Query)

@@ -381,7 +385,8 @@ class Fruit(relay.Node):

@strawberry.type
class Query:
fruits: relay.ListConnection[Fruit] = connection(sessionmaker=sessionmaker)
fruits: relay.ListConnection[Fruit] = connection(
sessionmaker=sessionmaker)

schema = strawberry.Schema(query=Query)

@@ -441,7 +446,8 @@ class Fruit(relay.Node):

@strawberry.type
class Query:
fruits: relay.ListConnection[Fruit] = connection(sessionmaker=sessionmaker)
fruits: relay.ListConnection[Fruit] = connection(
sessionmaker=sessionmaker)

schema = strawberry.Schema(query=Query)

@@ -467,7 +473,8 @@ class Query:
session.commit()

result = schema.execute_sync(
query, {"first": 1, "before": relay.to_base64("arrayconnection", 2)}
query, {"first": 1, "before": relay.to_base64(
"arrayconnection", 2)}
)
assert result.errors is None

@@ -755,3 +762,162 @@ class Query:
},
}
}


# TODO Investigate this test
@pytest.mark.skip("This test is currently failing because the Query with relay.ListConnection generates two DepartmentConnection, which violates the schema's expectations. After investigation, it appears this issue is related to the Relay implementation rather than the secondary table issue. We'll address this later. Additionally, note that the `result.data` may be incorrect in this test.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: strange...

I know that this lib will generate automatic connections for relations, and that indeed could conflict with the departments you are defining. But you don't have a departments relation inside Employee

Maybe the secondary table is generating such connection with that name? If so, is that correct/expected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check this again, but the expected is that creates only one Connection

@pytest.mark.asyncio
async def test_query_with_secondary_table_with_values_list(
secondary_tables,
base,
async_engine,
async_sessionmaker
):
async with async_engine.begin() as conn:
await conn.run_sync(base.metadata.create_all)

mapper = StrawberrySQLAlchemyMapper()
EmployeeModel, DepartmentModel = secondary_tables

@mapper.type(DepartmentModel)
class Department():
pass

@mapper.type(EmployeeModel)
class Employee():
pass

@strawberry.type
class Query:
departments: relay.ListConnection[Department] = connection(
sessionmaker=async_sessionmaker)

mapper.finalize()
schema = strawberry.Schema(query=Query)

query = """\
query {
departments {
edges {
node {
id
name
employees {
edges {
node {
id
name
role
department {
edges {
node {
id
name
}
}
}
}
}
}
}
}
}
}
"""

# Create test data
async with async_sessionmaker(expire_on_commit=False) as session:
department1 = DepartmentModel(id=10, name="Department Test 1")
department2 = DepartmentModel(id=3, name="Department Test 2")
e1 = EmployeeModel(id=1, name="John", role="Developer")
e2 = EmployeeModel(id=5, name="Bill", role="Doctor")
e3 = EmployeeModel(id=4, name="Maria", role="Teacher")
department1.employees.append(e1)
department1.employees.append(e2)
department2.employees.append(e3)
session.add_all([department1, department2, e1, e2, e3])
await session.commit()

result = await schema.execute(query, context_value={
"sqlalchemy_loader": StrawberrySQLAlchemyLoader(
async_bind_factory=async_sessionmaker
)
})
assert result.errors is None
assert result.data == {
"departments": {
"edges": [
{
"node": {
"id": 10,
"name": "Department Test 1",
"employees": {
"edges": [
{
"node": {
"id": 5,
"name": "Bill",
"role": "Doctor",
"department": {
"edges": [
{
"node": {
"id": 10,
"name": "Department Test 1"
}
}
]
}
}
},
{
"node": {
"id": 1,
"name": "John",
"role": "Developer",
"department": {
"edges": [
{
"node": {
"id": 10,
"name": "Department Test 1"
}
}
]
}
}
}
]
}
}
},
{
"node": {
"id": 3,
"name": "Department Test 2",
"employees": {
"edges": [
{
"node": {
"id": 4,
"name": "Maria",
"role": "Teacher",
"department": {
"edges": [
{
"node": {
"id": 3,
"name": "Department Test 2"
}
}
]
}
}
}
]
}
}
}
]
}
}
187 changes: 138 additions & 49 deletions tests/test_loader.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
from sqlalchemy import Column, ForeignKey, Integer, String, Table
from sqlalchemy.orm import relationship
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyLoader
from strawberry_sqlalchemy_mapper.exc import InvalidLocalRemotePairs

pytest_plugins = ("pytest_asyncio",)

@@ -26,38 +27,6 @@ class Department(base):
return Employee, Department


@pytest.fixture
def secondary_tables(base):
EmployeeDepartmentJoinTable = Table(
"employee_department_join_table",
base.metadata,
Column("employee_id", ForeignKey("employee.e_id"), primary_key=True),
Column("department_id", ForeignKey("department.d_id"), primary_key=True),
)

class Employee(base):
__tablename__ = "employee"
e_id = Column(Integer, autoincrement=True, primary_key=True)
name = Column(String, nullable=False)
departments = relationship(
"Department",
secondary="employee_department_join_table",
back_populates="employees",
)

class Department(base):
__tablename__ = "department"
d_id = Column(Integer, autoincrement=True, primary_key=True)
name = Column(String, nullable=False)
employees = relationship(
"Employee",
secondary="employee_department_join_table",
back_populates="departments",
)

return Employee, Department


def test_loader_init():
loader = StrawberrySQLAlchemyLoader(bind=None)
assert loader._bind is None
@@ -146,36 +115,156 @@ async def test_loader_with_async_session(
assert {e.name for e in employees} == {"e1"}


@pytest.mark.xfail
def create_default_data_on_secondary_table_tests(session, Employee, Department):
e1 = Employee(name="e1", id=1)
e2 = Employee(name="e2", id=2)
d1 = Department(name="d1")
d2 = Department(name="d2")
d3 = Department(name="d3")
session.add_all([e1, e2, d1, d2, d3])
session.flush()

e1.department.append(d1)
e1.department.append(d2)
e2.department.append(d2)
return e1, e2, d1, d2, d3


@pytest.mark.asyncio
async def test_loader_for_secondary(engine, base, sessionmaker, secondary_tables):
async def test_loader_for_secondary_table(engine, base, sessionmaker, secondary_tables):
Employee, Department = secondary_tables
base.metadata.create_all(engine)

with sessionmaker() as session:
e1 = Employee(name="e1")
e2 = Employee(name="e2")
d1 = Department(name="d1")
d2 = Department(name="d2")
session.add(e1)
session.add(e2)
session.add(d1)
session.add(d2)
session.flush()
e1, _, _, _, _ = create_default_data_on_secondary_table_tests(session=session, Employee=Employee, Department=Department)
session.commit()

e1.departments.append(d1)
e1.departments.append(d2)
e2.departments.append(d2)
base_loader = StrawberrySQLAlchemyLoader(bind=session)
loader = base_loader.loader_for(Employee.department.property)

key = tuple(
[
getattr(
e1, str(Employee.department.property.local_remote_pairs[0][0].key)),
]
)

departments = await loader.load(key)
assert {d.name for d in departments} == {"d1", "d2"}


@pytest.mark.asyncio
async def test_loader_for_secondary_tables_with_another_foreign_key(engine, base, sessionmaker, secondary_tables_with_another_foreign_key):
Employee, Department = secondary_tables_with_another_foreign_key
base.metadata.create_all(engine)

with sessionmaker() as session:
e1, _, _, _, _ = create_default_data_on_secondary_table_tests(session=session, Employee=Employee, Department=Department)
session.commit()

base_loader = StrawberrySQLAlchemyLoader(bind=session)
loader = base_loader.loader_for(Employee.departments.property)
loader = base_loader.loader_for(Employee.department.property)

key = tuple(
[
getattr(e1, local.key)
for local, _ in Employee.departments.property.local_remote_pairs
getattr(
e1, str(Employee.department.property.local_remote_pairs[0][0].key)),
]
)

departments = await loader.load(key)
assert {d.name for d in departments} == {"d1", "d2"}


@pytest.mark.asyncio
async def test_loader_for_secondary_tables_with_more_secondary_tables(engine, base, sessionmaker, secondary_tables_with_more_secondary_tables):
Employee, Department, Building = secondary_tables_with_more_secondary_tables
base.metadata.create_all(engine)

with sessionmaker() as session:
e1, e2, _, _, _ = create_default_data_on_secondary_table_tests(session=session, Employee=Employee, Department=Department)

b1 = Building(id=2, name="Building 1")
b1.employees.append(e1)
b1.employees.append(e2)
session.add(b1)
session.commit()

base_loader = StrawberrySQLAlchemyLoader(bind=session)
loader = base_loader.loader_for(Employee.department.property)

key = tuple(
[
getattr(
e1, str(Employee.department.property.local_remote_pairs[0][0].key)),
]
)

departments = await loader.load(key)
assert {d.name for d in departments} == {"d1", "d2"}


@pytest.mark.asyncio
async def test_loader_for_secondary_tables_with_use_list_false(engine, base, sessionmaker, secondary_tables_with_use_list_false):
Employee, Department = secondary_tables_with_use_list_false
base.metadata.create_all(engine)

with sessionmaker() as session:
e1, _, _, _, _ = create_default_data_on_secondary_table_tests(session=session, Employee=Employee, Department=Department)
session.commit()

base_loader = StrawberrySQLAlchemyLoader(bind=session)
loader = base_loader.loader_for(Employee.department.property)

key = tuple(
[
getattr(
e1, str(Employee.department.property.local_remote_pairs[0][0].key)),
]
)

departments = await loader.load(key)
assert {d.name for d in departments} == {"d1"}


@pytest.mark.asyncio
async def test_loader_for_secondary_tables_with_normal_relationship(engine, base, sessionmaker, secondary_tables_with_normal_relationship):
Employee, Department, Building = secondary_tables_with_normal_relationship
base.metadata.create_all(engine)

with sessionmaker() as session:
e1, e2, _, _, _ = create_default_data_on_secondary_table_tests(session=session, Employee=Employee, Department=Department)

b1 = Building(id=2, name="Building 1")
b1.employees.append(e1)
b1.employees.append(e2)
session.add(b1)
session.commit()

base_loader = StrawberrySQLAlchemyLoader(bind=session)
loader = base_loader.loader_for(Employee.department.property)

key = tuple(
[
getattr(
e1, str(Employee.department.property.local_remote_pairs[0][0].key)),
]
)

departments = await loader.load(key)
assert {d.name for d in departments} == {"d1", "d2"}


@pytest.mark.asyncio
async def test_loader_for_secondary_tables_should_raise_exception_if_relationship_dont_has_local_remote_pairs(engine, base, sessionmaker, secondary_tables_with_normal_relationship):
Employee, Department, Building = secondary_tables_with_normal_relationship
base.metadata.create_all(engine)

with sessionmaker() as session:
base_loader = StrawberrySQLAlchemyLoader(bind=session)

Employee.department.property.local_remote_pairs = []
loader = base_loader.loader_for(Employee.department.property)

with pytest.raises(expected_exception=InvalidLocalRemotePairs):
await loader.load((1,))
120 changes: 120 additions & 0 deletions tests/test_mapper.py
Original file line number Diff line number Diff line change
@@ -379,3 +379,123 @@ def departments(self) -> Department: ...
}
'''
assert str(schema) == textwrap.dedent(expected).strip()


def test_relationships_schema_with_secondary_tables(secondary_tables, mapper, expected_schema_from_secondary_tables):
EmployeeModel, DepartmentModel = secondary_tables

@mapper.type(EmployeeModel)
class Employee:
pass

@mapper.type(DepartmentModel)
class Department:
pass

@strawberry.type
class Query:
@strawberry.field
def departments(self) -> List[Department]: ...

mapper.finalize()
schema = strawberry.Schema(query=Query)

assert str(schema) == textwrap.dedent(expected_schema_from_secondary_tables).strip()


def test_relationships_schema_with_secondary_tables_with_another_foreign_key(secondary_tables_with_another_foreign_key, mapper, expected_schema_from_secondary_tables):
EmployeeModel, DepartmentModel = secondary_tables_with_another_foreign_key

@mapper.type(EmployeeModel)
class Employee:
pass

@mapper.type(DepartmentModel)
class Department:
pass

@strawberry.type
class Query:
@strawberry.field
def departments(self) -> List[Department]: ...

mapper.finalize()
schema = strawberry.Schema(query=Query)

assert str(schema) == textwrap.dedent(expected_schema_from_secondary_tables).strip()


def test_relationships_schema_with_secondary_tables_with_more_secondary_tables(secondary_tables_with_more_secondary_tables, mapper, expected_schema_from_secondary_tables_with_more_secondary_tables):
EmployeeModel, DepartmentModel, BuildingModel = secondary_tables_with_more_secondary_tables

@mapper.type(EmployeeModel)
class Employee:
pass

@mapper.type(DepartmentModel)
class Department:
pass

@mapper.type(BuildingModel)
class Building:
pass

@strawberry.type
class Query:
@strawberry.field
def departments(self) -> List[Department]: ...

mapper.finalize()
schema = strawberry.Schema(query=Query)

assert str(schema) == textwrap.dedent(expected_schema_from_secondary_tables_with_more_secondary_tables).strip()


def test_relationships_schema_with_secondary_tables_with_use_list_false(secondary_tables_with_use_list_false, mapper, expected_schema_from_secondary_tables_with_more_secondary_tables_with_use_list_false):
EmployeeModel, DepartmentModel = secondary_tables_with_use_list_false

@mapper.type(EmployeeModel)
class Employee:
pass

@mapper.type(DepartmentModel)
class Department:
pass


@strawberry.type
class Query:
@strawberry.field
def departments(self) -> List[Department]: ...

mapper.finalize()
schema = strawberry.Schema(query=Query)

assert str(schema) == textwrap.dedent(expected_schema_from_secondary_tables_with_more_secondary_tables_with_use_list_false).strip()


def test_relationships_schema_with_secondary_tables_with_normal_relationship(secondary_tables_with_normal_relationship, mapper, expected_schema_from_secondary_tables_with_more_secondary_tables_with__with_normal_relationship):
EmployeeModel, DepartmentModel, BuildingModel = secondary_tables_with_normal_relationship

@mapper.type(EmployeeModel)
class Employee:
pass

@mapper.type(DepartmentModel)
class Department:
pass

@mapper.type(BuildingModel)
class Building():
pass


@strawberry.type
class Query:
@strawberry.field
def departments(self) -> List[Department]: ...

mapper.finalize()
schema = strawberry.Schema(query=Query)

assert str(schema) == textwrap.dedent(expected_schema_from_secondary_tables_with_more_secondary_tables_with__with_normal_relationship).strip()
1,020 changes: 1,020 additions & 0 deletions tests/test_secondary_tables_query.py

Large diffs are not rendered by default.