Skip to content

Commit ab7f8a2

Browse files
Add option to get the DAG. (#9)
1 parent 8794c21 commit ab7f8a2

File tree

6 files changed

+183
-34
lines changed

6 files changed

+183
-34
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ target/
8181
profile_default/
8282
ipython_config.py
8383

84+
# VS Code
85+
.vscode
86+
8487
# pyenv
8588
.python-version
8689

.pre-commit-config.yaml

+8-8
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ repos:
66
- id: debug-statements
77
- id: end-of-file-fixer
88
- repo: https://github.com/asottile/reorder_python_imports
9-
rev: v3.1.0
9+
rev: v3.8.3
1010
hooks:
1111
- id: reorder-python-imports
1212
types: [python]
@@ -45,12 +45,12 @@ repos:
4545
additional_dependencies: [black==22.3.0]
4646
types: [rst]
4747
- repo: https://github.com/psf/black
48-
rev: 22.3.0
48+
rev: 22.8.0
4949
hooks:
5050
- id: black
5151
types: [python]
5252
- repo: https://github.com/PyCQA/flake8
53-
rev: 4.0.1
53+
rev: 5.0.4
5454
hooks:
5555
- id: flake8
5656
types: [python]
@@ -71,7 +71,7 @@ repos:
7171
Pygments,
7272
]
7373
- repo: https://github.com/PyCQA/doc8
74-
rev: 0.11.2
74+
rev: v1.0.0
7575
hooks:
7676
- id: doc8
7777
- repo: meta
@@ -86,11 +86,11 @@ repos:
8686
args: [--no-build-isolation]
8787
additional_dependencies: [setuptools-scm, toml]
8888
- repo: https://github.com/PyCQA/doc8
89-
rev: 0.11.2
89+
rev: v1.0.0
9090
hooks:
9191
- id: doc8
9292
- repo: https://github.com/asottile/setup-cfg-fmt
93-
rev: v1.20.1
93+
rev: v2.0.0
9494
hooks:
9595
- id: setup-cfg-fmt
9696
- repo: https://github.com/econchick/interrogate
@@ -100,11 +100,11 @@ repos:
100100
args: [-v, --fail-under=20]
101101
exclude: ^(tests|docs|setup\.py)
102102
- repo: https://github.com/codespell-project/codespell
103-
rev: v2.1.0
103+
rev: v2.2.1
104104
hooks:
105105
- id: codespell
106106
- repo: https://github.com/asottile/pyupgrade
107-
rev: v2.34.0
107+
rev: v2.38.2
108108
hooks:
109109
- id: pyupgrade
110110
args: [--py37-plus]

