Skip to content

Commit 4a47cae

Browse files
authored
Merge pull request #4574 from rapidsai/branch-24.08
Forward-merge branch-24.08 into branch-24.10
2 parents 2a79a83 + fae0d3e commit 4a47cae

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

python/nx-cugraph/nx_cugraph/classes/graph.py

+26
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class Graph:
5050
__networkx_backend__: ClassVar[str] = "cugraph" # nx >=3.2
5151
__networkx_plugin__: ClassVar[str] = "cugraph" # nx <3.2
5252

53+
# Allow networkx dispatch machinery to cache conversions.
54+
# This means we should clear the cache if we ever mutate the object!
55+
__networkx_cache__: dict | None
56+
5357
# networkx properties
5458
graph: dict
5559
graph_attr_dict_factory: ClassVar[type] = dict
@@ -108,6 +112,7 @@ def from_coo(
108112
**attr,
109113
) -> Graph:
110114
new_graph = object.__new__(cls)
115+
new_graph.__networkx_cache__ = {}
111116
new_graph.src_indices = src_indices
112117
new_graph.dst_indices = dst_indices
113118
new_graph.edge_values = {} if edge_values is None else dict(edge_values)
@@ -420,13 +425,17 @@ def clear(self) -> None:
420425
self._node_ids = None
421426
self.key_to_id = None
422427
self._id_to_key = None
428+
if cache := self.__networkx_cache__:
429+
cache.clear()
423430

424431
@networkx_api
425432
def clear_edges(self) -> None:
426433
self.edge_values.clear()
427434
self.edge_masks.clear()
428435
self.src_indices = cp.empty(0, self.src_indices.dtype)
429436
self.dst_indices = cp.empty(0, self.dst_indices.dtype)
437+
if cache := self.__networkx_cache__:
438+
cache.clear()
430439

431440
@networkx_api
432441
def copy(self, as_view: bool = False) -> Graph:
@@ -553,6 +562,12 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
553562
node_masks = self.node_masks
554563
key_to_id = self.key_to_id
555564
id_to_key = None if key_to_id is None else self._id_to_key
565+
if self.__networkx_cache__ is None:
566+
__networkx_cache__ = None
567+
elif not reverse and cls is self.__class__:
568+
__networkx_cache__ = self.__networkx_cache__
569+
else:
570+
__networkx_cache__ = {}
556571
if not as_view:
557572
src_indices = src_indices.copy()
558573
dst_indices = dst_indices.copy()
@@ -564,6 +579,8 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
564579
key_to_id = key_to_id.copy()
565580
if id_to_key is not None:
566581
id_to_key = id_to_key.copy()
582+
if __networkx_cache__ is not None:
583+
__networkx_cache__ = __networkx_cache__.copy()
567584
if reverse:
568585
src_indices, dst_indices = dst_indices, src_indices
569586
rv = cls.from_coo(
@@ -581,6 +598,7 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
581598
rv.graph = self.graph
582599
else:
583600
rv.graph.update(deepcopy(self.graph))
601+
rv.__networkx_cache__ = __networkx_cache__
584602
return rv
585603

586604
def _get_plc_graph(
@@ -719,18 +737,26 @@ def _become(self, other: Graph):
719737
edge_masks = self.edge_masks
720738
node_values = self.node_values
721739
node_masks = self.node_masks
740+
__networkx_cache__ = self.__networkx_cache__
722741
graph = self.graph
723742
edge_values.update(other.edge_values)
724743
edge_masks.update(other.edge_masks)
725744
node_values.update(other.node_values)
726745
node_masks.update(other.node_masks)
727746
graph.update(other.graph)
747+
if other.__networkx_cache__ is None:
748+
__networkx_cache__ = None
749+
else:
750+
if __networkx_cache__ is None:
751+
__networkx_cache__ = {}
752+
__networkx_cache__.update(other.__networkx_cache__)
728753
self.__dict__.update(other.__dict__)
729754
self.edge_values = edge_values
730755
self.edge_masks = edge_masks
731756
self.node_values = node_values
732757
self.node_masks = node_masks
733758
self.graph = graph
759+
self.__networkx_cache__ = __networkx_cache__
734760
return self
735761

736762
def _degrees_array(self, *, ignore_selfloops=False):

python/nx-cugraph/nx_cugraph/classes/multigraph.py

+9
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,12 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
415415
key_to_id = self.key_to_id
416416
id_to_key = None if key_to_id is None else self._id_to_key
417417
edge_keys = self.edge_keys
418+
if self.__networkx_cache__ is None:
419+
__networkx_cache__ = None
420+
elif not reverse and cls is self.__class__:
421+
__networkx_cache__ = self.__networkx_cache__
422+
else:
423+
__networkx_cache__ = {}
418424
if not as_view:
419425
src_indices = src_indices.copy()
420426
dst_indices = dst_indices.copy()
@@ -429,6 +435,8 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
429435
id_to_key = id_to_key.copy()
430436
if edge_keys is not None:
431437
edge_keys = edge_keys.copy()
438+
if __networkx_cache__ is not None:
439+
__networkx_cache__ = __networkx_cache__.copy()
432440
if reverse:
433441
src_indices, dst_indices = dst_indices, src_indices
434442
rv = cls.from_coo(
@@ -448,6 +456,7 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
448456
rv.graph = self.graph
449457
else:
450458
rv.graph.update(deepcopy(self.graph))
459+
rv.__networkx_cache__ = __networkx_cache__
451460
return rv
452461

453462
def _sort_edge_indices(self, primary="src"):

0 commit comments

Comments
 (0)