@@ -24,15 +24,16 @@ def concatenate_functions(
24
24
25
25
Functions that are not required to produce the targets will simply be ignored.
26
26
27
- The arguments of the combined function are all arguments of relevant functions
28
- that are not themselves function names, in alphabetical order.
27
+ The arguments of the combined function are all arguments of relevant functions that
28
+ are not themselves function names, in alphabetical order.
29
29
30
30
Args:
31
31
functions (dict or list): Dict or list of functions. If a list, the function
32
- name is inferred from the __name__ attribute of the entries. If a dict,
33
- the name of the function is set to the dictionary key.
34
- targets (str | None): Name of the function that produces the target or list of
35
- such function names. If the value is `None`, all variables are returned.
32
+ name is inferred from the __name__ attribute of the entries. If a dict, the
33
+ name of the function is set to the dictionary key.
34
+ targets (str or list or None): Name of the function that produces the target or
35
+ list of such function names. If the value is `None`, all variables are
36
+ returned.
36
37
return_type (str): One of "tuple", "list", "dict". This is ignored if the
37
38
targets are a single string or if an aggregator is provided.
38
39
aggregator (callable or None): Binary reduction function that is used to
@@ -45,19 +46,99 @@ def concatenate_functions(
45
46
function: A function that produces targets when called with suitable arguments.
46
47
47
48
"""
48
- _functions = _harmonize_functions (functions )
49
- _targets = _harmonize_targets (targets , list (_functions ))
50
- _fail_if_targets_have_wrong_types (_targets )
51
- _fail_if_functions_are_missing (_functions , _targets )
52
49
50
+ # Create the DAG.
51
+ dag = create_dag (functions , targets )
52
+
53
+ # Build combined function.
54
+ out = _create_combined_function_from_dag (
55
+ dag , functions , targets , return_type , aggregator , enforce_signature
56
+ )
57
+
58
+ return out
59
+
60
+
61
+ def create_dag (functions , targets ):
62
+ """Build a directed acyclic graph (DAG) from functions.
63
+
64
+ Functions can depend on the output of other functions as inputs, as long as the
65
+ dependencies can be described by a directed acyclic graph (DAG).
66
+
67
+ Functions that are not required to produce the targets will simply be ignored.
68
+
69
+ Args:
70
+ functions (dict or list): Dict or list of functions. If a list, the function
71
+ name is inferred from the __name__ attribute of the entries. If a dict, the
72
+ name of the function is set to the dictionary key.
73
+ targets (str or list or None): Name of the function that produces the target or
74
+ list of such function names. If the value is `None`, all variables are
75
+ returned.
76
+
77
+ Returns:
78
+ dag: the DAG (as networkx.DiGraph object)
79
+
80
+ """
81
+ # Harmonize and check arguments.
82
+ _functions , _targets = _harmonize_and_check_functions_and_targets (
83
+ functions , targets
84
+ )
85
+
86
+ # Create the DAG
53
87
_raw_dag = _create_complete_dag (_functions )
54
- _dag = _limit_dag_to_targets_and_their_ancestors (_raw_dag , _targets )
55
- _arglist = _create_arguments_of_concatenated_function (_functions , _dag )
56
- _exec_info = _create_execution_info (_functions , _dag )
88
+ dag = _limit_dag_to_targets_and_their_ancestors (_raw_dag , _targets )
89
+
90
+ # Check if there are cycles in the DAG
91
+ _fail_if_dag_contains_cycle (dag )
92
+
93
+ return dag
94
+
95
+
96
+ def _create_combined_function_from_dag (
97
+ dag ,
98
+ functions ,
99
+ targets ,
100
+ return_type = "tuple" ,
101
+ aggregator = None ,
102
+ enforce_signature = True ,
103
+ ):
104
+ """Create combined function which allows to execute a complete directed acyclic
105
+ graph (DAG) in one function call.
106
+
107
+ The arguments of the combined function are all arguments of relevant functions that
108
+ are not themselves function names, in alphabetical order.
109
+
110
+ Args:
111
+ dag (networkx.DiGraph): a DAG of functions
112
+ functions (dict or list): Dict or list of functions. If a list, the function
113
+ name is inferred from the __name__ attribute of the entries. If a dict, the
114
+ name of the function is set to the dictionary key.
115
+ targets (str or list or None): Name of the function that produces the target or
116
+ list of such function names. If the value is `None`, all variables are
117
+ returned.
118
+ return_type (str): One of "tuple", "list", "dict". This is ignored if the
119
+ targets are a single string or if an aggregator is provided.
120
+ aggregator (callable or None): Binary reduction function that is used to
121
+ aggregate the targets into a single target.
122
+ enforce_signature (bool): If True, the signature of the concatenated function
123
+ is enforced. Otherwise it is only provided for introspection purposes.
124
+ Enforcing the signature has a small runtime overhead.
125
+
126
+ Returns:
127
+ function: A function that produces targets when called with suitable arguments.
128
+
129
+ """
130
+ # Harmonize and check arguments.
131
+ _functions , _targets = _harmonize_and_check_functions_and_targets (
132
+ functions , targets
133
+ )
134
+
135
+ _arglist = _create_arguments_of_concatenated_function (_functions , dag )
136
+ _exec_info = _create_execution_info (_functions , dag )
57
137
_concatenated = _create_concatenated_function (
58
138
_exec_info , _arglist , _targets , enforce_signature
59
139
)
60
140
141
+ # Return function in specified format.
61
142
if isinstance (targets , str ) or (aggregator is not None and len (_targets ) == 1 ):
62
143
out = single_output (_concatenated )
63
144
elif aggregator is not None :
@@ -70,7 +151,7 @@ def concatenate_functions(
70
151
out = dict_output (_concatenated , keys = _targets )
71
152
else :
72
153
raise ValueError (
73
- f"Invalid return type { return_type } . Must be 'list', 'tuple', or 'dict'. "
154
+ f"Invalid return type { return_type } . Must be 'list', 'tuple', or 'dict'. "
74
155
f"You provided { return_type } ."
75
156
)
76
157
@@ -91,13 +172,14 @@ def get_ancestors(functions, targets, include_targets=False):
91
172
set: The ancestors
92
173
93
174
"""
94
- _functions = _harmonize_functions (functions )
95
- _targets = _harmonize_targets (targets , list (_functions ))
96
- _fail_if_targets_have_wrong_types (_targets )
97
- _fail_if_functions_are_missing (_functions , _targets )
98
175
99
- raw_dag = _create_complete_dag (_functions )
100
- dag = _limit_dag_to_targets_and_their_ancestors (raw_dag , _targets )
176
+ # Harmonize and check arguments.
177
+ _functions , _targets = _harmonize_and_check_functions_and_targets (
178
+ functions , targets
179
+ )
180
+
181
+ # Create the DAG.
182
+ dag = create_dag (functions , targets )
101
183
102
184
ancestors = set ()
103
185
for target in _targets :
@@ -107,6 +189,29 @@ def get_ancestors(functions, targets, include_targets=False):
107
189
return ancestors
108
190
109
191
192
+ def _harmonize_and_check_functions_and_targets (functions , targets ):
193
+ """Harmonize the type of specified functions and targets and do some checks.
194
+
195
+ Args:
196
+ functions (dict or list): Dict or list of functions. If a list, the function
197
+ name is inferred from the __name__ attribute of the entries. If a dict, the
198
+ name of the function is set to the dictionary key.
199
+ targets (str or list): Name of the function that produces the target or list of
200
+ such function names.
201
+
202
+ Returns:
203
+ functions_harmonized: harmonized functions
204
+ targets_harmonized: harmonized targets
205
+
206
+ """
207
+ functions_harmonized = _harmonize_functions (functions )
208
+ targets_harmonized = _harmonize_targets (targets , list (functions_harmonized ))
209
+ _fail_if_targets_have_wrong_types (targets_harmonized )
210
+ _fail_if_functions_are_missing (functions_harmonized , targets_harmonized )
211
+
212
+ return functions_harmonized , targets_harmonized
213
+
214
+
110
215
def _harmonize_functions (functions ):
111
216
if isinstance (functions , (list , tuple )):
112
217
functions = {func .__name__ : func for func in functions }
@@ -141,6 +246,15 @@ def _fail_if_functions_are_missing(functions, targets):
141
246
return functions , targets
142
247
143
248
249
+ def _fail_if_dag_contains_cycle (dag ):
250
+ """Check for cycles in DAG"""
251
+ cycles = list (nx .simple_cycles (dag ))
252
+
253
+ if len (cycles ) > 0 :
254
+ formatted = _format_list_linewise (cycles )
255
+ raise ValueError (f"The DAG contains one or more cycles:\n { formatted } " )
256
+
257
+
144
258
def _create_complete_dag (functions ):
145
259
"""Create the complete DAG.
146
260
@@ -275,7 +389,7 @@ def concatenated(*args, **kwargs):
275
389
276
390
277
391
def _format_list_linewise (list_ ):
278
- formatted_list = '",\n "' .join (list_ )
392
+ formatted_list = '",\n "' .join ([ str ( c ) for c in list_ ] )
279
393
return textwrap .dedent (
280
394
"""
281
395
[
0 commit comments