Skip to content

Commit b02892a

Browse files
LydiaXwQstephanie-wang
authored andcommitted
[core][compiled graphs] Check for unused input attributes (ray-project#49382)
If a user forgets to use a DAG's input attributes, it can lead to silent errors. This PR adds a check to make sure that there is a path between all DAG input attributes and the DAG's final output. ## Related issue number Closes ray-project#47165 --------- Signed-off-by: Lydia <[email protected]> Signed-off-by: Stephanie wang <[email protected]> Co-authored-by: Stephanie wang <[email protected]>
1 parent 70f9463 commit b02892a

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

python/ray/dag/compiled_dag_node.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,52 @@ def _shutdown_all_compiled_dags():
9999
_compiled_dags = weakref.WeakValueDictionary()
100100

101101

102+
def _check_unused_dag_input_attributes(
103+
output_node: "ray.dag.MultiOutputNode", input_attributes: Set[str]
104+
) -> Set[str]:
105+
"""
106+
Helper function to check that all input attributes are used in the DAG.
107+
For example, if the user creates an input attribute by calling
108+
InputNode()["x"], we ensure that there is a path from the
109+
InputAttributeNode corresponding to "x" to the DAG's output. If an
110+
input attribute is not used, throw an error.
111+
112+
Args:
113+
output_node: The starting node for the traversal.
114+
input_attributes: A set of attributes accessed by the InputNode.
115+
"""
116+
from ray.dag import InputAttributeNode
117+
118+
used_attributes = set()
119+
visited_nodes = set()
120+
stack: List["ray.dag.DAGNode"] = [output_node]
121+
122+
while stack:
123+
current_node = stack.pop()
124+
if current_node in visited_nodes:
125+
continue
126+
visited_nodes.add(current_node)
127+
128+
if isinstance(current_node, InputAttributeNode):
129+
used_attributes.add(current_node.key)
130+
131+
stack.extend(current_node._upstream_nodes)
132+
133+
unused_attributes = input_attributes - used_attributes
134+
if unused_attributes:
135+
unused_attributes_str = ", ".join(str(key) for key in unused_attributes)
136+
input_attributes_str = ", ".join(str(key) for key in input_attributes)
137+
unused_phrase = "is unused" if len(unused_attributes) == 1 else "are unused"
138+
139+
raise ValueError(
140+
"Compiled Graph expects input to be accessed "
141+
f"using all of attributes {input_attributes_str}, "
142+
f"but {unused_attributes_str} {unused_phrase}. "
143+
"Ensure all input attributes are used and contribute "
144+
"to the computation of the Compiled Graph output."
145+
)
146+
147+
102148
@DeveloperAPI
103149
def do_allocate_channel(
104150
self,
@@ -948,11 +994,16 @@ def _preprocess(self) -> None:
948994
nccl_actors_p2p: Set["ray.actor.ActorHandle"] = set()
949995
collective_ops: Set[_CollectiveOperation] = set()
950996

997+
input_attributes: Set[str] = set()
951998
# Find the input node and input attribute nodes in the DAG.
952999
for idx, task in self.idx_to_task.items():
9531000
if isinstance(task.dag_node, InputNode):
9541001
assert self.input_task_idx is None, "More than one InputNode found"
9551002
self.input_task_idx = idx
1003+
# handle_unused_attributes:
1004+
# Save input attributes in a set.
1005+
input_node = task.dag_node
1006+
input_attributes.update(input_node.input_attribute_nodes.keys())
9561007
elif isinstance(task.dag_node, InputAttributeNode):
9571008
self.input_attr_task_idxs.append(idx)
9581009

@@ -1132,6 +1183,10 @@ def _preprocess(self) -> None:
11321183
# Add all readers to the NCCL actors of P2P.
11331184
nccl_actors_p2p.add(downstream_actor_handle)
11341185

1186+
# Check that all specified input attributes, e.g., InputNode()["x"],
1187+
# are used in the DAG.
1188+
_check_unused_dag_input_attributes(output_node, input_attributes)
1189+
11351190
# Collect all leaf nodes.
11361191
leaf_nodes: DAGNode = []
11371192
for idx, task in self.idx_to_task.items():

python/ray/dag/tests/experimental/test_accelerated_dag.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2996,6 +2996,36 @@ def g(self, x, y, z=1):
29962996
_ = worker.g.bind(inp)
29972997

29982998

2999+
def test_missing_input_node():
3000+
@ray.remote
3001+
class Actor:
3002+
def __init__(self):
3003+
pass
3004+
3005+
def f(self, input):
3006+
return input
3007+
3008+
def add(self, a, b):
3009+
return a + b
3010+
3011+
actor = Actor.remote()
3012+
3013+
with ray.dag.InputNode() as dag_input:
3014+
input0, input1, input2 = dag_input[0], dag_input[1], dag_input[2]
3015+
_ = actor.f.bind(input1)
3016+
dag = actor.add.bind(input0, input2)
3017+
3018+
with pytest.raises(
3019+
ValueError,
3020+
match="Compiled Graph expects input to be accessed "
3021+
"using all of attributes 0, 1, 2, "
3022+
"but 1 is unused. "
3023+
"Ensure all input attributes are used and contribute "
3024+
"to the computation of the Compiled Graph output.",
3025+
):
3026+
dag.experimental_compile()
3027+
3028+
29993029
@pytest.mark.skipif(sys.platform == "win32", reason="Sigint not supported on Windows")
30003030
def test_sigint_get_dagref(ray_start_cluster):
30013031
driver_script = """

0 commit comments

Comments
 (0)