Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,39 @@ def _add_node(self, node: "ray.dag.DAGNode") -> None:
self.dag_node_to_idx[node] = idx
self.counter += 1

def find_unused_input_attributes(
self, output_node: "ray.dag.MultiOutputNode", input_attributes: Set[str]
) -> Set[str]:
"""
This is the helper function to handle_unused_attributes.
Traverse the DAG backwards from the output node to find unused attributes.
Args:
output_node: The starting node for the traversal.
input_attributes: A set of attributes accessed by the InputNode.
Returns:
A set:
- unused_attributes: A set of attributes that are unused.
"""
from ray.dag import InputAttributeNode

used_attributes = set()
visited_nodes = set()
stack = [output_node]

while stack:
current_node = stack.pop()
if current_node in visited_nodes:
continue
visited_nodes.add(current_node)

if isinstance(current_node, InputAttributeNode):
used_attributes.add(current_node.key)

stack.extend(current_node._upstream_nodes)

unused_attributes = input_attributes - used_attributes
return unused_attributes

def _preprocess(self) -> None:
"""Before compiling, preprocess the DAG to build an index from task to
upstream and downstream tasks, and to set the input and output node(s)
Expand All @@ -969,11 +1002,16 @@ def _preprocess(self) -> None:
nccl_actors_p2p: Set["ray.actor.ActorHandle"] = set()
collective_ops: Set[_CollectiveOperation] = set()

input_attributes: Set[str] = set()
# Find the input node and input attribute nodes in the DAG.
for idx, task in self.idx_to_task.items():
if isinstance(task.dag_node, InputNode):
assert self.input_task_idx is None, "More than one InputNode found"
self.input_task_idx = idx
# handle_unused_attributes:
# Save input attributes in a set.
input_node = task.dag_node
input_attributes.update(input_node.input_attribute_nodes.keys())
elif isinstance(task.dag_node, InputAttributeNode):
self.input_attr_task_idxs.append(idx)

Expand Down Expand Up @@ -1153,6 +1191,24 @@ def _preprocess(self) -> None:
# Add all readers to the NCCL actors of P2P.
nccl_actors_p2p.add(downstream_actor_handle)

# handle_unused_attributes:
unused_attributes = self.find_unused_input_attributes(
output_node, input_attributes
)

if unused_attributes:
unused_attributes_str = ", ".join(str(key) for key in unused_attributes)
input_attributes_str = ", ".join(str(key) for key in input_attributes)
unused_phrase = "is unused" if len(unused_attributes) == 1 else "are unused"

raise ValueError(
"Compiled Graph expects input to be accessed "
f"using all of attributes {input_attributes_str}, "
f"but {unused_attributes_str} {unused_phrase}. "
"Ensure all input attributes are used and contribute "
"to the computation of the Compiled Graph output."
)

# Collect all leaf nodes.
leaf_nodes: DAGNode = []
for idx, task in self.idx_to_task.items():
Expand Down
30 changes: 30 additions & 0 deletions python/ray/dag/tests/test_input_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,36 @@ def _apply_recursive_with_counter(self, fn):
DAGNode.apply_recursive = original_apply_recursive


def test_missing_input_node():
@ray.remote
class Actor:
def __init__(self):
pass

def f(self, input):
return input

def combine(self, a, b):
return a + b

actor = Actor.remote()

with ray.dag.InputNode() as dag_input:
input0, input1, input2 = dag_input[0], dag_input[1], dag_input[2]
_ = actor.f.bind(input1)
dag = actor.combine.bind(input0, input2)

with pytest.raises(
ValueError,
match="Compiled Graph expects input to be accessed "
"using all of attributes 0, 1, 2, "
"but 1 is unused. "
"Ensure all input attributes are used and contribute "
"to the computation of the Compiled Graph output.",
):
dag.experimental_compile()


if __name__ == "__main__":
import sys

Expand Down
Loading