From 5c0fc4472a5522949f876df30912750d46f2bf3f Mon Sep 17 00:00:00 2001 From: Damien Garros Date: Tue, 22 Oct 2024 07:19:19 +0200 Subject: [PATCH] Fix prefetch_relationships to work with hierarchical nodes --- backend/infrahub/core/manager.py | 45 ++++++++++++-------- backend/infrahub/core/query/node.py | 44 +++++++++++++------ backend/infrahub/core/relationship/model.py | 6 ++- backend/tests/unit/conftest.py | 4 +- backend/tests/unit/core/test_manager_node.py | 27 ++++++++++++ backend/tests/unit/core/test_node_query.py | 23 +++++++++- 6 files changed, 111 insertions(+), 38 deletions(-) diff --git a/backend/infrahub/core/manager.py b/backend/infrahub/core/manager.py index 08e550223d..c14137df70 100644 --- a/backend/infrahub/core/manager.py +++ b/backend/infrahub/core/manager.py @@ -20,7 +20,7 @@ ) from infrahub.core.query.relationship import RelationshipGetPeerQuery from infrahub.core.registry import registry -from infrahub.core.relationship import Relationship +from infrahub.core.relationship import Relationship, RelationshipManager from infrahub.core.schema import GenericSchema, MainSchemaTypes, NodeSchema, ProfileSchema, RelationshipSchema from infrahub.core.timestamp import Timestamp from infrahub.exceptions import NodeNotFoundError, ProcessingError, SchemaNotFoundError @@ -1138,8 +1138,8 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements # if prefetch_relationships is enabled # Query all the peers associated with all nodes at once. - peers_per_node = None - peers = None + peers_per_node: dict[str, dict[str, list[str]]] = {} + peers: dict[str, Node] = {} if prefetch_relationships: query = await NodeListGetRelationshipsQuery.init( db=db, ids=ids, branch=branch, at=at, branch_agnostic=branch_agnostic @@ -1152,7 +1152,8 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements for node_peers in node_data.values(): peer_ids.extend(node_peers) - peer_ids = list(set(peer_ids)) + # query the peers that are not already part of the main list + peer_ids = list(set(peer_ids) - set(ids)) peers = await cls.get_many( ids=peer_ids, branch=branch, @@ -1162,7 +1163,7 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements include_source=include_source, ) - nodes = {} + nodes: dict[str, Node] = {} for node_id in ids: # pylint: disable=too-many-nested-blocks if node_id not in nodes_info_by_id: @@ -1189,19 +1190,6 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements for attr_name, attr in node_attributes[node_id].attrs.items(): new_node_data[attr_name] = attr - # -------------------------------------------------------- - # Relationships - # -------------------------------------------------------- - if prefetch_relationships and peers: - for rel_schema in node.schema.relationships: - if node_id in peers_per_node and rel_schema.identifier in peers_per_node[node_id]: - rel_peers = [peers.get(id) for id in peers_per_node[node_id][rel_schema.identifier]] - if rel_schema.cardinality == "one": - if len(rel_peers) == 1: - new_node_data[rel_schema.name] = rel_peers[0] - elif rel_schema.cardinality == "many": - new_node_data[rel_schema.name] = rel_peers - new_node_data_with_profile_overrides = profile_index.apply_profiles(new_node_data) node_class = identify_node_class(node=node) node_branch = await registry.get_branch(db=db, branch=node.branch) @@ -1210,6 +1198,27 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements nodes[node_id] = item + # -------------------------------------------------------- + # Relationships + # -------------------------------------------------------- + if prefetch_relationships: + for node_id, node in nodes.items(): + if node_id not in peers_per_node.keys(): + continue + + for rel_schema in node._schema.relationships: + direction_identifier = f"{rel_schema.direction.value}::{rel_schema.identifier}" + if direction_identifier in peers_per_node[node_id]: + rel_peers = [ + peers.get(id, None) or nodes.get(id) for id in peers_per_node[node_id][direction_identifier] + ] + rel_manager: RelationshipManager = getattr(node, rel_schema.name) + if rel_schema.cardinality == "one" and not len(rel_peers) == 1: + raise ValueError("Only one relationship expected") + + rel_manager.has_fetched_relationships = True + await rel_manager.update(db=db, data=rel_peers) + return nodes @classmethod diff --git a/backend/infrahub/core/query/node.py b/backend/infrahub/core/query/node.py index a990458f1c..f04417ceca 100644 --- a/backend/infrahub/core/query/node.py +++ b/backend/infrahub/core/query/node.py @@ -557,6 +557,7 @@ def _extract_attribute_data(self, result: QueryResult) -> AttributeFromDB: class NodeListGetRelationshipsQuery(Query): name: str = "node_list_get_relationship" + insert_return: bool = False def __init__(self, ids: list[str], **kwargs): self.ids = ids @@ -569,28 +570,43 @@ async def query_init(self, db: InfrahubDatabase, **kwargs) -> None: rels_filter, rels_params = self.branch.get_query_filter_path(at=self.at, branch_agnostic=self.branch_agnostic) self.params.update(rels_params) - query = ( - """ - MATCH (n) WHERE n.uuid IN $ids - MATCH p = ((n)-[r1:IS_RELATED]-(rel:Relationship)-[r2:IS_RELATED]-(peer)) - WHERE all(r IN relationships(p) WHERE (%s)) - """ - % rels_filter - ) + query = """ + MATCH (n1:Node) + WHERE n1.uuid IN $ids + MATCH paths_in = ((n1)<-[r1:IS_RELATED]-(rel1:Relationship)<-[r2:IS_RELATED]-(peer1)) + WHERE all(r IN relationships(paths_in) WHERE (%(filters)s)) + RETURN n1 as res4_node, rel1 as res3_rel, peer1 as res2_peer, "inbound" as res1_direction + UNION ALL + MATCH (n2:Node) + WHERE n2.uuid IN $ids + MATCH paths_out = ((n2)-[r1:IS_RELATED]->(rel2:Relationship)-[r2:IS_RELATED]->(peer2)) + WHERE all(r IN relationships(paths_out) WHERE (%(filters)s)) + RETURN n2 as res4_node, rel2 as res3_rel, peer2 as res2_peer, "outbound" as res1_direction + UNION ALL + MATCH (n3:Node) + WHERE n3.uuid IN $ids + MATCH paths_bidir = ((n3)-[r1:IS_RELATED]->(rel3:Relationship)<-[r2:IS_RELATED]-(peer3)) + WHERE all(r IN relationships(paths_bidir) WHERE (%(filters)s)) + RETURN n3 as res4_node, rel3 as res3_rel, peer3 as res2_peer, "bidirectional" as res1_direction + """ % {"filters": rels_filter} self.add_to_query(query) - self.return_labels = ["n", "rel", "peer", "r1", "r2"] + # NOTE Not sure why but when using UNION memgraph 2.19 is returning the result in alphabetically reverse order + # instead of respecting the order defined in the query + # In order to have a consistent ordering, all the results have been prepended with res + self.return_labels = ["res4_node", "res3_rel", "res2_peer", "res1_direction"] def get_peers_group_by_node(self) -> dict[str, dict[str, list[str]]]: peers_by_node = defaultdict(lambda: defaultdict(list)) - for result in self.get_results_group_by(("n", "uuid"), ("rel", "name"), ("peer", "uuid")): - node_id = result.get("n").get("uuid") - rel_name = result.get("rel").get("name") - peer_id = result.get("peer").get("uuid") + for result in self.get_results_group_by(("res4_node", "uuid"), ("res3_rel", "name"), ("res2_peer", "uuid")): + node_id = result.get_node("res4_node").get("uuid") + rel_name = result.get_node("res3_rel").get("name") + peer_id = result.get_node("res2_peer").get("uuid") + direction = result.get_as_str("res1_direction") - peers_by_node[node_id][rel_name].append(peer_id) + peers_by_node[node_id][f"{direction}::{rel_name}"].append(peer_id) return peers_by_node diff --git a/backend/infrahub/core/relationship/model.py b/backend/infrahub/core/relationship/model.py index 4c1a083347..e3d6180da4 100644 --- a/backend/infrahub/core/relationship/model.py +++ b/backend/infrahub/core/relationship/model.py @@ -943,11 +943,13 @@ async def _fetch_relationships( for peer_id in details.peer_ids_present_local_only: await self.remove(peer_id=peer_id, db=db) - async def get(self, db: InfrahubDatabase) -> Union[Relationship, list[Relationship]]: + async def get(self, db: InfrahubDatabase) -> Relationship | list[Relationship] | None: rels = await self.get_relationships(db=db) - if self.schema.cardinality == "one": + if self.schema.cardinality == "one" and rels: return rels[0] + if self.schema.cardinality == "one" and not rels: + return None return rels diff --git a/backend/tests/unit/conftest.py b/backend/tests/unit/conftest.py index 9b84684649..262a8ced53 100644 --- a/backend/tests/unit/conftest.py +++ b/backend/tests/unit/conftest.py @@ -2118,14 +2118,14 @@ async def hierarchical_location_schema( @pytest.fixture async def hierarchical_location_data_simple( db: InfrahubDatabase, default_branch: Branch, hierarchical_location_schema_simple -) -> Dict[str, Node]: +) -> dict[str, Node]: return await _build_hierarchical_location_data(db=db) @pytest.fixture async def hierarchical_location_data( db: InfrahubDatabase, default_branch: Branch, hierarchical_location_schema -) -> Dict[str, Node]: +) -> dict[str, Node]: return await _build_hierarchical_location_data(db=db) diff --git a/backend/tests/unit/core/test_manager_node.py b/backend/tests/unit/core/test_manager_node.py index f0b5c476de..d858730780 100644 --- a/backend/tests/unit/core/test_manager_node.py +++ b/backend/tests/unit/core/test_manager_node.py @@ -7,6 +7,7 @@ from infrahub.core.node import Node from infrahub.core.query.node import NodeToProcess from infrahub.core.registry import registry +from infrahub.core.relationship import Relationship from infrahub.core.schema import NodeSchema from infrahub.core.schema.schema_branch import SchemaBranch from infrahub.core.timestamp import Timestamp @@ -263,6 +264,32 @@ async def test_get_many_prefetch(db: InfrahubDatabase, default_branch: Branch, p assert tags[1]._peer +async def test_get_many_prefetch_hierarchical( + db: InfrahubDatabase, default_branch: Branch, hierarchical_location_data: dict[str, Node] +): + nodes_to_query = ["europe", "asia", "paris", "chicago", "london-r1"] + node_ids = [hierarchical_location_data[value].id for value in nodes_to_query] + nodes = await NodeManager.get_many(db=db, ids=node_ids, prefetch_relationships=True) + assert len(nodes) == 5 + + paris_id = hierarchical_location_data["paris"].id + europe_id = hierarchical_location_data["europe"].id + + assert nodes[paris_id] + children_paris = await nodes[paris_id].children.get(db=db) + assert len(children_paris) == 2 + parent_paris = await nodes[paris_id].parent.get(db=db) + assert isinstance(parent_paris, Relationship) + assert parent_paris.peer_id == europe_id + + europe_id = hierarchical_location_data["europe"].id + assert nodes[europe_id] + children_europe = await nodes[europe_id].children.get(db=db) + assert len(children_europe) == 2 + parent_europe = await nodes[europe_id].parent.get(db=db) + assert parent_europe is None + + async def test_get_many_with_profile(db: InfrahubDatabase, default_branch: Branch, criticality_low, criticality_medium): profile_schema = registry.schema.get("ProfileTestCriticality", branch=default_branch) crit_profile_1 = await Node.init(db=db, schema=profile_schema) diff --git a/backend/tests/unit/core/test_node_query.py b/backend/tests/unit/core/test_node_query.py index 622c1ca0b0..ef7ba66314 100644 --- a/backend/tests/unit/core/test_node_query.py +++ b/backend/tests/unit/core/test_node_query.py @@ -353,8 +353,27 @@ async def test_query_NodeListGetRelationshipsQuery(db: InfrahubDatabase, default await query.execute(db=db) result = query.get_peers_group_by_node() assert person_jack_tags_main.id in result - assert "builtintag__testperson" in result[person_jack_tags_main.id] - assert len(result[person_jack_tags_main.id]["builtintag__testperson"]) == 2 + assert "inbound::builtintag__testperson" in result[person_jack_tags_main.id] + assert len(result[person_jack_tags_main.id]["inbound::builtintag__testperson"]) == 2 + + +async def test_query_NodeListGetRelationshipsQuery_hierarchical( + db: InfrahubDatabase, default_branch: Branch, hierarchical_location_data: dict[str, Node] +): + node_ids = [value.id for value in hierarchical_location_data.values()] + paris_id = hierarchical_location_data["paris"].id + default_branch = await registry.get_branch(db=db, branch="main") + query = await NodeListGetRelationshipsQuery.init( + db=db, + ids=node_ids, + branch=default_branch, + ) + await query.execute(db=db) + result = query.get_peers_group_by_node() + assert paris_id in result + assert "inbound::parent__child" in result[paris_id] + assert "outbound::parent__child" in result[paris_id] + assert len(result[paris_id]["inbound::parent__child"]) == 2 async def test_query_NodeDeleteQuery(