@@ -263,8 +263,6 @@ pub struct ClusterMap<S = gxhash::GxBuildHasher> {
263263 version : AtomicU64 ,
264264}
265265
266- type DashMapRef < ' inner > = dashmap:: mapref:: one:: Ref < ' inner , Option < Locality > , EndpointSet > ;
267-
268266impl ClusterMap {
269267 pub fn new ( ) -> Self {
270268 Self :: default ( )
@@ -367,15 +365,39 @@ where
367365 }
368366
369367 #[ inline]
370- pub fn get ( & self , key : & Option < Locality > ) -> Option < DashMapRef < ' _ > > {
371- self . map . get ( key)
368+ pub fn exists ( & self , key : & Option < Locality > ) -> bool {
369+ self . map . get ( key) . is_some ( )
370+ }
371+
372+ #[ inline]
373+ pub fn with_value < F , T > ( & self , key : & Option < Locality > , func : F ) -> Option < T >
374+ where
375+ F : FnOnce ( dashmap:: mapref:: one:: Ref < ' _ , Option < Locality > , EndpointSet > ) -> T ,
376+ {
377+ self . map . get ( key) . map ( func)
372378 }
373379
374380 #[ inline]
375381 pub fn insert_default ( & self , endpoints : BTreeSet < Endpoint > ) {
376382 self . insert ( None , None , endpoints) ;
377383 }
378384
385+ /// Iterates over the entries in the `ClusterMap` with the given func
386+ ///
387+ /// This ensures that the dashmap entry references are never held across await boundaries as
388+ /// the func cannot be async.
389+ #[ inline]
390+ pub fn iter_with < F , T > ( & self , func : F ) -> Vec < T >
391+ where
392+ F : for < ' a > Fn ( & ' a Option < Locality > , & ' a EndpointSet ) -> T ,
393+ {
394+ let mut results: Vec < T > = Vec :: new ( ) ;
395+ for entry in self . map . iter ( ) {
396+ results. push ( func ( entry. key ( ) , entry. value ( ) ) ) ;
397+ }
398+ results
399+ }
400+
379401 #[ inline]
380402 pub fn remove_endpoint ( & self , needle : & Endpoint ) -> bool {
381403 let locality = ' l: {
@@ -430,11 +452,6 @@ where
430452 true
431453 }
432454
433- #[ inline]
434- pub fn iter ( & self ) -> dashmap:: iter:: Iter < ' _ , Option < Locality > , EndpointSet , S > {
435- self . map . iter ( )
436- }
437-
438455 #[ inline]
439456 pub fn replace (
440457 & self ,
@@ -477,7 +494,7 @@ where
477494 }
478495
479496 pub fn nth_endpoint ( & self , mut index : usize ) -> Option < Endpoint > {
480- for set in self . iter ( ) {
497+ for set in self . map . iter ( ) {
481498 let set = & set. value ( ) . endpoints ;
482499 if index < set. len ( ) {
483500 return set. iter ( ) . nth ( index) . cloned ( ) ;
@@ -492,7 +509,7 @@ where
492509 pub fn filter_endpoints ( & self , f : impl Fn ( & Endpoint ) -> bool ) -> Vec < Endpoint > {
493510 let mut endpoints = Vec :: new ( ) ;
494511
495- for set in self . iter ( ) {
512+ for set in self . map . iter ( ) {
496513 for endpoint in set. endpoints . iter ( ) . filter ( |e| ( f) ( e) ) {
497514 endpoints. push ( endpoint. clone ( ) ) ;
498515 }
@@ -639,8 +656,9 @@ where
639656 S : Default + std:: hash:: BuildHasher + Clone ,
640657{
641658 fn eq ( & self , rhs : & Self ) -> bool {
642- for a in self . iter ( ) {
659+ for a in self . map . iter ( ) {
643660 match rhs
661+ . map
644662 . get ( a. key ( ) )
645663 . filter ( |b| a. value ( ) . endpoints == b. endpoints )
646664 {
@@ -808,16 +826,18 @@ mod tests {
808826 cluster1. insert ( None , Some ( nl1. clone ( ) ) , [ endpoint. clone ( ) ] . into ( ) ) ;
809827 cluster1. insert ( None , Some ( de1. clone ( ) ) , [ endpoint. clone ( ) ] . into ( ) ) ;
810828
811- assert_eq ! ( cluster1. get( & Some ( nl1. clone( ) ) ) . unwrap( ) . len( ) , 1 ) ;
829+ assert_eq ! ( cluster1. map . get( & Some ( nl1. clone( ) ) ) . unwrap( ) . len( ) , 1 ) ;
812830 assert ! (
813831 cluster1
832+ . map
814833 . get( & Some ( nl1. clone( ) ) )
815834 . unwrap( )
816835 . contains( & endpoint)
817836 ) ;
818- assert_eq ! ( cluster1. get( & Some ( de1. clone( ) ) ) . unwrap( ) . len( ) , 1 ) ;
837+ assert_eq ! ( cluster1. map . get( & Some ( de1. clone( ) ) ) . unwrap( ) . len( ) , 1 ) ;
819838 assert ! (
820839 cluster1
840+ . map
821841 . get( & Some ( de1. clone( ) ) )
822842 . unwrap( )
823843 . contains( & endpoint)
@@ -827,19 +847,20 @@ mod tests {
827847
828848 cluster1. insert ( None , Some ( de1. clone ( ) ) , [ endpoint. clone ( ) ] . into ( ) ) ;
829849
830- assert_eq ! ( cluster1. get( & Some ( nl1. clone( ) ) ) . unwrap( ) . len( ) , 1 ) ;
831- assert_eq ! ( cluster1. get( & Some ( de1. clone( ) ) ) . unwrap( ) . len( ) , 1 ) ;
850+ assert_eq ! ( cluster1. map . get( & Some ( nl1. clone( ) ) ) . unwrap( ) . len( ) , 1 ) ;
851+ assert_eq ! ( cluster1. map . get( & Some ( de1. clone( ) ) ) . unwrap( ) . len( ) , 1 ) ;
832852 assert ! (
833853 cluster1
854+ . map
834855 . get( & Some ( de1. clone( ) ) )
835856 . unwrap( )
836857 . contains( & endpoint)
837858 ) ;
838859
839860 cluster1. insert ( None , Some ( de1. clone ( ) ) , <_ >:: default ( ) ) ;
840861
841- assert_eq ! ( cluster1. get( & Some ( nl1. clone( ) ) ) . unwrap( ) . len( ) , 1 ) ;
842- assert ! ( cluster1. get( & Some ( de1. clone( ) ) ) . unwrap( ) . is_empty( ) ) ;
862+ assert_eq ! ( cluster1. map . get( & Some ( nl1. clone( ) ) ) . unwrap( ) . len( ) , 1 ) ;
863+ assert ! ( cluster1. map . get( & Some ( de1. clone( ) ) ) . unwrap( ) . is_empty( ) ) ;
843864 }
844865
845866 #[ test]
@@ -862,12 +883,15 @@ mod tests {
862883 [ Endpoint :: new ( ( Ipv4Addr :: new ( 20 , 20 , 20 , 20 ) , 1234 ) . into ( ) ) ] . into ( ) ;
863884
864885 cluster. insert ( Some ( nl02. into ( ) ) , Some ( nl1. clone ( ) ) , not_expected. clone ( ) ) ;
865- assert_eq ! ( cluster. get( & Some ( nl1. clone( ) ) ) . unwrap( ) . endpoints, expected) ;
886+ assert_eq ! (
887+ cluster. map. get( & Some ( nl1. clone( ) ) ) . unwrap( ) . endpoints,
888+ expected
889+ ) ;
866890
867891 cluster. remove_locality ( Some ( nl01. into ( ) ) , & Some ( nl1. clone ( ) ) ) ;
868892
869893 cluster. insert ( Some ( nl02. into ( ) ) , Some ( nl1. clone ( ) ) , not_expected. clone ( ) ) ;
870- assert_eq ! ( cluster. get( & Some ( nl1) ) . unwrap( ) . endpoints, not_expected) ;
894+ assert_eq ! ( cluster. map . get( & Some ( nl1) ) . unwrap( ) . endpoints, not_expected) ;
871895 }
872896
873897 #[ test]
@@ -886,7 +910,7 @@ mod tests {
886910 assert ! ( cluster. remove_endpoint( ep) ) ;
887911 }
888912
889- assert ! ( cluster. get( & None ) . is_none( ) ) ;
913+ assert ! ( cluster. map . get( & None ) . is_none( ) ) ;
890914 }
891915
892916 {
@@ -895,7 +919,7 @@ mod tests {
895919 assert ! ( cluster. remove_endpoint_if( |_ep| true ) ) ;
896920 }
897921
898- assert ! ( cluster. get( & None ) . is_none( ) ) ;
922+ assert ! ( cluster. map . get( & None ) . is_none( ) ) ;
899923 }
900924 }
901925}
0 commit comments