Skip to content

Commit

Permalink
Fix prefetch_relationships to work with hierarchical nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
dgarros committed Oct 22, 2024
1 parent 731bee3 commit 5c0fc44
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 38 deletions.
45 changes: 27 additions & 18 deletions backend/infrahub/core/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from infrahub.core.query.relationship import RelationshipGetPeerQuery
from infrahub.core.registry import registry
from infrahub.core.relationship import Relationship
from infrahub.core.relationship import Relationship, RelationshipManager
from infrahub.core.schema import GenericSchema, MainSchemaTypes, NodeSchema, ProfileSchema, RelationshipSchema
from infrahub.core.timestamp import Timestamp
from infrahub.exceptions import NodeNotFoundError, ProcessingError, SchemaNotFoundError
Expand Down Expand Up @@ -1138,8 +1138,8 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements

# if prefetch_relationships is enabled
# Query all the peers associated with all nodes at once.
peers_per_node = None
peers = None
peers_per_node: dict[str, dict[str, list[str]]] = {}
peers: dict[str, Node] = {}
if prefetch_relationships:
query = await NodeListGetRelationshipsQuery.init(
db=db, ids=ids, branch=branch, at=at, branch_agnostic=branch_agnostic
Expand All @@ -1152,7 +1152,8 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements
for node_peers in node_data.values():
peer_ids.extend(node_peers)

peer_ids = list(set(peer_ids))
# query the peers that are not already part of the main list
peer_ids = list(set(peer_ids) - set(ids))
peers = await cls.get_many(
ids=peer_ids,
branch=branch,
Expand All @@ -1162,7 +1163,7 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements
include_source=include_source,
)

nodes = {}
nodes: dict[str, Node] = {}

for node_id in ids: # pylint: disable=too-many-nested-blocks
if node_id not in nodes_info_by_id:
Expand All @@ -1189,19 +1190,6 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements
for attr_name, attr in node_attributes[node_id].attrs.items():
new_node_data[attr_name] = attr

# --------------------------------------------------------
# Relationships
# --------------------------------------------------------
if prefetch_relationships and peers:
for rel_schema in node.schema.relationships:
if node_id in peers_per_node and rel_schema.identifier in peers_per_node[node_id]:
rel_peers = [peers.get(id) for id in peers_per_node[node_id][rel_schema.identifier]]
if rel_schema.cardinality == "one":
if len(rel_peers) == 1:
new_node_data[rel_schema.name] = rel_peers[0]
elif rel_schema.cardinality == "many":
new_node_data[rel_schema.name] = rel_peers

new_node_data_with_profile_overrides = profile_index.apply_profiles(new_node_data)
node_class = identify_node_class(node=node)
node_branch = await registry.get_branch(db=db, branch=node.branch)
Expand All @@ -1210,6 +1198,27 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements

nodes[node_id] = item

# --------------------------------------------------------
# Relationships
# --------------------------------------------------------
if prefetch_relationships:
for node_id, node in nodes.items():
if node_id not in peers_per_node.keys():
continue

for rel_schema in node._schema.relationships:
direction_identifier = f"{rel_schema.direction.value}::{rel_schema.identifier}"
if direction_identifier in peers_per_node[node_id]:
rel_peers = [
peers.get(id, None) or nodes.get(id) for id in peers_per_node[node_id][direction_identifier]
]
rel_manager: RelationshipManager = getattr(node, rel_schema.name)
if rel_schema.cardinality == "one" and not len(rel_peers) == 1:
raise ValueError("Only one relationship expected")

rel_manager.has_fetched_relationships = True
await rel_manager.update(db=db, data=rel_peers)

return nodes

@classmethod
Expand Down
44 changes: 30 additions & 14 deletions backend/infrahub/core/query/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ def _extract_attribute_data(self, result: QueryResult) -> AttributeFromDB:

class NodeListGetRelationshipsQuery(Query):
name: str = "node_list_get_relationship"
insert_return: bool = False

def __init__(self, ids: list[str], **kwargs):
self.ids = ids
Expand All @@ -569,28 +570,43 @@ async def query_init(self, db: InfrahubDatabase, **kwargs) -> None:
rels_filter, rels_params = self.branch.get_query_filter_path(at=self.at, branch_agnostic=self.branch_agnostic)
self.params.update(rels_params)

query = (
"""
MATCH (n) WHERE n.uuid IN $ids
MATCH p = ((n)-[r1:IS_RELATED]-(rel:Relationship)-[r2:IS_RELATED]-(peer))
WHERE all(r IN relationships(p) WHERE (%s))
"""
% rels_filter
)
query = """
MATCH (n1:Node)
WHERE n1.uuid IN $ids
MATCH paths_in = ((n1)<-[r1:IS_RELATED]-(rel1:Relationship)<-[r2:IS_RELATED]-(peer1))
WHERE all(r IN relationships(paths_in) WHERE (%(filters)s))
RETURN n1 as res4_node, rel1 as res3_rel, peer1 as res2_peer, "inbound" as res1_direction
UNION ALL
MATCH (n2:Node)
WHERE n2.uuid IN $ids
MATCH paths_out = ((n2)-[r1:IS_RELATED]->(rel2:Relationship)-[r2:IS_RELATED]->(peer2))
WHERE all(r IN relationships(paths_out) WHERE (%(filters)s))
RETURN n2 as res4_node, rel2 as res3_rel, peer2 as res2_peer, "outbound" as res1_direction
UNION ALL
MATCH (n3:Node)
WHERE n3.uuid IN $ids
MATCH paths_bidir = ((n3)-[r1:IS_RELATED]->(rel3:Relationship)<-[r2:IS_RELATED]-(peer3))
WHERE all(r IN relationships(paths_bidir) WHERE (%(filters)s))
RETURN n3 as res4_node, rel3 as res3_rel, peer3 as res2_peer, "bidirectional" as res1_direction
""" % {"filters": rels_filter}

self.add_to_query(query)

self.return_labels = ["n", "rel", "peer", "r1", "r2"]
# NOTE Not sure why but when using UNION memgraph 2.19 is returning the result in alphabetically reverse order
# instead of respecting the order defined in the query
# In order to have a consistent ordering, all the results have been prepended with res<id>
self.return_labels = ["res4_node", "res3_rel", "res2_peer", "res1_direction"]

def get_peers_group_by_node(self) -> dict[str, dict[str, list[str]]]:
peers_by_node = defaultdict(lambda: defaultdict(list))

for result in self.get_results_group_by(("n", "uuid"), ("rel", "name"), ("peer", "uuid")):
node_id = result.get("n").get("uuid")
rel_name = result.get("rel").get("name")
peer_id = result.get("peer").get("uuid")
for result in self.get_results_group_by(("res4_node", "uuid"), ("res3_rel", "name"), ("res2_peer", "uuid")):
node_id = result.get_node("res4_node").get("uuid")
rel_name = result.get_node("res3_rel").get("name")
peer_id = result.get_node("res2_peer").get("uuid")
direction = result.get_as_str("res1_direction")

peers_by_node[node_id][rel_name].append(peer_id)
peers_by_node[node_id][f"{direction}::{rel_name}"].append(peer_id)

return peers_by_node

Expand Down
6 changes: 4 additions & 2 deletions backend/infrahub/core/relationship/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,11 +943,13 @@ async def _fetch_relationships(
for peer_id in details.peer_ids_present_local_only:
await self.remove(peer_id=peer_id, db=db)

async def get(self, db: InfrahubDatabase) -> Union[Relationship, list[Relationship]]:
async def get(self, db: InfrahubDatabase) -> Relationship | list[Relationship] | None:
rels = await self.get_relationships(db=db)

if self.schema.cardinality == "one":
if self.schema.cardinality == "one" and rels:
return rels[0]
if self.schema.cardinality == "one" and not rels:
return None

return rels

Expand Down
4 changes: 2 additions & 2 deletions backend/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2118,14 +2118,14 @@ async def hierarchical_location_schema(
@pytest.fixture
async def hierarchical_location_data_simple(
db: InfrahubDatabase, default_branch: Branch, hierarchical_location_schema_simple
) -> Dict[str, Node]:
) -> dict[str, Node]:
return await _build_hierarchical_location_data(db=db)


@pytest.fixture
async def hierarchical_location_data(
db: InfrahubDatabase, default_branch: Branch, hierarchical_location_schema
) -> Dict[str, Node]:
) -> dict[str, Node]:
return await _build_hierarchical_location_data(db=db)


Expand Down
27 changes: 27 additions & 0 deletions backend/tests/unit/core/test_manager_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from infrahub.core.node import Node
from infrahub.core.query.node import NodeToProcess
from infrahub.core.registry import registry
from infrahub.core.relationship import Relationship
from infrahub.core.schema import NodeSchema
from infrahub.core.schema.schema_branch import SchemaBranch
from infrahub.core.timestamp import Timestamp
Expand Down Expand Up @@ -263,6 +264,32 @@ async def test_get_many_prefetch(db: InfrahubDatabase, default_branch: Branch, p
assert tags[1]._peer


async def test_get_many_prefetch_hierarchical(
db: InfrahubDatabase, default_branch: Branch, hierarchical_location_data: dict[str, Node]
):
nodes_to_query = ["europe", "asia", "paris", "chicago", "london-r1"]
node_ids = [hierarchical_location_data[value].id for value in nodes_to_query]
nodes = await NodeManager.get_many(db=db, ids=node_ids, prefetch_relationships=True)
assert len(nodes) == 5

paris_id = hierarchical_location_data["paris"].id
europe_id = hierarchical_location_data["europe"].id

assert nodes[paris_id]
children_paris = await nodes[paris_id].children.get(db=db)
assert len(children_paris) == 2
parent_paris = await nodes[paris_id].parent.get(db=db)
assert isinstance(parent_paris, Relationship)
assert parent_paris.peer_id == europe_id

europe_id = hierarchical_location_data["europe"].id
assert nodes[europe_id]
children_europe = await nodes[europe_id].children.get(db=db)
assert len(children_europe) == 2
parent_europe = await nodes[europe_id].parent.get(db=db)
assert parent_europe is None


async def test_get_many_with_profile(db: InfrahubDatabase, default_branch: Branch, criticality_low, criticality_medium):
profile_schema = registry.schema.get("ProfileTestCriticality", branch=default_branch)
crit_profile_1 = await Node.init(db=db, schema=profile_schema)
Expand Down
23 changes: 21 additions & 2 deletions backend/tests/unit/core/test_node_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,27 @@ async def test_query_NodeListGetRelationshipsQuery(db: InfrahubDatabase, default
await query.execute(db=db)
result = query.get_peers_group_by_node()
assert person_jack_tags_main.id in result
assert "builtintag__testperson" in result[person_jack_tags_main.id]
assert len(result[person_jack_tags_main.id]["builtintag__testperson"]) == 2
assert "inbound::builtintag__testperson" in result[person_jack_tags_main.id]
assert len(result[person_jack_tags_main.id]["inbound::builtintag__testperson"]) == 2


async def test_query_NodeListGetRelationshipsQuery_hierarchical(
db: InfrahubDatabase, default_branch: Branch, hierarchical_location_data: dict[str, Node]
):
node_ids = [value.id for value in hierarchical_location_data.values()]
paris_id = hierarchical_location_data["paris"].id
default_branch = await registry.get_branch(db=db, branch="main")
query = await NodeListGetRelationshipsQuery.init(
db=db,
ids=node_ids,
branch=default_branch,
)
await query.execute(db=db)
result = query.get_peers_group_by_node()
assert paris_id in result
assert "inbound::parent__child" in result[paris_id]
assert "outbound::parent__child" in result[paris_id]
assert len(result[paris_id]["inbound::parent__child"]) == 2


async def test_query_NodeDeleteQuery(
Expand Down

0 comments on commit 5c0fc44

Please sign in to comment.