@@ -50,6 +50,10 @@ class Graph:
50
50
__networkx_backend__ : ClassVar [str ] = "cugraph" # nx >=3.2
51
51
__networkx_plugin__ : ClassVar [str ] = "cugraph" # nx <3.2
52
52
53
+ # Allow networkx dispatch machinery to cache conversions.
54
+ # This means we should clear the cache if we ever mutate the object!
55
+ __networkx_cache__ : dict | None
56
+
53
57
# networkx properties
54
58
graph : dict
55
59
graph_attr_dict_factory : ClassVar [type ] = dict
@@ -108,6 +112,7 @@ def from_coo(
108
112
** attr ,
109
113
) -> Graph :
110
114
new_graph = object .__new__ (cls )
115
+ new_graph .__networkx_cache__ = {}
111
116
new_graph .src_indices = src_indices
112
117
new_graph .dst_indices = dst_indices
113
118
new_graph .edge_values = {} if edge_values is None else dict (edge_values )
@@ -420,13 +425,17 @@ def clear(self) -> None:
420
425
self ._node_ids = None
421
426
self .key_to_id = None
422
427
self ._id_to_key = None
428
+ if cache := self .__networkx_cache__ :
429
+ cache .clear ()
423
430
424
431
@networkx_api
425
432
def clear_edges (self ) -> None :
426
433
self .edge_values .clear ()
427
434
self .edge_masks .clear ()
428
435
self .src_indices = cp .empty (0 , self .src_indices .dtype )
429
436
self .dst_indices = cp .empty (0 , self .dst_indices .dtype )
437
+ if cache := self .__networkx_cache__ :
438
+ cache .clear ()
430
439
431
440
@networkx_api
432
441
def copy (self , as_view : bool = False ) -> Graph :
@@ -553,6 +562,12 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
553
562
node_masks = self .node_masks
554
563
key_to_id = self .key_to_id
555
564
id_to_key = None if key_to_id is None else self ._id_to_key
565
+ if self .__networkx_cache__ is None :
566
+ __networkx_cache__ = None
567
+ elif not reverse and cls is self .__class__ :
568
+ __networkx_cache__ = self .__networkx_cache__
569
+ else :
570
+ __networkx_cache__ = {}
556
571
if not as_view :
557
572
src_indices = src_indices .copy ()
558
573
dst_indices = dst_indices .copy ()
@@ -564,6 +579,8 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
564
579
key_to_id = key_to_id .copy ()
565
580
if id_to_key is not None :
566
581
id_to_key = id_to_key .copy ()
582
+ if __networkx_cache__ is not None :
583
+ __networkx_cache__ = __networkx_cache__ .copy ()
567
584
if reverse :
568
585
src_indices , dst_indices = dst_indices , src_indices
569
586
rv = cls .from_coo (
@@ -581,6 +598,7 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
581
598
rv .graph = self .graph
582
599
else :
583
600
rv .graph .update (deepcopy (self .graph ))
601
+ rv .__networkx_cache__ = __networkx_cache__
584
602
return rv
585
603
586
604
def _get_plc_graph (
@@ -719,18 +737,26 @@ def _become(self, other: Graph):
719
737
edge_masks = self .edge_masks
720
738
node_values = self .node_values
721
739
node_masks = self .node_masks
740
+ __networkx_cache__ = self .__networkx_cache__
722
741
graph = self .graph
723
742
edge_values .update (other .edge_values )
724
743
edge_masks .update (other .edge_masks )
725
744
node_values .update (other .node_values )
726
745
node_masks .update (other .node_masks )
727
746
graph .update (other .graph )
747
+ if other .__networkx_cache__ is None :
748
+ __networkx_cache__ = None
749
+ else :
750
+ if __networkx_cache__ is None :
751
+ __networkx_cache__ = {}
752
+ __networkx_cache__ .update (other .__networkx_cache__ )
728
753
self .__dict__ .update (other .__dict__ )
729
754
self .edge_values = edge_values
730
755
self .edge_masks = edge_masks
731
756
self .node_values = node_values
732
757
self .node_masks = node_masks
733
758
self .graph = graph
759
+ self .__networkx_cache__ = __networkx_cache__
734
760
return self
735
761
736
762
def _degrees_array (self , * , ignore_selfloops = False ):
0 commit comments