Skip to content

Commit 9fb6509

Browse files
committed
fix: prevents exception when the pipeline contains multiple nested loops, due to the cycle detection removing the same edge multiple times (ref deepset-ai#8657)
1 parent 3ea128c commit 9fb6509

File tree

3 files changed

+55
-21
lines changed

3 files changed

+55
-21
lines changed

haystack/core/pipeline/base.py

+26-21
Original file line numberDiff line numberDiff line change
@@ -1209,27 +1209,32 @@ def _break_supported_cycles_in_graph(self) -> Tuple[networkx.MultiDiGraph, Dict[
12091209
# sender_comp will be the last element of cycle and receiver_comp will be the first.
12101210
# So if cycle is [1, 2, 3, 4] we would call zip([1, 2, 3, 4], [2, 3, 4, 1]).
12111211
for sender_comp, receiver_comp in zip(cycle, cycle[1:] + cycle[:1]):
1212-
# We get the key and iterate those as we want to edit the graph data while
1213-
# iterating the edges and that would raise.
1214-
# Even though the connection key set in Pipeline.connect() uses only the
1215-
# sockets name we don't have clashes since it's only used to differentiate
1216-
# multiple edges between two nodes.
1217-
edge_keys = list(temp_graph.get_edge_data(sender_comp, receiver_comp).keys())
1218-
for edge_key in edge_keys:
1219-
edge_data = temp_graph.get_edge_data(sender_comp, receiver_comp)[edge_key]
1220-
receiver_socket = edge_data["to_socket"]
1221-
if not receiver_socket.is_variadic and receiver_socket.is_mandatory:
1222-
continue
1223-
1224-
# We found a breakable edge
1225-
sender_socket = edge_data["from_socket"]
1226-
edges_removed[sender_comp].append(sender_socket.name)
1227-
temp_graph.remove_edge(sender_comp, receiver_comp, edge_key)
1228-
1229-
graph_has_cycles = not networkx.is_directed_acyclic_graph(temp_graph)
1230-
if not graph_has_cycles:
1231-
# We removed all the cycles, we can stop
1232-
break
1212+
# for graphs with multiple nested cycles, we need to check if the edge hasn't
1213+
# been previously removed before we try to remove it again
1214+
edge_data = temp_graph.get_edge_data(sender_comp, receiver_comp)
1215+
if edge_data is not None:
1216+
# We get the key and iterate those as we want to edit the graph data while
1217+
# iterating the edges and that would raise.
1218+
# Even though the connection key set in Pipeline.connect() uses only the
1219+
# sockets name we don't have clashes since it's only used to differentiate
1220+
# multiple edges between two nodes.
1221+
edge_keys = list(edge_data.keys())
1222+
1223+
for edge_key in edge_keys:
1224+
edge_data = temp_graph.get_edge_data(sender_comp, receiver_comp)[edge_key]
1225+
receiver_socket = edge_data["to_socket"]
1226+
if not receiver_socket.is_variadic and receiver_socket.is_mandatory:
1227+
continue
1228+
1229+
# We found a breakable edge
1230+
sender_socket = edge_data["from_socket"]
1231+
edges_removed[sender_comp].append(sender_socket.name)
1232+
temp_graph.remove_edge(sender_comp, receiver_comp, edge_key)
1233+
1234+
graph_has_cycles = not networkx.is_directed_acyclic_graph(temp_graph)
1235+
if not graph_has_cycles:
1236+
# We removed all the cycles, we can stop
1237+
break
12331238

12341239
if not graph_has_cycles:
12351240
# We removed all the cycles, nice
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
Prevents the pipeline from raising an exception when there are multiple nested cycles in the graph.

test/core/pipeline/test_pipeline.py

+25
Original file line numberDiff line numberDiff line change
@@ -1581,3 +1581,28 @@ def test__find_receivers_from(self):
15811581
),
15821582
)
15831583
]
1584+
1585+
def test__break_supported_cycles_in_graph(self):
1586+
# the following pipeline has a nested cycle, which is supported by Haystack
1587+
# but was causing an exception to be raised in the _break_supported_cycles_in_graph method
1588+
comp1 = component_class("Comp1", input_types={"value": int}, output_types={"value": int})()
1589+
comp2 = component_class("Comp2", input_types={"value": Variadic[int]}, output_types={"value": int})()
1590+
comp3 = component_class("Comp3", input_types={"value": Variadic[int]}, output_types={"value": int})()
1591+
comp4 = component_class("Comp4", input_types={"value": Optional[int]}, output_types={"value": int})()
1592+
comp5 = component_class("Comp5", input_types={"value": Variadic[int]}, output_types={"value": int})()
1593+
pipe = Pipeline()
1594+
pipe.add_component("comp1", comp1)
1595+
pipe.add_component("comp2", comp2)
1596+
pipe.add_component("comp3", comp3)
1597+
pipe.add_component("comp4", comp4)
1598+
pipe.add_component("comp5", comp5)
1599+
pipe.connect("comp1.value", "comp2.value")
1600+
pipe.connect("comp2.value", "comp3.value")
1601+
pipe.connect("comp3.value", "comp4.value")
1602+
pipe.connect("comp3.value", "comp5.value")
1603+
pipe.connect("comp4.value", "comp5.value")
1604+
pipe.connect("comp4.value", "comp3.value")
1605+
pipe.connect("comp5.value", "comp2.value")
1606+
1607+
# the following call should not raise an exception
1608+
pipe._break_supported_cycles_in_graph()

0 commit comments

Comments
 (0)