Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,52 @@ def _shutdown_all_compiled_dags():
_compiled_dags = weakref.WeakValueDictionary()


def _check_unused_dag_input_attributes(
output_node: "ray.dag.MultiOutputNode", input_attributes: Set[str]
) -> Set[str]:
"""
Helper function to check that all input attributes are used in the DAG.
For example, if the user creates an input attribute by calling
InputNode()["x"], we ensure that there is a path from the
InputAttributeNode corresponding to "x" to the DAG's output. If an
input attribute is not used, throw an error.

Args:
output_node: The starting node for the traversal.
input_attributes: A set of attributes accessed by the InputNode.
"""
from ray.dag import InputAttributeNode

used_attributes = set()
visited_nodes = set()
stack: List["ray.dag.DAGNode"] = [output_node]

while stack:
current_node = stack.pop()
if current_node in visited_nodes:
continue
visited_nodes.add(current_node)

if isinstance(current_node, InputAttributeNode):
used_attributes.add(current_node.key)

stack.extend(current_node._upstream_nodes)

unused_attributes = input_attributes - used_attributes
if unused_attributes:
unused_attributes_str = ", ".join(str(key) for key in unused_attributes)
input_attributes_str = ", ".join(str(key) for key in input_attributes)
unused_phrase = "is unused" if len(unused_attributes) == 1 else "are unused"

raise ValueError(
"Compiled Graph expects input to be accessed "
f"using all of attributes {input_attributes_str}, "
f"but {unused_attributes_str} {unused_phrase}. "
"Ensure all input attributes are used and contribute "
"to the computation of the Compiled Graph output."
)


@DeveloperAPI
def do_allocate_channel(
self,
Expand Down Expand Up @@ -948,11 +994,16 @@ def _preprocess(self) -> None:
nccl_actors_p2p: Set["ray.actor.ActorHandle"] = set()
collective_ops: Set[_CollectiveOperation] = set()

input_attributes: Set[str] = set()
# Find the input node and input attribute nodes in the DAG.
for idx, task in self.idx_to_task.items():
if isinstance(task.dag_node, InputNode):
assert self.input_task_idx is None, "More than one InputNode found"
self.input_task_idx = idx
# handle_unused_attributes:
# Save input attributes in a set.
input_node = task.dag_node
input_attributes.update(input_node.input_attribute_nodes.keys())
elif isinstance(task.dag_node, InputAttributeNode):
self.input_attr_task_idxs.append(idx)

Expand Down Expand Up @@ -1132,6 +1183,10 @@ def _preprocess(self) -> None:
# Add all readers to the NCCL actors of P2P.
nccl_actors_p2p.add(downstream_actor_handle)

# Check that all specified input attributes, e.g., InputNode()["x"],
# are used in the DAG.
_check_unused_dag_input_attributes(output_node, input_attributes)

# Collect all leaf nodes.
leaf_nodes: DAGNode = []
for idx, task in self.idx_to_task.items():
Expand Down
30 changes: 30 additions & 0 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2996,6 +2996,36 @@ def g(self, x, y, z=1):
_ = worker.g.bind(inp)


def test_missing_input_node():
@ray.remote
class Actor:
def __init__(self):
pass

def f(self, input):
return input

def add(self, a, b):
return a + b

actor = Actor.remote()

with ray.dag.InputNode() as dag_input:
input0, input1, input2 = dag_input[0], dag_input[1], dag_input[2]
_ = actor.f.bind(input1)
dag = actor.add.bind(input0, input2)

with pytest.raises(
ValueError,
match="Compiled Graph expects input to be accessed "
"using all of attributes 0, 1, 2, "
"but 1 is unused. "
"Ensure all input attributes are used and contribute "
"to the computation of the Compiled Graph output.",
):
dag.experimental_compile()


@pytest.mark.skipif(sys.platform == "win32", reason="Sigint not supported on Windows")
def test_sigint_get_dagref(ray_start_cluster):
driver_script = """
Expand Down
Loading