Skip to content

Commit 5458e76

Browse files
[BUG] Fix Failing WholeGraph Tests (#4560)
This PR properly uses PyTorch DDP to initialize a process group and test the WholeGraph feature store. Previously it was relying on an API in WholeGraph that no longer appears to work. Authors: - Alex Barghi (https://github.com/alexbarghi-nv) Approvers: - Rick Ratzel (https://github.com/rlratzel) URL: #4560
1 parent aa0347c commit 5458e76

File tree

1 file changed

+22
-20
lines changed

1 file changed

+22
-20
lines changed

python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage_wholegraph.py

+22-20
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import pytest
1515
import numpy as np
16+
import os
1617

1718
from cugraph.gnn import FeatureStore
1819

@@ -21,18 +22,23 @@
2122
pylibwholegraph = import_optional("pylibwholegraph")
2223
wmb = import_optional("pylibwholegraph.binding.wholememory_binding")
2324
torch = import_optional("torch")
25+
wgth = import_optional("pylibwholegraph.torch")
2426

2527

26-
def runtest(world_rank: int, world_size: int):
27-
from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm
28+
def runtest(rank: int, world_size: int):
29+
torch.cuda.set_device(rank)
2830

29-
wm_comm, _ = init_torch_env_and_create_wm_comm(
30-
world_rank,
31+
os.environ["MASTER_ADDR"] = "localhost"
32+
os.environ["MASTER_PORT"] = "12355"
33+
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)
34+
35+
pylibwholegraph.torch.initialize.init(
36+
rank,
3137
world_size,
32-
world_rank,
38+
rank,
3339
world_size,
3440
)
35-
wm_comm = wm_comm.wmb_comm
41+
wm_comm = wgth.get_global_communicator()
3642

3743
generator = np.random.default_rng(62)
3844
arr = (
@@ -52,36 +58,32 @@ def runtest(world_rank: int, world_size: int):
5258
expected = arr[indices_to_fetch]
5359
np.testing.assert_array_equal(output_fs.cpu().numpy(), expected)
5460

55-
wmb.finalize()
61+
pylibwholegraph.torch.initialize.finalize()
5662

5763

5864
@pytest.mark.sg
5965
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
6066
@pytest.mark.skipif(
6167
isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available"
6268
)
63-
@pytest.mark.skip(reason="broken")
6469
def test_feature_storage_wholegraph_backend():
65-
from pylibwholegraph.utils.multiprocess import multiprocess_run
70+
world_size = torch.cuda.device_count()
71+
print("gpu count:", world_size)
72+
assert world_size > 0
6673

67-
gpu_count = wmb.fork_get_gpu_count()
68-
print("gpu count:", gpu_count)
69-
assert gpu_count > 0
74+
print("ignoring gpu count and running on 1 GPU only")
7075

71-
multiprocess_run(1, runtest)
76+
torch.multiprocessing.spawn(runtest, args=(1,), nprocs=1)
7277

7378

7479
@pytest.mark.mg
7580
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
7681
@pytest.mark.skipif(
7782
isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available"
7883
)
79-
@pytest.mark.skip(reason="broken")
8084
def test_feature_storage_wholegraph_backend_mg():
81-
from pylibwholegraph.utils.multiprocess import multiprocess_run
82-
83-
gpu_count = wmb.fork_get_gpu_count()
84-
print("gpu count:", gpu_count)
85-
assert gpu_count > 0
85+
world_size = torch.cuda.device_count()
86+
print("gpu count:", world_size)
87+
assert world_size > 0
8688

87-
multiprocess_run(gpu_count, runtest)
89+
torch.multiprocessing.spawn(runtest, args=(world_size,), nprocs=world_size)

0 commit comments

Comments
 (0)