Skip to content

Commit 20bf682

Browse files
authored
[core][rdt] Abort NIXL and allow actor reuse on failed transfers (ray-project#56783)
Signed-off-by: dayshah <[email protected]>
1 parent 89a329c commit 20bf682

File tree

11 files changed

+360
-149
lines changed

11 files changed

+360
-149
lines changed

doc/source/ray-core/direct-transport.rst

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ For example, passing a CUDA ``torch.Tensor`` from one Ray task to another would
1212
*Ray Direct Transport (RDT)* is a new feature that allows Ray to store and pass objects directly between Ray actors.
1313
This feature augments the familiar Ray :class:`ObjectRef <ray.ObjectRef>` API by:
1414

15-
- Keeping GPU data in GPU memory until a transfer is needed
15+
- Keeping GPU data in GPU memory until a transfer is necessary
1616
- Avoiding expensive serialization and copies to and from the Ray object store
1717
- Using efficient data transports like collective communication libraries (`Gloo <https://github.com/pytorch/gloo>`__ or `NCCL <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html>`__) or point-to-point RDMA (via `NVIDIA's NIXL <https://github.com/ai-dynamo/nixl>`__) to transfer data directly between devices, including both CPU and GPUs
1818

1919
.. note::
20-
RDT is currently in **alpha**. Not all Ray Core APIs are supported yet. Future releases may introduce breaking API changes. See the :ref:`limitations <limitations>` section for more details.
20+
RDT is currently in **alpha** and doesn't support all Ray Core APIs yet. Future releases may introduce breaking API changes. See the :ref:`limitations <limitations>` section for more details.
2121

2222
Getting started
2323
===============
@@ -290,12 +290,6 @@ For collective-based tensor transports (Gloo and NCCL):
290290
* Similarly, the process that created the collective group cannot serialize and pass RDT :class:`ray.ObjectRefs <ray.ObjectRef>` to other Ray tasks or actors. Instead, the :class:`ray.ObjectRef`\s can only be passed as direct arguments to other actor tasks, and those actors must be in the same collective group.
291291
* Each actor can only be in one collective group per tensor transport at a time.
292292
* No support for :func:`ray.put <ray.put>`.
293-
* If a system-level error occurs during a collective operation, the collective group will be destroyed and the actors will no longer be able to communicate via the collective group. Note that application-level errors, i.e. exceptions raised by user code, will not destroy the collective group and will instead be propagated to any dependent task(s), as for non-RDT Ray objects. System-level errors include:
294-
295-
* Errors internal to the third-party transport, e.g., NCCL network errors
296-
* Actor and node failure
297-
* Tensors returned by the user that are located on an unsupported device, e.g., a CPU tensor when using NCCL
298-
* Any unexpected system bugs
299293

300294

301295
Due to a known issue, for NIXL, we currently do not support storing different GPU objects at the same actor, where the objects contain an overlapping but not equal set of tensors. To support this pattern, ensure that the first `ObjectRef` has gone out of scope before storing the same tensor(s) again in a second object.
@@ -305,6 +299,23 @@ Due to a known issue, for NIXL, we currently do not support storing different GP
305299
:start-after: __nixl_limitations_start__
306300
:end-before: __nixl_limitations_end__
307301

302+
Error handling
303+
==============
304+
305+
* Application-level errors, i.e. exceptions raised by user code, will not destroy the collective group and will instead be propagated to any dependent task(s), as for non-RDT Ray objects.
306+
307+
* If a system-level error occurs during a GLOO or NCCL collective operation, the collective group will be destroyed and the actors will be killed to prevent any hanging.
308+
309+
* If a system-level error occurs during a NIXL transfer, Ray or NIXL will abort the transfer with an exception and Ray will raise the exception in the dependent task or on the ray.get on the NIXL ref.
310+
311+
* System-level errors include:
312+
* Errors internal to the third-party transport, e.g., NCCL network errors
313+
* Actor or node failures
314+
* Transport errors due to tensor device / transport mismatches, e.g., a CPU tensor when using NCCL
315+
* Ray object fetch timeouts (can be overridden by setting the ``RAY_fetch_fail_timeout_milliseconds`` environment variable)
316+
* Any unexpected system bugs
317+
318+
308319
Advanced: RDT Internals
309320
=======================
310321

doc/source/ray-core/doc_code/direct_transport_nixl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,6 @@ def sum_dict(self, dict):
8888
result2 = receiver.sum_dict.remote(ref2)
8989
try:
9090
print(ray.get(result2))
91-
except ActorDiedError as e:
91+
except ValueError as e:
9292
print("Error caught:", e)
9393
# __nixl_limitations_end__

python/ray/actor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,7 @@ def _process_option_dict(actor_options, has_tensor_transport_methods):
11801180
if _filled_options.get("concurrency_groups", None) is None:
11811181
_filled_options["concurrency_groups"] = {}
11821182
_filled_options["concurrency_groups"]["_ray_system"] = 1
1183+
_filled_options["concurrency_groups"]["_ray_system_error"] = 1
11831184

11841185
return _filled_options
11851186

python/ray/experimental/collective/collective_tensor_transport.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ def tensor_transport_backend(self) -> Backend:
2626
def is_one_sided() -> bool:
2727
return False
2828

29+
@staticmethod
30+
def can_abort_transport() -> bool:
31+
return False
32+
2933
def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool:
3034
from ray.experimental.collective import get_collective_groups
3135

@@ -137,6 +141,7 @@ def get_communicator_metadata(
137141
@staticmethod
138142
def recv_multiple_tensors(
139143
tensors,
144+
obj_id: str,
140145
tensor_transport_metadata: CollectiveTransportMetadata,
141146
communicator_metadata: CollectiveCommunicatorMetadata,
142147
):
@@ -183,3 +188,12 @@ def garbage_collect(
183188
obj_id: str, tensor_transport_meta: CollectiveTransportMetadata
184189
):
185190
pass
191+
192+
@staticmethod
193+
def abort_transport(
194+
obj_id: str,
195+
communicator_metadata: CollectiveCommunicatorMetadata,
196+
):
197+
raise NotImplementedError(
198+
"Collective transport does not support abort_transport for now."
199+
)

python/ray/experimental/collective/nixl_tensor_transport.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ def tensor_transport_backend(self) -> Backend:
2424
def is_one_sided() -> bool:
2525
return True
2626

27+
@staticmethod
28+
def can_abort_transport() -> bool:
29+
return True
30+
2731
def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool:
2832
def __ray_actor_has_tensor_transport__(
2933
self: "ray.actor.ActorHandle",
@@ -134,6 +138,7 @@ def get_communicator_metadata(
134138
@staticmethod
135139
def recv_multiple_tensors(
136140
tensors,
141+
obj_id: str,
137142
tensor_transport_metadata: NixlTransportMetadata,
138143
communicator_metadata: NixlCommunicatorMetadata,
139144
):
@@ -152,6 +157,7 @@ def recv_multiple_tensors(
152157

153158
g.recv(
154159
tensors,
160+
obj_id,
155161
tensor_transport_metadata.nixl_serialized_descs,
156162
tensor_transport_metadata.nixl_agent_meta,
157163
)
@@ -178,3 +184,14 @@ def garbage_collect(obj_id: str, tensor_transport_meta: NixlTransportMetadata):
178184
if descs is not None:
179185
nixl_backend = get_group_handle(NIXL_GROUP_NAME)
180186
nixl_backend.deregister_memory(descs)
187+
188+
@staticmethod
189+
def abort_transport(
190+
obj_id: str,
191+
communicator_metadata: NixlCommunicatorMetadata,
192+
):
193+
from ray.util.collective.collective import get_group_handle
194+
195+
g = get_group_handle(communicator_metadata.communicator_name)
196+
if g:
197+
g.abort(obj_id)

python/ray/experimental/collective/tensor_transport_manager.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@ def is_one_sided() -> bool:
3131
bool: True if the backend is one-sided, False otherwise.
3232
"""
3333

34+
@staticmethod
35+
@abstractmethod
36+
def can_abort_transport() -> bool:
37+
"""
38+
Whether the backend can abort the transport.
39+
If this returns False, then Ray will kill involved actors upon system errors to avoid hanging.
40+
41+
Returns:
42+
bool: True if the backend can abort the transport.
43+
"""
44+
3445
@abstractmethod
3546
def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool:
3647
"""Whether the actor has the tensor transport available.
@@ -102,6 +113,7 @@ def get_communicator_metadata(
102113
@abstractmethod
103114
def recv_multiple_tensors(
104115
tensors: List["torch.Tensor"],
116+
obj_id: str,
105117
tensor_transport_metadata: TensorTransportMetadata,
106118
communicator_metadata: CommunicatorMetadata,
107119
):
@@ -110,6 +122,7 @@ def recv_multiple_tensors(
110122
111123
Args:
112124
tensors: The pre-allocated tensor space to receive the tensors.
125+
obj_id: The object ID for related GPU object.
113126
tensor_transport_metadata: The tensor transport metadata for the GPU object.
114127
communicator_metadata: The communicator metadata for the send/recv operation.
115128
@@ -139,3 +152,17 @@ def garbage_collect(obj_id: str, tensor_transport_meta: TensorTransportMetadata)
139152
obj_id: The ID of the GPU object to garbage collect.
140153
tensor_transport_meta: The tensor transport metadata.
141154
"""
155+
156+
@staticmethod
157+
@abstractmethod
158+
def abort_transport(
159+
obj_id: str,
160+
communicator_metadata: CommunicatorMetadata,
161+
):
162+
"""
163+
Abort the transport.
164+
165+
Args:
166+
obj_id: The object ID for related GPU object.
167+
communicator_metadata: The communicator metadata for the send/recv operation.
168+
"""

python/ray/experimental/gpu_object_manager/gpu_object_manager.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class TransferMetadata(NamedTuple):
4747
recv_ref: ObjectRef
4848
communicator_meta: "CommunicatorMetadata"
4949
backend: str
50+
obj_id: str
5051
timeout: float
5152

5253

@@ -179,28 +180,59 @@ def _abort_transport(
179180
Cleans up the ref_info_map, kill the src and dst actors, and destroy the
180181
collective group if necessary.
181182
"""
182-
from ray.experimental.collective import destroy_collective_group
183+
from ray.experimental.collective import (
184+
destroy_collective_group,
185+
get_tensor_transport_manager,
186+
)
187+
from ray.experimental.gpu_object_manager.gpu_object_store import (
188+
__ray_abort_transport__,
189+
)
183190
from ray.util.collective.types import CollectiveCommunicatorMetadata
184191

185192
ref_info = ref_info_map.pop(failed_ref.hex(), None)
186193
if ref_info is None:
187194
return
188195

189-
logger.error(
190-
"RDT transfer with src actor %s and dst actor %s failed. Killing the actors. "
191-
"Transfer failed with exception: %s",
192-
ref_info.src_actor,
193-
ref_info.dst_actor,
194-
exception,
195-
)
196-
197196
if ref_info.send_ref:
198197
ref_info_map.pop(ref_info.send_ref.hex(), None)
199198
ref_info_map.pop(ref_info.recv_ref.hex(), None)
200199

201-
# TODO(#51276): Kill all actors in the collective group when we support more collective operations
202-
ray.kill(ref_info.src_actor)
203-
ray.kill(ref_info.dst_actor)
200+
tensor_transport_manager = get_tensor_transport_manager(ref_info.backend)
201+
if tensor_transport_manager.can_abort_transport():
202+
if not tensor_transport_manager.is_one_sided():
203+
# This is dead code until we implement a NCCL abort since NIXL
204+
# is the only abortable transport for now and is one-sided.
205+
ref_info.src_actor.__ray_call__.options(
206+
concurrency_group="_ray_system_error"
207+
).remote(
208+
__ray_abort_transport__,
209+
ref_info.obj_id,
210+
ref_info.communicator_meta,
211+
)
212+
ref_info.dst_actor.__ray_call__.options(
213+
concurrency_group="_ray_system_error"
214+
).remote(
215+
__ray_abort_transport__,
216+
ref_info.obj_id,
217+
ref_info.communicator_meta,
218+
)
219+
logger.info(
220+
"RDT transfer with src actor %s and dst actor %s failed due to %s.",
221+
ref_info.src_actor,
222+
ref_info.dst_actor,
223+
exception,
224+
)
225+
else:
226+
# TODO(#51276): Kill all actors in the collective group when we support more collective operations
227+
ray.kill(ref_info.src_actor)
228+
ray.kill(ref_info.dst_actor)
229+
logger.error(
230+
"RDT transfer with src actor %s and dst actor %s failed. Killing the actors. "
231+
"Transfer failed with exception: %s",
232+
ref_info.src_actor,
233+
ref_info.dst_actor,
234+
exception,
235+
)
204236

205237
# isinstance does an implicit cast and makes communicator_name inaccessible
206238
# so we have to get communicator_name before the cast.
@@ -336,7 +368,7 @@ def _fetch_object(
336368
__ray_fetch_gpu_object__, obj_id
337369
)
338370
)
339-
self.gpu_object_store.add_object(obj_id, tensors)
371+
self.gpu_object_store.add_object(obj_id, tensors, is_primary=False)
340372
else:
341373
if isinstance(gpu_object_meta.tensor_transport_meta, ObjectRef):
342374
# If the tensor transport meta is an ObjectRef, gpu object manager
@@ -358,7 +390,7 @@ def _fetch_object(
358390
None, None, tensor_transport_backend
359391
)
360392
__ray_recv__(
361-
None, obj_id, gpu_object_meta.tensor_transport_meta, communicator_meta
393+
None, obj_id, [gpu_object_meta.tensor_transport_meta], communicator_meta
362394
)
363395

364396
def trigger_out_of_band_tensor_transfer(
@@ -474,7 +506,7 @@ def trigger_out_of_band_tensor_transfer(
474506
).remote(
475507
__ray_recv__,
476508
obj_id,
477-
tensor_transport_meta,
509+
[tensor_transport_meta],
478510
communicator_meta,
479511
)
480512

@@ -486,6 +518,7 @@ def trigger_out_of_band_tensor_transfer(
486518
recv_ref=recv_ref,
487519
communicator_meta=communicator_meta,
488520
backend=gpu_object_meta.tensor_transport_backend,
521+
obj_id=obj_id,
489522
timeout=time.time() + ray_constants.FETCH_FAIL_TIMEOUT_SECONDS,
490523
)
491524
)

0 commit comments

Comments
 (0)