@@ -219,22 +219,12 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
219219
220220 slot = & i .children [(hash >> hashShift )& nChildrenMask ]
221221 n = slot .Load ()
222- if n == nil {
222+ if n == nil || n . isEntry {
223223 // We found a nil slot which is a candidate for insertion,
224224 // or an existing entry that we'll replace.
225225 haveInsertPoint = true
226226 break
227227 }
228- if n .isEntry {
229- // Swap if the keys compare.
230- old , swapped := n .entry ().swap (key , new )
231- if swapped {
232- return old , true
233- }
234- // If we fail, that means we should try to insert.
235- haveInsertPoint = true
236- break
237- }
238228 i = n .indirect ()
239229 }
240230 if ! haveInsertPoint {
@@ -261,10 +251,11 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
261251 var zero V
262252 var oldEntry * entry [K , V ]
263253 if n != nil {
264- // Between before and now, something got inserted. Swap if the keys compare.
254+ // Swap if the keys compare.
265255 oldEntry = n .entry ()
266- old , swapped := oldEntry .swap (key , new )
256+ newEntry , old , swapped := oldEntry .swap (key , new )
267257 if swapped {
258+ slot .Store (& newEntry .node )
268259 return old , true
269260 }
270261 }
@@ -292,30 +283,25 @@ func (ht *HashTrieMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) {
292283 panic ("called CompareAndSwap when value is not of comparable type" )
293284 }
294285 hash := ht .keyHash (abi .NoEscape (unsafe .Pointer (& key )), ht .seed )
295- for {
296- // Find the key or return if it's not there.
297- i := ht .root .Load ()
298- hashShift := 8 * goarch .PtrSize
299- found := false
300- for hashShift != 0 {
301- hashShift -= nChildrenLog2
302286
303- slot := & i .children [(hash >> hashShift )& nChildrenMask ]
304- n := slot .Load ()
305- if n == nil {
306- // Nothing to compare with. Give up.
307- return false
308- }
309- if n .isEntry {
310- // We found an entry. Try to compare and swap directly.
311- return n .entry ().compareAndSwap (key , old , new , ht .valEqual )
312- }
313- i = n .indirect ()
314- }
315- if ! found {
316- panic ("internal/concurrent.HashMapTrie: ran out of hash bits while iterating" )
317- }
287+ // Find a node with the key and compare with it. n != nil if we found the node.
288+ i , _ , slot , n := ht .find (key , hash , ht .valEqual , old )
289+ if i != nil {
290+ defer i .mu .Unlock ()
318291 }
292+ if n == nil {
293+ return false
294+ }
295+
296+ // Try to swap the entry.
297+ e , swapped := n .entry ().compareAndSwap (key , old , new , ht .valEqual )
298+ if ! swapped {
299+ // Nothing was actually swapped, which means the node is no longer there.
300+ return false
301+ }
302+ // Store the entry back because it changed.
303+ slot .Store (& e .node )
304+ return true
319305}
320306
321307// LoadAndDelete deletes the value for a key, returning the previous value if any.
@@ -523,7 +509,7 @@ func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V)
523509 }
524510 e := n .entry ()
525511 for e != nil {
526- if ! yield (e .key , * e .value . Load () ) {
512+ if ! yield (e .key , e .value ) {
527513 return false
528514 }
529515 e = e .overflow .Load ()
@@ -579,22 +565,21 @@ type entry[K comparable, V any] struct {
579565 node [K , V ]
580566 overflow atomic.Pointer [entry [K , V ]] // Overflow for hash collisions.
581567 key K
582- value atomic. Pointer [ V ]
568+ value V
583569}
584570
585571func newEntryNode [K comparable , V any ](key K , value V ) * entry [K , V ] {
586- e := & entry [K , V ]{
587- node : node [K , V ]{isEntry : true },
588- key : key ,
572+ return & entry [K , V ]{
573+ node : node [K , V ]{isEntry : true },
574+ key : key ,
575+ value : value ,
589576 }
590- e .value .Store (& value )
591- return e
592577}
593578
594579func (e * entry [K , V ]) lookup (key K ) (V , bool ) {
595580 for e != nil {
596581 if e .key == key {
597- return * e .value . Load () , true
582+ return e .value , true
598583 }
599584 e = e .overflow .Load ()
600585 }
@@ -603,87 +588,69 @@ func (e *entry[K, V]) lookup(key K) (V, bool) {
603588
604589func (e * entry [K , V ]) lookupWithValue (key K , value V , valEqual equalFunc ) (V , bool ) {
605590 for e != nil {
606- oldp := e .value .Load ()
607- if e .key == key && (valEqual == nil || valEqual (unsafe .Pointer (oldp ), abi .NoEscape (unsafe .Pointer (& value )))) {
608- return * oldp , true
591+ if e .key == key && (valEqual == nil || valEqual (unsafe .Pointer (& e .value ), abi .NoEscape (unsafe .Pointer (& value )))) {
592+ return e .value , true
609593 }
610594 e = e .overflow .Load ()
611595 }
612596 return * new (V ), false
613597}
614598
615- // swap replaces a value in the overflow chain if keys compare equal.
616- // Returns the old value, and whether or not anything was swapped.
599+ // swap replaces an entry in the overflow chain if keys compare equal. Returns the new entry chain,
600+ // the old value, and whether or not anything was swapped.
617601//
618602// swap must be called under the mutex of the indirect node which e is a child of.
619- func (head * entry [K , V ]) swap (key K , newv V ) (V , bool ) {
603+ func (head * entry [K , V ]) swap (key K , new V ) (* entry [ K , V ], V , bool ) {
620604 if head .key == key {
621- vp := new (V )
622- * vp = newv
623- oldp := head .value .Swap (vp )
624- return * oldp , true
605+ // Return the new head of the list.
606+ e := newEntryNode (key , new )
607+ if chain := head .overflow .Load (); chain != nil {
608+ e .overflow .Store (chain )
609+ }
610+ return e , head .value , true
625611 }
626612 i := & head .overflow
627613 e := i .Load ()
628614 for e != nil {
629615 if e .key == key {
630- vp := new ( V )
631- * vp = newv
632- oldp := e . value . Swap ( vp )
633- return * oldp , true
616+ eNew := newEntryNode ( key , new )
617+ eNew . overflow . Store ( e . overflow . Load ())
618+ i . Store ( eNew )
619+ return head , e . value , true
634620 }
635621 i = & e .overflow
636622 e = e .overflow .Load ()
637623 }
638624 var zero V
639- return zero , false
625+ return head , zero , false
640626}
641627
642- // compareAndSwap replaces a value for a matching key and existing value in the overflow chain.
643- // Returns whether or not anything was swapped.
628+ // compareAndSwap replaces an entry in the overflow chain if both the key and value compare
629+ // equal. Returns the new entry chain and whether or not anything was swapped.
644630//
645631// compareAndSwap must be called under the mutex of the indirect node which e is a child of.
646- func (head * entry [K , V ]) compareAndSwap (key K , oldv , newv V , valEqual equalFunc ) bool {
647- var vbox * V
648- outerLoop:
649- for {
650- oldvp := head .value .Load ()
651- if head .key == key && valEqual (unsafe .Pointer (oldvp ), abi .NoEscape (unsafe .Pointer (& oldv ))) {
652- // Return the new head of the list.
653- if vbox == nil {
654- // Delay explicit creation of a new value to hold newv. If we just pass &newv
655- // to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails.
656- vbox = new (V )
657- * vbox = newv
658- }
659- if head .value .CompareAndSwap (oldvp , vbox ) {
660- return true
661- }
662- // We need to restart from the head of the overflow list in case, due to a removal, a node
663- // is moved up the list and we miss it.
664- continue outerLoop
632+ func (head * entry [K , V ]) compareAndSwap (key K , old , new V , valEqual equalFunc ) (* entry [K , V ], bool ) {
633+ if head .key == key && valEqual (unsafe .Pointer (& head .value ), abi .NoEscape (unsafe .Pointer (& old ))) {
634+ // Return the new head of the list.
635+ e := newEntryNode (key , new )
636+ if chain := head .overflow .Load (); chain != nil {
637+ e .overflow .Store (chain )
665638 }
666- i := & head .overflow
667- e := i .Load ()
668- for e != nil {
669- oldvp := e .value .Load ()
670- if e .key == key && valEqual (unsafe .Pointer (oldvp ), abi .NoEscape (unsafe .Pointer (& oldv ))) {
671- if vbox == nil {
672- // Delay explicit creation of a new value to hold newv. If we just pass &newv
673- // to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails.
674- vbox = new (V )
675- * vbox = newv
676- }
677- if e .value .CompareAndSwap (oldvp , vbox ) {
678- return true
679- }
680- continue outerLoop
681- }
682- i = & e .overflow
683- e = e .overflow .Load ()
639+ return e , true
640+ }
641+ i := & head .overflow
642+ e := i .Load ()
643+ for e != nil {
644+ if e .key == key && valEqual (unsafe .Pointer (& e .value ), abi .NoEscape (unsafe .Pointer (& old ))) {
645+ eNew := newEntryNode (key , new )
646+ eNew .overflow .Store (e .overflow .Load ())
647+ i .Store (eNew )
648+ return head , true
684649 }
685- return false
650+ i = & e .overflow
651+ e = e .overflow .Load ()
686652 }
653+ return head , false
687654}
688655
689656// loadAndDelete deletes an entry in the overflow chain by key. Returns the value for the key, the new
@@ -693,14 +660,14 @@ outerLoop:
693660func (head * entry [K , V ]) loadAndDelete (key K ) (V , * entry [K , V ], bool ) {
694661 if head .key == key {
695662 // Drop the head of the list.
696- return * head .value . Load () , head .overflow .Load (), true
663+ return head .value , head .overflow .Load (), true
697664 }
698665 i := & head .overflow
699666 e := i .Load ()
700667 for e != nil {
701668 if e .key == key {
702669 i .Store (e .overflow .Load ())
703- return * e .value . Load () , head , true
670+ return e .value , head , true
704671 }
705672 i = & e .overflow
706673 e = e .overflow .Load ()
@@ -713,14 +680,14 @@ func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) {
713680//
714681// compareAndDelete must be called under the mutex of the indirect node which e is a child of.
715682func (head * entry [K , V ]) compareAndDelete (key K , value V , valEqual equalFunc ) (* entry [K , V ], bool ) {
716- if head .key == key && valEqual (unsafe .Pointer (head .value . Load () ), abi .NoEscape (unsafe .Pointer (& value ))) {
683+ if head .key == key && valEqual (unsafe .Pointer (& head .value ), abi .NoEscape (unsafe .Pointer (& value ))) {
717684 // Drop the head of the list.
718685 return head .overflow .Load (), true
719686 }
720687 i := & head .overflow
721688 e := i .Load ()
722689 for e != nil {
723- if e .key == key && valEqual (unsafe .Pointer (e .value . Load () ), abi .NoEscape (unsafe .Pointer (& value ))) {
690+ if e .key == key && valEqual (unsafe .Pointer (& e .value ), abi .NoEscape (unsafe .Pointer (& value ))) {
724691 i .Store (e .overflow .Load ())
725692 return head , true
726693 }
0 commit comments