Skip to content

Commit

Permalink
[BUG] Use the Correct WG Communicator (#4548)
Browse files Browse the repository at this point in the history
cuGraph-PyG's WholeFeatureStore currently uses the local communicator, when it should be using the global communicator, as was originally intended.  This PR modifies the feature store so it correctly calls `get_global_node_communicator()`.

This also fixes another bug where torch.int32 was used to store the number of edges in the graph, which resulted in an overflow error when the number of edges exceeded that datatype's maximum value.  The datatype is now correctly set to int64.

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)

URL: #4548
  • Loading branch information
alexbarghi-nv authored Jul 30, 2024
1 parent 94e60f0 commit ac35be3
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/cugraph-pyg/cugraph_pyg/data/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(self, memory_type="distributed", location="cpu"):

self.__features = {}

self.__wg_comm = wgth.get_local_node_communicator()
self.__wg_comm = wgth.get_global_communicator()
self.__wg_type = memory_type
self.__wg_location = location

Expand Down
2 changes: 1 addition & 1 deletion python/cugraph-pyg/cugraph_pyg/data/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __get_edgelist(self):
torch.tensor(
[self.__edge_indices[et].shape[1] for et in sorted_keys],
device="cuda",
dtype=torch.int32,
dtype=torch.int64,
)
)

Expand Down

0 comments on commit ac35be3

Please sign in to comment.