Skip to content

Commit f647f8b

Browse files
authored
Ignore partialled arguments of functions. (#4)
1 parent ee35f56 commit f647f8b

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

src/dags/dag.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import inspect
23
import textwrap
34

@@ -24,7 +25,7 @@ def concatenate_functions(
2425
Functions that are not required to produce the targets will simply be ignored.
2526
2627
The arguments of the combined function are all arguments of relevant functions
27-
that are not themselves function names.
28+
that are not themselves function names, in alphabetical order.
2829
2930
Args:
3031
functions (dict or list): Dict or list of functions. If a list, the function
@@ -152,14 +153,22 @@ def _create_complete_dag(functions):
152153
153154
"""
154155
functions_arguments_dict = {
155-
name: list(inspect.signature(function).parameters)
156-
for name, function in functions.items()
156+
name: _get_free_arguments(function) for name, function in functions.items()
157157
}
158158
dag = nx.DiGraph(functions_arguments_dict).reverse()
159159

160160
return dag
161161

162162

163+
def _get_free_arguments(func):
164+
arguments = list(inspect.signature(func).parameters)
165+
if isinstance(func, functools.partial):
166+
non_free = set(func.args) | set(func.keywords)
167+
arguments = [arg for arg in arguments if arg not in non_free]
168+
169+
return arguments
170+
171+
163172
def _limit_dag_to_targets_and_their_ancestors(dag, targets):
164173
"""Limit DAG to targets and their ancestors.
165174
@@ -216,7 +225,7 @@ def _create_execution_info(functions, dag):
216225
out = {}
217226
for node in nx.topological_sort(dag):
218227
if node in functions:
219-
arguments = list(inspect.signature(functions[node]).parameters)
228+
arguments = _get_free_arguments(functions[node])
220229
info = {}
221230
info["func"] = functions[node]
222231
info["arguments"] = arguments

tests/test_dag.py

+17
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
from functools import partial
23

34
import pytest
45
from dags.dag import concatenate_functions
@@ -119,3 +120,19 @@ def test_concatenate_functions_with_aggregation_via_or():
119120
aggregator=lambda a, b: a or b,
120121
)
121122
assert aggregated()
123+
124+
125+
def test_partialled_argument_is_ignored():
126+
def f(a, b):
127+
return a + b
128+
129+
def g(f, c):
130+
return f + c
131+
132+
concatenated = concatenate_functions(
133+
functions={"f": partial(f, b=2), "g": g},
134+
targets="g",
135+
)
136+
137+
assert list(inspect.signature(concatenated).parameters) == ["a", "c"]
138+
assert concatenated(1, 3) == 6

0 commit comments

Comments
 (0)