From ac35be3fc9aaa529282d6e54df3cdb845c710647 Mon Sep 17 00:00:00 2001 From: Alex Barghi <105237337+alexbarghi-nv@users.noreply.github.com> Date: Tue, 30 Jul 2024 10:20:42 -0400 Subject: [PATCH] [BUG] Use the Correct WG Communicator (#4548) cuGraph-PyG's WholeFeatureStore currently uses the local communicator, when it should be using the global communicator, as was originally intended. This PR modifies the feature store so it correctly calls `get_global_node_communicator()`. This also fixes another bug where torch.int32 was used to store the number of edges in the graph, which resulted in an overflow error when the number of edges exceeded that datatype's maximum value. The datatype is now correctly set to int64. Authors: - Alex Barghi (https://github.com/alexbarghi-nv) Approvers: - Rick Ratzel (https://github.com/rlratzel) URL: https://github.com/rapidsai/cugraph/pull/4548 --- python/cugraph-pyg/cugraph_pyg/data/feature_store.py | 2 +- python/cugraph-pyg/cugraph_pyg/data/graph_store.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/data/feature_store.py b/python/cugraph-pyg/cugraph_pyg/data/feature_store.py index a3715d3ddf4..b6450e7b192 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/feature_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/feature_store.py @@ -169,7 +169,7 @@ def __init__(self, memory_type="distributed", location="cpu"): self.__features = {} - self.__wg_comm = wgth.get_local_node_communicator() + self.__wg_comm = wgth.get_global_communicator() self.__wg_type = memory_type self.__wg_location = location diff --git a/python/cugraph-pyg/cugraph_pyg/data/graph_store.py b/python/cugraph-pyg/cugraph_pyg/data/graph_store.py index 622b68d37e2..e086bf07b1f 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/graph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/graph_store.py @@ -271,7 +271,7 @@ def __get_edgelist(self): torch.tensor( [self.__edge_indices[et].shape[1] for et in sorted_keys], device="cuda", - dtype=torch.int32, + dtype=torch.int64, ) )