diff --git a/kedro-airflow/kedro_airflow/grouping.py b/kedro-airflow/kedro_airflow/grouping.py index 581c84f41..1604e3b64 100644 --- a/kedro-airflow/kedro_airflow/grouping.py +++ b/kedro-airflow/kedro_airflow/grouping.py @@ -36,6 +36,7 @@ def group_memory_nodes(catalog: DataCatalog, pipeline: Pipeline): together. Essentially, this computes connected components over the graph of nodes connected by MemoryDatasets. """ + # get all memory datasets in the pipeline memory_datasets = get_memory_datasets(catalog, pipeline) @@ -58,8 +59,12 @@ def group_memory_nodes(catalog: DataCatalog, pipeline: Pipeline): sequence_id = None for i in node.inputs: if i in memory_datasets: - assert sequence_id is None or sequence_id == sequence_map[i] - sequence_id = sequence_map[i] + if sequence_id is None: + sequence_id = sequence_map[i] + else: + # merge sequences + node_sequences[sequence_id].extend(node_sequences[sequence_map[i]]) + node_sequences[sequence_map[i]] = None # Append to map node_sequences[sequence_id].append(node) @@ -73,6 +78,7 @@ def group_memory_nodes(catalog: DataCatalog, pipeline: Pipeline): nodes = { node_sequence_name(node_sequence): node_sequence for node_sequence in node_sequences + if node_sequence is not None } # Inverted mapping