CHANGES.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ releases are available on `Anaconda.org
1414
- :gh:`7` improves the examples in the test cases.
1515
- :gh:`10` turns ``targets`` into an optional argument. All variables in the DAG are
1616
returned by default.
17-
17+
- :gh:`9` Add function to return the DAG. Check for cycles in DAG.
18+
(:ghuser:`ChristianZimpelmann`)
1819

1920
0.2.1 - 2022-03-29
2021
------------------

setup.cfg

-4
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@ classifiers =
1717
Operating System :: POSIX
1818
Programming Language :: Python :: 3
1919
Programming Language :: Python :: 3 :: Only
20-
Programming Language :: Python :: 3.7
21-
Programming Language :: Python :: 3.8
22-
Programming Language :: Python :: 3.9
23-
Programming Language :: Python :: 3.10
2420
Topic :: Utilities
2521

2622
[options]

src/dags/dag.py

+135-21
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,16 @@ def concatenate_functions(
2424
2525
Functions that are not required to produce the targets will simply be ignored.
2626
27-
The arguments of the combined function are all arguments of relevant functions
28-
that are not themselves function names, in alphabetical order.
27+
The arguments of the combined function are all arguments of relevant functions that
28+
are not themselves function names, in alphabetical order.
2929
3030
Args:
3131
functions (dict or list): Dict or list of functions. If a list, the function
32-
name is inferred from the __name__ attribute of the entries. If a dict,
33-
the name of the function is set to the dictionary key.
34-
targets (str | None): Name of the function that produces the target or list of
35-
such function names. If the value is `None`, all variables are returned.
32+
name is inferred from the __name__ attribute of the entries. If a dict, the
33+
name of the function is set to the dictionary key.
34+
targets (str or list or None): Name of the function that produces the target or
35+
list of such function names. If the value is `None`, all variables are
36+
returned.
3637
return_type (str): One of "tuple", "list", "dict". This is ignored if the
3738
targets are a single string or if an aggregator is provided.
3839
aggregator (callable or None): Binary reduction function that is used to
@@ -45,19 +46,99 @@ def concatenate_functions(
4546
function: A function that produces targets when called with suitable arguments.
4647
4748
"""
48-
_functions = _harmonize_functions(functions)
49-
_targets = _harmonize_targets(targets, list(_functions))
50-
_fail_if_targets_have_wrong_types(_targets)
51-
_fail_if_functions_are_missing(_functions, _targets)
5249

50+
# Create the DAG.
51+
dag = create_dag(functions, targets)
52+
53+
# Build combined function.
54+
out = _create_combined_function_from_dag(
55+
dag, functions, targets, return_type, aggregator, enforce_signature
56+
)
57+
58+
return out
59+
60+
61+
def create_dag(functions, targets):
62+
"""Build a directed acyclic graph (DAG) from functions.
63+
64+
Functions can depend on the output of other functions as inputs, as long as the
65+
dependencies can be described by a directed acyclic graph (DAG).
66+
67+
Functions that are not required to produce the targets will simply be ignored.
68+
69+
Args:
70+
functions (dict or list): Dict or list of functions. If a list, the function
71+
name is inferred from the __name__ attribute of the entries. If a dict, the
72+
name of the function is set to the dictionary key.
73+
targets (str or list or None): Name of the function that produces the target or
74+
list of such function names. If the value is `None`, all variables are
75+
returned.
76+
77+
Returns:
78+
dag: the DAG (as networkx.DiGraph object)
79+
80+
"""
81+
# Harmonize and check arguments.
82+
_functions, _targets = _harmonize_and_check_functions_and_targets(
83+
functions, targets
84+
)
85+
86+
# Create the DAG
5387
_raw_dag = _create_complete_dag(_functions)
54-
_dag = _limit_dag_to_targets_and_their_ancestors(_raw_dag, _targets)
55-
_arglist = _create_arguments_of_concatenated_function(_functions, _dag)
56-
_exec_info = _create_execution_info(_functions, _dag)
88+
dag = _limit_dag_to_targets_and_their_ancestors(_raw_dag, _targets)
89+
90+
# Check if there are cycles in the DAG
91+
_fail_if_dag_contains_cycle(dag)
92+
93+
return dag
94+
95+
96+
def _create_combined_function_from_dag(
97+
dag,
98+
functions,
99+
targets,
100+
return_type="tuple",
101+
aggregator=None,
102+
enforce_signature=True,
103+
):
104+
"""Create combined function which allows to execute a complete directed acyclic
105+
graph (DAG) in one function call.
106+
107+
The arguments of the combined function are all arguments of relevant functions that
108+
are not themselves function names, in alphabetical order.
109+
110+
Args:
111+
dag (networkx.DiGraph): a DAG of functions
112+
functions (dict or list): Dict or list of functions. If a list, the function
113+
name is inferred from the __name__ attribute of the entries. If a dict, the
114+
name of the function is set to the dictionary key.
115+
targets (str or list or None): Name of the function that produces the target or
116+
list of such function names. If the value is `None`, all variables are
117+
returned.
118+
return_type (str): One of "tuple", "list", "dict". This is ignored if the
119+
targets are a single string or if an aggregator is provided.
120+
aggregator (callable or None): Binary reduction function that is used to
121+
aggregate the targets into a single target.
122+
enforce_signature (bool): If True, the signature of the concatenated function
123+
is enforced. Otherwise it is only provided for introspection purposes.
124+
Enforcing the signature has a small runtime overhead.
125+
126+
Returns:
127+
function: A function that produces targets when called with suitable arguments.
128+
129+
"""
130+
# Harmonize and check arguments.
131+
_functions, _targets = _harmonize_and_check_functions_and_targets(
132+
functions, targets
133+
)
134+
135+
_arglist = _create_arguments_of_concatenated_function(_functions, dag)
136+
_exec_info = _create_execution_info(_functions, dag)
57137
_concatenated = _create_concatenated_function(
58138
_exec_info, _arglist, _targets, enforce_signature
59139
)
60140

141+
# Return function in specified format.
61142
if isinstance(targets, str) or (aggregator is not None and len(_targets) == 1):
62143
out = single_output(_concatenated)
63144
elif aggregator is not None:
@@ -70,7 +151,7 @@ def concatenate_functions(
70151
out = dict_output(_concatenated, keys=_targets)
71152
else:
72153
raise ValueError(
73-
f"Invalid return type {return_type}. Must be 'list', 'tuple', or 'dict'. "
154+
f"Invalid return type {return_type}. Must be 'list', 'tuple', or 'dict'. "
74155
f"You provided {return_type}."
75156
)
76157

@@ -91,13 +172,14 @@ def get_ancestors(functions, targets, include_targets=False):
91172
set: The ancestors
92173
93174
"""
94-
_functions = _harmonize_functions(functions)
95-
_targets = _harmonize_targets(targets, list(_functions))
96-
_fail_if_targets_have_wrong_types(_targets)
97-
_fail_if_functions_are_missing(_functions, _targets)
98175

99-
raw_dag = _create_complete_dag(_functions)
100-
dag = _limit_dag_to_targets_and_their_ancestors(raw_dag, _targets)
176+
# Harmonize and check arguments.
177+
_functions, _targets = _harmonize_and_check_functions_and_targets(
178+
functions, targets
179+
)
180+
181+
# Create the DAG.
182+
dag = create_dag(functions, targets)
101183

102184
ancestors = set()
103185
for target in _targets:
@@ -107,6 +189,29 @@ def get_ancestors(functions, targets, include_targets=False):
107189
return ancestors
108190

109191

192+
def _harmonize_and_check_functions_and_targets(functions, targets):
193+
"""Harmonize the type of specified functions and targets and do some checks.
194+
195+
Args:
196+
functions (dict or list): Dict or list of functions. If a list, the function
197+
name is inferred from the __name__ attribute of the entries. If a dict, the
198+
name of the function is set to the dictionary key.
199+
targets (str or list): Name of the function that produces the target or list of
200+
such function names.
201+
202+
Returns:
203+
functions_harmonized: harmonized functions
204+
targets_harmonized: harmonized targets
205+
206+
"""
207+
functions_harmonized = _harmonize_functions(functions)
208+
targets_harmonized = _harmonize_targets(targets, list(functions_harmonized))
209+
_fail_if_targets_have_wrong_types(targets_harmonized)
210+
_fail_if_functions_are_missing(functions_harmonized, targets_harmonized)
211+
212+
return functions_harmonized, targets_harmonized
213+
214+
110215
def _harmonize_functions(functions):
111216
if isinstance(functions, (list, tuple)):
112217
functions = {func.__name__: func for func in functions}
@@ -141,6 +246,15 @@ def _fail_if_functions_are_missing(functions, targets):
141246
return functions, targets
142247

143248

249+
def _fail_if_dag_contains_cycle(dag):
250+
"""Check for cycles in DAG"""
251+
cycles = list(nx.simple_cycles(dag))
252+
253+
if len(cycles) > 0:
254+
formatted = _format_list_linewise(cycles)
255+
raise ValueError(f"The DAG contains one or more cycles:\n{formatted}")
256+
257+
144258
def _create_complete_dag(functions):
145259
"""Create the complete DAG.
146260
@@ -275,7 +389,7 @@ def concatenated(*args, **kwargs):
275389

276390

277391
def _format_list_linewise(list_):
278-
formatted_list = '",\n "'.join(list_)
392+
formatted_list = '",\n "'.join([str(c) for c in list_])
279393
return textwrap.dedent(
280394
"""
281395
[

tests/test_dag.py

+35
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
from dags.dag import concatenate_functions
6+
from dags.dag import create_dag
67
from dags.dag import get_ancestors
78

89

@@ -22,6 +23,14 @@ def _unrelated(working_hours): # noqa: U100
2223
raise NotImplementedError()
2324

2425

26+
def _leisure_cycle(working_hours, _utility):
27+
return 24 - working_hours + _utility
28+
29+
30+
def _consumption_cycle(working_hours, wage, _utility):
31+
return wage * working_hours + _utility
32+
33+
2534
def _complete_utility(wage, working_hours, leisure_weight):
2635
"""The function that we try to generate dynamically."""
2736
leis = _leisure(working_hours)
@@ -157,3 +166,29 @@ def g(f, d):
157166

158167
assert list(inspect.signature(concatenated).parameters) == ["c", "d"]
159168
assert concatenated(3, 4) == 10
169+
170+
171+
@pytest.mark.parametrize(
172+
"funcs",
173+
[
174+
{
175+
"_utility": _utility,
176+
"_leisure": _leisure,
177+
"_consumption": _consumption_cycle,
178+
},
179+
{
180+
"_utility": _utility,
181+
"_leisure": _leisure_cycle,
182+
"_consumption": _consumption_cycle,
183+
},
184+
],
185+
)
186+
def test_fail_if_cycle_in_dag(funcs):
187+
with pytest.raises(
188+
ValueError,
189+
match="The DAG contains one or more cycles:",
190+
):
191+
create_dag(
192+
functions=funcs,
193+
targets=["_utility"],
194+
)

0 commit comments

Comments
 (0)