Skip to content

Commit c941748

Browse files
authored
Support non p2p configuration when initializing the comms (#4543)
closes #4490 Authors: - Joseph Nke (https://github.com/jnke2016) - Ralph Liu (https://github.com/nv-rliu) - Chuck Hastings (https://github.com/ChuckHastings) Approvers: - Chuck Hastings (https://github.com/ChuckHastings) - Rick Ratzel (https://github.com/rlratzel) URL: #4543
1 parent 8f7fec9 commit c941748

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

python/cugraph/cugraph/dask/comms/comms.py

-2
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,6 @@ def initialize(comms=None, p2p=False, prows=None, pcols=None, partition_type=1):
146146
__default_handle = None
147147
if comms is None:
148148
# Initialize communicator
149-
if not p2p:
150-
raise Exception("Set p2p to True for running mnmg algorithms")
151149
__instance = raftComms(comms_p2p=p2p)
152150
__instance.init()
153151
# Initialize subcommunicator

python/cugraph/cugraph/testing/mg_utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
1+
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -35,6 +35,7 @@ def start_dask_client(
3535
jit_unspill=False,
3636
worker_class=None,
3737
device_memory_limit=0.8,
38+
p2p=True,
3839
):
3940
"""
4041
Creates a new dask client, and possibly also a cluster, and returns them as
@@ -95,6 +96,9 @@ def start_dask_client(
9596
dask_cuda.LocalCUDACluster for details. This parameter is ignored if
9697
the env var SCHEDULER_FILE is set which implies the dask cluster has
9798
already been created.
99+
100+
p2p : bool, optional (default=True)
101+
Initialize UCX endpoints if True.
98102
"""
99103
dask_scheduler_file = os.environ.get("SCHEDULER_FILE")
100104
dask_local_directory = os.getenv("DASK_LOCAL_DIRECTORY")
@@ -164,7 +168,7 @@ def start_dask_client(
164168
# FIXME: use proper logging, INFO or DEBUG level
165169
print("\nDask client/cluster created using LocalCUDACluster")
166170

167-
Comms.initialize(p2p=True)
171+
Comms.initialize(p2p=p2p)
168172

169173
return (client, cluster)
170174

python/cugraph/cugraph/tests/conftest.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
1+
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -52,6 +52,21 @@ def dask_client():
5252
stop_dask_client(dask_client, dask_cluster)
5353

5454

55+
# FIXME: Add tests leveraging this fixture
56+
@pytest.fixture(scope="module")
57+
def dask_client_non_p2p():
58+
# start_dask_client will check for the SCHEDULER_FILE and
59+
# DASK_WORKER_DEVICES env vars and use them when creating a client if
60+
# set. start_dask_client will also initialize the Comms singleton.
61+
dask_client, dask_cluster = start_dask_client(
62+
worker_class=IncreasedCloseTimeoutNanny, p2p=False
63+
)
64+
65+
yield dask_client
66+
67+
stop_dask_client(dask_client, dask_cluster)
68+
69+
5570
@pytest.fixture(scope="module")
5671
def scratch_dir():
5772
# This should always be set if doing MG testing, since temporary

0 commit comments

Comments
 (0)