13
13
14
14
import pytest
15
15
import numpy as np
16
+ import os
16
17
17
18
from cugraph .gnn import FeatureStore
18
19
21
22
pylibwholegraph = import_optional ("pylibwholegraph" )
22
23
wmb = import_optional ("pylibwholegraph.binding.wholememory_binding" )
23
24
torch = import_optional ("torch" )
25
+ wgth = import_optional ("pylibwholegraph.torch" )
24
26
25
27
26
- def runtest (world_rank : int , world_size : int ):
27
- from pylibwholegraph . torch .initialize import init_torch_env_and_create_wm_comm
28
+ def runtest (rank : int , world_size : int ):
29
+ torch .cuda . set_device ( rank )
28
30
29
- wm_comm , _ = init_torch_env_and_create_wm_comm (
30
- world_rank ,
31
+ os .environ ["MASTER_ADDR" ] = "localhost"
32
+ os .environ ["MASTER_PORT" ] = "12355"
33
+ torch .distributed .init_process_group ("nccl" , rank = rank , world_size = world_size )
34
+
35
+ pylibwholegraph .torch .initialize .init (
36
+ rank ,
31
37
world_size ,
32
- world_rank ,
38
+ rank ,
33
39
world_size ,
34
40
)
35
- wm_comm = wm_comm . wmb_comm
41
+ wm_comm = wgth . get_global_communicator ()
36
42
37
43
generator = np .random .default_rng (62 )
38
44
arr = (
@@ -52,36 +58,32 @@ def runtest(world_rank: int, world_size: int):
52
58
expected = arr [indices_to_fetch ]
53
59
np .testing .assert_array_equal (output_fs .cpu ().numpy (), expected )
54
60
55
- wmb .finalize ()
61
+ pylibwholegraph . torch . initialize .finalize ()
56
62
57
63
58
64
@pytest .mark .sg
59
65
@pytest .mark .skipif (isinstance (torch , MissingModule ), reason = "torch not available" )
60
66
@pytest .mark .skipif (
61
67
isinstance (pylibwholegraph , MissingModule ), reason = "wholegraph not available"
62
68
)
63
- @pytest .mark .skip (reason = "broken" )
64
69
def test_feature_storage_wholegraph_backend ():
65
- from pylibwholegraph .utils .multiprocess import multiprocess_run
70
+ world_size = torch .cuda .device_count ()
71
+ print ("gpu count:" , world_size )
72
+ assert world_size > 0
66
73
67
- gpu_count = wmb .fork_get_gpu_count ()
68
- print ("gpu count:" , gpu_count )
69
- assert gpu_count > 0
74
+ print ("ignoring gpu count and running on 1 GPU only" )
70
75
71
- multiprocess_run ( 1 , runtest )
76
+ torch . multiprocessing . spawn ( runtest , args = ( 1 ,), nprocs = 1 )
72
77
73
78
74
79
@pytest .mark .mg
75
80
@pytest .mark .skipif (isinstance (torch , MissingModule ), reason = "torch not available" )
76
81
@pytest .mark .skipif (
77
82
isinstance (pylibwholegraph , MissingModule ), reason = "wholegraph not available"
78
83
)
79
- @pytest .mark .skip (reason = "broken" )
80
84
def test_feature_storage_wholegraph_backend_mg ():
81
- from pylibwholegraph .utils .multiprocess import multiprocess_run
82
-
83
- gpu_count = wmb .fork_get_gpu_count ()
84
- print ("gpu count:" , gpu_count )
85
- assert gpu_count > 0
85
+ world_size = torch .cuda .device_count ()
86
+ print ("gpu count:" , world_size )
87
+ assert world_size > 0
86
88
87
- multiprocess_run ( gpu_count , runtest )
89
+ torch . multiprocessing . spawn ( runtest , args = ( world_size ,), nprocs = world_size )
0 commit comments