Skip to content

Commit 8794c21

Browse files
authored
Make targets optional and return all variables if targets=None. (#10)
1 parent 96d9dce commit 8794c21

File tree

4 files changed

+36
-15
lines changed

4 files changed

+36
-15
lines changed

CHANGES.rst

+5-3
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ releases are available on `Anaconda.org
77
<https://anaconda.org/OpenSourceEconomics/dags>`_.
88

99

10-
11-
0.2.2
12-
-----
10+
0.2.2 - 2022-xx-xx
11+
------------------
1312

1413
- :gh:`5` Updates examples used in tests (:ghuser:`janosg`)
14+
- :gh:`7` improves the examples in the test cases.
15+
- :gh:`10` turns ``targets`` into an optional argument. All variables in the DAG are
16+
returned by default.
1517

1618

1719
0.2.1 - 2022-03-29

src/dags/dag.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
def concatenate_functions(
1414
functions,
15-
targets,
15+
targets=None,
1616
return_type="tuple",
1717
aggregator=None,
1818
enforce_signature=True,
@@ -31,8 +31,8 @@ def concatenate_functions(
3131
functions (dict or list): Dict or list of functions. If a list, the function
3232
name is inferred from the __name__ attribute of the entries. If a dict,
3333
the name of the function is set to the dictionary key.
34-
targets (str): Name of the function that produces the target or list of such
35-
function names.
34+
targets (str | None): Name of the function that produces the target or list of
35+
such function names. If the value is `None`, all variables are returned.
3636
return_type (str): One of "tuple", "list", "dict". This is ignored if the
3737
targets are a single string or if an aggregator is provided.
3838
aggregator (callable or None): Binary reduction function that is used to
@@ -45,8 +45,8 @@ def concatenate_functions(
4545
function: A function that produces targets when called with suitable arguments.
4646
4747
"""
48-
_targets = _harmonize_targets(targets)
4948
_functions = _harmonize_functions(functions)
49+
_targets = _harmonize_targets(targets, list(_functions))
5050
_fail_if_targets_have_wrong_types(_targets)
5151
_fail_if_functions_are_missing(_functions, _targets)
5252

@@ -91,8 +91,8 @@ def get_ancestors(functions, targets, include_targets=False):
9191
set: The ancestors
9292
9393
"""
94-
_targets = _harmonize_targets(targets)
9594
_functions = _harmonize_functions(functions)
95+
_targets = _harmonize_targets(targets, list(_functions))
9696
_fail_if_targets_have_wrong_types(_targets)
9797
_fail_if_functions_are_missing(_functions, _targets)
9898

@@ -113,8 +113,10 @@ def _harmonize_functions(functions):
113113
return functions
114114

115115

116-
def _harmonize_targets(targets):
117-
if isinstance(targets, str):
116+
def _harmonize_targets(targets, function_names):
117+
if targets is None:
118+
targets = function_names
119+
elif isinstance(targets, str):
118120
targets = [targets]
119121
return targets
120122

tests/test_dag.py

+21
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,27 @@ def _complete_utility(wage, working_hours, leisure_weight):
3030
return util
3131

3232

33+
def test_concatenate_functions_no_target():
34+
concatenated = concatenate_functions(functions=[_utility, _leisure, _consumption])
35+
36+
calculated_result = concatenated(wage=5, working_hours=8, leisure_weight=2)
37+
38+
expected_utility = _complete_utility(wage=5, working_hours=8, leisure_weight=2)
39+
expected_leisure = _leisure(working_hours=8)
40+
expected_consumption = _consumption(working_hours=8, wage=5)
41+
42+
assert calculated_result == (
43+
expected_utility,
44+
expected_leisure,
45+
expected_consumption,
46+
)
47+
48+
calculated_args = set(inspect.signature(concatenated).parameters)
49+
expected_args = {"leisure_weight", "wage", "working_hours"}
50+
51+
assert calculated_args == expected_args
52+
53+
3354
def test_concatenate_functions_single_target():
3455
concatenated = concatenate_functions(
3556
functions=[_utility, _unrelated, _leisure, _consumption],

tox.ini

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
[tox]
22
envlist = pytest, sphinx
3-
skipsdist = True
4-
skip_missing_interpreters = True
53

64
[testenv]
7-
basepython = python
5+
usedevelop = true
86

97
[testenv:pytest]
10-
setenv =
11-
CONDA_DLL_SEARCH_MODIFICATION_ENABLE = 1
128
conda_channels =
139
conda-forge
1410
nodefaults

0 commit comments

Comments
 (0)