Skip to content

Commit ce90c1e

Browse files
hapixHapixrtmiz
authored
Add workflow provenance tracking with yProv4WFS integration (#98)
Co-authored-by: Hapix <[email protected]> Co-authored-by: Gerald Walter Irsiegler <[email protected]> Co-authored-by: Gerald Walter Irsiegler <[email protected]>
1 parent b6f367c commit ce90c1e

File tree

3 files changed

+216
-5
lines changed

3 files changed

+216
-5
lines changed

openeo_pg_parser_networkx/graph.py

Lines changed: 146 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,27 @@
44

55
sys.setrecursionlimit(16385) # Necessary when parsing really big graphs
66
import functools
7+
8+
## For yprov4wfs
79
import json
810
import logging
11+
import os
912
import random
13+
import uuid
1014
from collections import namedtuple
1115
from dataclasses import dataclass, field
12-
from functools import partial
16+
from datetime import datetime
17+
from functools import partial, wraps
1318
from pathlib import Path
1419
from typing import Callable, Optional, Union
1520
from uuid import UUID
1621

22+
import dask.array as da
1723
import networkx as nx
24+
import xarray as xr
25+
from yprov4wfs.datamodel.data import Data
26+
from yprov4wfs.datamodel.task import Task
27+
from yprov4wfs.datamodel.workflow import Workflow
1828

1929
from openeo_pg_parser_networkx.pg_schema import (
2030
PGEdgeType,
@@ -70,6 +80,10 @@ def __repr__(self):
7080

7181
class OpenEOProcessGraph:
7282
def __init__(self, pg_data: dict):
83+
# Make a workflow object
84+
self.workflow = Workflow('openeo_workflow', 'OpenEO Workflow')
85+
self.workflow._engineWMS = "Openeo-Workflow"
86+
self.workflow._level = "0"
7387
self.G = nx.DiGraph()
7488

7589
# Save pg_data for resolving later on
@@ -377,7 +391,7 @@ def node_callable(*args, parent_callables, named_parameters=None, **kwargs):
377391
# The node needs to first call all its parents, so that results are prepopulated in the results_cache
378392
for func in parent_callables:
379393
func(*args, named_parameters=named_parameters, **kwargs)
380-
394+
cache_users = {}
381395
try:
382396
# If this node has already been computed once, just grab that result from the results_cache instead of recomputing it.
383397
# This cannot be done for aggregated data as the wrapped function has to be called multiple times with different values.
@@ -411,13 +425,108 @@ def node_callable(*args, parent_callables, named_parameters=None, **kwargs):
411425
kwargs[arg_sub.arg_name] = self.G.nodes(data=True)[node][
412426
"resolved_kwargs"
413427
].__getitem__(arg_sub.arg_name)
414-
415-
result = prebaked_process_impl(
428+
# Make a dictionary from the nodes that uses the outputs of the other nodes
429+
if source_node not in cache_users:
430+
cache_users[source_node] = []
431+
cache_users[source_node].append(node)
432+
# Make the tasks
433+
task = Task(node, node_with_data['process_id'])
434+
# result = prebaked_process_impl(
435+
# *args, named_parameters=named_parameters, **kwargs
436+
# )
437+
result, execution_data = self.profile_function(prebaked_process_impl)(
416438
*args, named_parameters=named_parameters, **kwargs
417439
)
418440

441+
if isinstance(result, xr.DataArray):
442+
processed_result = {
443+
"entity_type": "xarray.DataArray",
444+
"info": {
445+
"shape": result.shape,
446+
"dimensions": list(result.dims),
447+
# "attributes": result.attrs,
448+
"dtype": str(result.dtype),
449+
},
450+
}
451+
452+
elif isinstance(result, da.Array):
453+
processed_result = {
454+
"entity_type": "dask.Array",
455+
"info": {
456+
"shape": result.shape,
457+
"dtype": str(result.dtype),
458+
"chunk_size": result.chunksize,
459+
"chunk_type": type(result._meta).__name__,
460+
},
461+
}
462+
else:
463+
processed_result = {}
464+
processed_result['info'] = result
465+
processed_result['entity_type'] = type(result).__name__
466+
if result is not None:
467+
results_cache_node = Data(
468+
str(uuid.uuid4()), processed_result['entity_type']
469+
)
470+
results_cache_node._info = processed_result['info']
471+
task.add_output(results_cache_node)
472+
self.workflow.add_data(results_cache_node)
419473
results_cache[node] = result
420474

475+
# Loading data info
476+
process_id = node_with_data.get("process_id")
477+
resolved_kwargs = node_with_data.get("resolved_kwargs", {})
478+
479+
if process_id in ("load_stac", "load_collection"):
480+
key = "url" if process_id == "load_stac" else "id"
481+
raw_source = resolved_kwargs.get(key, "")
482+
data_source = raw_source.split("\\")[-1]
483+
484+
data_src = Data(str(uuid.uuid4()), data_source)
485+
# Extract extra information
486+
if process_id == "load_stac":
487+
data_src._info = resolved_kwargs
488+
489+
task._start_time = execution_data['start_time']
490+
task._end_time = execution_data['end_time']
491+
task._status = execution_data['task_status']
492+
task._level = "1"
493+
494+
# This is just for load stac ( for the temporary usage)
495+
if node_with_data['process_id'] in ["load_stac", "load_collection"]:
496+
task.add_input(data_src)
497+
498+
self.workflow.add_task(task)
499+
500+
if cache_users:
501+
for source_node, target_node in cache_users.items():
502+
output_data_from_source = (
503+
self.workflow.get_task_by_id(source_node)._outputs[0]._id
504+
)
505+
for target in target_node:
506+
self.workflow.get_task_by_id(target).add_input(
507+
self.workflow.get_data_by_id(output_data_from_source)
508+
)
509+
510+
edges = [
511+
{"source": source, "target": target, "type": data["reference_type"]}
512+
for source, target, data in self.G.edges(node, data=True)
513+
]
514+
515+
for edge in edges:
516+
self.workflow.get_task_by_id(edge['source']).set_next(
517+
self.workflow.get_task_by_id(edge['target'])
518+
)
519+
520+
if node == self.result_node:
521+
self.workflow._status = "Ok"
522+
523+
# To save the provenance
524+
# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
525+
# save_path = os.path.join(os.getcwd(), f"run_{timestamp}")
526+
# print(f"Provenance file saved to: {save_path}")
527+
# os.makedirs(save_path, exist_ok=True)
528+
# self.workflow.prov_to_json(directory_path=save_path)
529+
421530
return result
422531

423532
return partial(node_callable, parent_callables=parent_callables)
@@ -516,3 +625,36 @@ def plot(self, reverse=False):
516625

517626
if reverse:
518627
self.G = self.G.reverse()
628+
629+
@staticmethod
630+
def profile_function(func):
631+
"""Decorator to track execution performance and return both result and profiling data.
632+
In the case in the future there will be some more metrics of intrest (like cpu and memory
633+
usage) to extract."""
634+
635+
@wraps(func)
636+
def wrapper(*args, named_parameters, **kwargs):
637+
start_dt = datetime.now()
638+
start_timestamp = start_dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
639+
640+
try:
641+
result = func(*args, named_parameters, **kwargs)
642+
status = "Ok"
643+
except Exception as e:
644+
result = str(e)
645+
status = f"Error: {result[:70]}"
646+
647+
end_dt = datetime.now()
648+
end_timestamp = end_dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
649+
execution_time = (end_dt - start_dt).total_seconds()
650+
execution_data = {
651+
# "function": func.__name__,
652+
"task_status": status,
653+
"start_time": start_timestamp,
654+
"end_time": end_timestamp,
655+
"execution_time_sec": round(execution_time, 4),
656+
}
657+
# Return both the result and profiling data
658+
return result, execution_data
659+
660+
return wrapper

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "openeo-pg-parser-networkx"
3-
version = "2024.10.1"
3+
version = "2025.5.1"
44

55
description = "Parse OpenEO process graphs from JSON to traversible Python objects."
66
authors = ["Lukas Weidenholzer <[email protected]>", "Sean Hoyal <[email protected]>", "Valentina Hutter <[email protected]>", "Gerald Irsiegler <[email protected]>"]
@@ -33,6 +33,9 @@ numpy = "^1.20.3"
3333
pendulum = "^2.1.2"
3434
matplotlib = { version = "^3.7.1", optional = true }
3535
traitlets = "<=5.9.0"
36+
yprov4wfs = ">=0.0.8"
37+
xarray = ">=2022.11.0,<=2024.3.0"
38+
dask = ">=2023.4.0,<2025.2.0"
3639

3740
[tool.poetry.group.dev.dependencies]
3841
matplotlib = "^3.7.1"

tests/test_pg_provenance.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import json
2+
3+
import pytest
4+
from yprov4wfs.datamodel.data import Data
5+
from yprov4wfs.datamodel.task import Task
6+
from yprov4wfs.datamodel.workflow import Workflow
7+
8+
from openeo_pg_parser_networkx import OpenEOProcessGraph
9+
from openeo_pg_parser_networkx.process_registry import Process, ProcessRegistry
10+
11+
12+
def test_execute_returns_result_and_workflow(process_graph_path):
13+
"""
14+
Test that OpenEOProcessGraph returns result and workflow correctly
15+
for all sample graphs, using a mock registry based on required processes.
16+
"""
17+
18+
with open(process_graph_path) as f:
19+
flat_pg = json.load(f)
20+
21+
pg = OpenEOProcessGraph(flat_pg)
22+
23+
mock_registry = ProcessRegistry(wrap_funcs=[])
24+
for process_id in pg.required_processes:
25+
mock_registry[process_id] = Process(
26+
spec={},
27+
implementation=lambda *args, **kwargs: args[0] if args else None,
28+
namespace="predefined",
29+
)
30+
31+
# Create callable and execute
32+
result = pg.to_callable(mock_registry)()
33+
workflow = pg.workflow
34+
35+
# Assertions
36+
assert result is not None, "Result should not be None"
37+
assert workflow is not None, "Workflow should not be None"
38+
assert isinstance(
39+
workflow, Workflow
40+
), "Workflow should be a yprov4wfs.Workflow instance"
41+
assert len(workflow._tasks) > 0, "Workflow should have at least one task"
42+
assert workflow._status in ["Ok", "Error"], "Workflow status should be Ok or Error"
43+
44+
# Test the tasks
45+
assert isinstance(workflow._tasks, list), "Workflow._tasks should be a list"
46+
for task in workflow._tasks:
47+
# Each task should be a Task instance
48+
assert isinstance(
49+
task, Task
50+
), f"Each task should be a Task instance but got {type(task)}"
51+
assert hasattr(task, "_id"), "Task must have an _id"
52+
assert hasattr(task, "_name"), "Task must have a _name"
53+
assert hasattr(task, "_start_time"), "Task must have a start_time"
54+
assert hasattr(task, "_end_time"), "Task must have an end_time"
55+
assert hasattr(task, "_status"), "Task must have a status"
56+
assert hasattr(task, "_inputs"), "Task must have _inputs"
57+
assert hasattr(task, "_outputs"), "Task must have _outputs"
58+
59+
# Test the data
60+
assert isinstance(workflow._data, list), "Workflow._data should be a list"
61+
for data in workflow._data:
62+
assert isinstance(
63+
data, Data
64+
), f"Each data node should be a Data instance but got {type(data)}"
65+
assert hasattr(data, "_id"), "Data must have an _id"
66+
assert hasattr(data, "_name"), "Data must have a _name"

0 commit comments

Comments
 (0)