Skip to content

Commit fae0d3e

Browse files
authored
nx-cugraph: add G.__networkx_cache__ to enable graph conversion caching (#4567)
Adding this now, because this was an oversight. By having a `G.__networkx_cache__` MutableMapping, we give NetworkX the ability to cache graph conversions. This will be most useful in dev and future versions of NetworkX. As such, it would be nice (but not strictly essential) to get this in 24.08. Authors: - Erik Welch (https://github.com/eriknw) Approvers: - Rick Ratzel (https://github.com/rlratzel) URL: #4567
1 parent 5458e76 commit fae0d3e

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

python/nx-cugraph/nx_cugraph/classes/graph.py

+26
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class Graph:
5050
__networkx_backend__: ClassVar[str] = "cugraph" # nx >=3.2
5151
__networkx_plugin__: ClassVar[str] = "cugraph" # nx <3.2
5252

53+
# Allow networkx dispatch machinery to cache conversions.
54+
# This means we should clear the cache if we ever mutate the object!
55+
__networkx_cache__: dict | None
56+
5357
# networkx properties
5458
graph: dict
5559
graph_attr_dict_factory: ClassVar[type] = dict
@@ -108,6 +112,7 @@ def from_coo(
108112
**attr,
109113
) -> Graph:
110114
new_graph = object.__new__(cls)
115+
new_graph.__networkx_cache__ = {}
111116
new_graph.src_indices = src_indices
112117
new_graph.dst_indices = dst_indices
113118
new_graph.edge_values = {} if edge_values is None else dict(edge_values)
@@ -420,13 +425,17 @@ def clear(self) -> None:
420425
self._node_ids = None
421426
self.key_to_id = None
422427
self._id_to_key = None
428+
if cache := self.__networkx_cache__:
429+
cache.clear()
423430

424431
@networkx_api
425432
def clear_edges(self) -> None:
426433
self.edge_values.clear()
427434
self.edge_masks.clear()
428435
self.src_indices = cp.empty(0, self.src_indices.dtype)
429436
self.dst_indices = cp.empty(0, self.dst_indices.dtype)
437+
if cache := self.__networkx_cache__:
438+
cache.clear()
430439

431440
@networkx_api
432441
def copy(self, as_view: bool = False) -> Graph:
@@ -553,6 +562,12 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
553562
node_masks = self.node_masks
554563
key_to_id = self.key_to_id
555564
id_to_key = None if key_to_id is None else self._id_to_key
565+
if self.__networkx_cache__ is None:
566+
__networkx_cache__ = None
567+
elif not reverse and cls is self.__class__:
568+
__networkx_cache__ = self.__networkx_cache__
569+
else:
570+
__networkx_cache__ = {}
556571
if not as_view:
557572
src_indices = src_indices.copy()
558573
dst_indices = dst_indices.copy()
@@ -564,6 +579,8 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
564579
key_to_id = key_to_id.copy()
565580
if id_to_key is not None:
566581
id_to_key = id_to_key.copy()
582+
if __networkx_cache__ is not None:
583+
__networkx_cache__ = __networkx_cache__.copy()
567584
if reverse:
568585
src_indices, dst_indices = dst_indices, src_indices
569586
rv = cls.from_coo(
@@ -581,6 +598,7 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
581598
rv.graph = self.graph
582599
else:
583600
rv.graph.update(deepcopy(self.graph))
601+
rv.__networkx_cache__ = __networkx_cache__
584602
return rv
585603

586604
def _get_plc_graph(
@@ -719,18 +737,26 @@ def _become(self, other: Graph):
719737
edge_masks = self.edge_masks
720738
node_values = self.node_values
721739
node_masks = self.node_masks
740+
__networkx_cache__ = self.__networkx_cache__
722741
graph = self.graph
723742
edge_values.update(other.edge_values)
724743
edge_masks.update(other.edge_masks)
725744
node_values.update(other.node_values)
726745
node_masks.update(other.node_masks)
727746
graph.update(other.graph)
747+
if other.__networkx_cache__ is None:
748+
__networkx_cache__ = None
749+
else:
750+
if __networkx_cache__ is None:
751+
__networkx_cache__ = {}
752+
__networkx_cache__.update(other.__networkx_cache__)
728753
self.__dict__.update(other.__dict__)
729754
self.edge_values = edge_values
730755
self.edge_masks = edge_masks
731756
self.node_values = node_values
732757
self.node_masks = node_masks
733758
self.graph = graph
759+
self.__networkx_cache__ = __networkx_cache__
734760
return self
735761

736762
def _degrees_array(self, *, ignore_selfloops=False):

python/nx-cugraph/nx_cugraph/classes/multigraph.py

+9
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,12 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
415415
key_to_id = self.key_to_id
416416
id_to_key = None if key_to_id is None else self._id_to_key
417417
edge_keys = self.edge_keys
418+
if self.__networkx_cache__ is None:
419+
__networkx_cache__ = None
420+
elif not reverse and cls is self.__class__:
421+
__networkx_cache__ = self.__networkx_cache__
422+
else:
423+
__networkx_cache__ = {}
418424
if not as_view:
419425
src_indices = src_indices.copy()
420426
dst_indices = dst_indices.copy()
@@ -429,6 +435,8 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
429435
id_to_key = id_to_key.copy()
430436
if edge_keys is not None:
431437
edge_keys = edge_keys.copy()
438+
if __networkx_cache__ is not None:
439+
__networkx_cache__ = __networkx_cache__.copy()
432440
if reverse:
433441
src_indices, dst_indices = dst_indices, src_indices
434442
rv = cls.from_coo(
@@ -448,6 +456,7 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
448456
rv.graph = self.graph
449457
else:
450458
rv.graph.update(deepcopy(self.graph))
459+
rv.__networkx_cache__ = __networkx_cache__
451460
return rv
452461

453462
def _sort_edge_indices(self, primary="src"):

0 commit comments

Comments
 (0)