@@ -219,22 +219,12 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
219
219
220
220
slot = & i .children [(hash >> hashShift )& nChildrenMask ]
221
221
n = slot .Load ()
222
- if n == nil {
222
+ if n == nil || n . isEntry {
223
223
// We found a nil slot which is a candidate for insertion,
224
224
// or an existing entry that we'll replace.
225
225
haveInsertPoint = true
226
226
break
227
227
}
228
- if n .isEntry {
229
- // Swap if the keys compare.
230
- old , swapped := n .entry ().swap (key , new )
231
- if swapped {
232
- return old , true
233
- }
234
- // If we fail, that means we should try to insert.
235
- haveInsertPoint = true
236
- break
237
- }
238
228
i = n .indirect ()
239
229
}
240
230
if ! haveInsertPoint {
@@ -261,10 +251,11 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
261
251
var zero V
262
252
var oldEntry * entry [K , V ]
263
253
if n != nil {
264
- // Between before and now, something got inserted. Swap if the keys compare.
254
+ // Swap if the keys compare.
265
255
oldEntry = n .entry ()
266
- old , swapped := oldEntry .swap (key , new )
256
+ newEntry , old , swapped := oldEntry .swap (key , new )
267
257
if swapped {
258
+ slot .Store (& newEntry .node )
268
259
return old , true
269
260
}
270
261
}
@@ -292,30 +283,25 @@ func (ht *HashTrieMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) {
292
283
panic ("called CompareAndSwap when value is not of comparable type" )
293
284
}
294
285
hash := ht .keyHash (abi .NoEscape (unsafe .Pointer (& key )), ht .seed )
295
- for {
296
- // Find the key or return if it's not there.
297
- i := ht .root .Load ()
298
- hashShift := 8 * goarch .PtrSize
299
- found := false
300
- for hashShift != 0 {
301
- hashShift -= nChildrenLog2
302
286
303
- slot := & i .children [(hash >> hashShift )& nChildrenMask ]
304
- n := slot .Load ()
305
- if n == nil {
306
- // Nothing to compare with. Give up.
307
- return false
308
- }
309
- if n .isEntry {
310
- // We found an entry. Try to compare and swap directly.
311
- return n .entry ().compareAndSwap (key , old , new , ht .valEqual )
312
- }
313
- i = n .indirect ()
314
- }
315
- if ! found {
316
- panic ("internal/concurrent.HashMapTrie: ran out of hash bits while iterating" )
317
- }
287
+ // Find a node with the key and compare with it. n != nil if we found the node.
288
+ i , _ , slot , n := ht .find (key , hash , ht .valEqual , old )
289
+ if i != nil {
290
+ defer i .mu .Unlock ()
318
291
}
292
+ if n == nil {
293
+ return false
294
+ }
295
+
296
+ // Try to swap the entry.
297
+ e , swapped := n .entry ().compareAndSwap (key , old , new , ht .valEqual )
298
+ if ! swapped {
299
+ // Nothing was actually swapped, which means the node is no longer there.
300
+ return false
301
+ }
302
+ // Store the entry back because it changed.
303
+ slot .Store (& e .node )
304
+ return true
319
305
}
320
306
321
307
// LoadAndDelete deletes the value for a key, returning the previous value if any.
@@ -523,7 +509,7 @@ func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V)
523
509
}
524
510
e := n .entry ()
525
511
for e != nil {
526
- if ! yield (e .key , * e .value . Load () ) {
512
+ if ! yield (e .key , e .value ) {
527
513
return false
528
514
}
529
515
e = e .overflow .Load ()
@@ -579,22 +565,21 @@ type entry[K comparable, V any] struct {
579
565
node [K , V ]
580
566
overflow atomic.Pointer [entry [K , V ]] // Overflow for hash collisions.
581
567
key K
582
- value atomic. Pointer [ V ]
568
+ value V
583
569
}
584
570
585
571
func newEntryNode [K comparable , V any ](key K , value V ) * entry [K , V ] {
586
- e := & entry [K , V ]{
587
- node : node [K , V ]{isEntry : true },
588
- key : key ,
572
+ return & entry [K , V ]{
573
+ node : node [K , V ]{isEntry : true },
574
+ key : key ,
575
+ value : value ,
589
576
}
590
- e .value .Store (& value )
591
- return e
592
577
}
593
578
594
579
func (e * entry [K , V ]) lookup (key K ) (V , bool ) {
595
580
for e != nil {
596
581
if e .key == key {
597
- return * e .value . Load () , true
582
+ return e .value , true
598
583
}
599
584
e = e .overflow .Load ()
600
585
}
@@ -603,87 +588,69 @@ func (e *entry[K, V]) lookup(key K) (V, bool) {
603
588
604
589
func (e * entry [K , V ]) lookupWithValue (key K , value V , valEqual equalFunc ) (V , bool ) {
605
590
for e != nil {
606
- oldp := e .value .Load ()
607
- if e .key == key && (valEqual == nil || valEqual (unsafe .Pointer (oldp ), abi .NoEscape (unsafe .Pointer (& value )))) {
608
- return * oldp , true
591
+ if e .key == key && (valEqual == nil || valEqual (unsafe .Pointer (& e .value ), abi .NoEscape (unsafe .Pointer (& value )))) {
592
+ return e .value , true
609
593
}
610
594
e = e .overflow .Load ()
611
595
}
612
596
return * new (V ), false
613
597
}
614
598
615
- // swap replaces a value in the overflow chain if keys compare equal.
616
- // Returns the old value, and whether or not anything was swapped.
599
+ // swap replaces an entry in the overflow chain if keys compare equal. Returns the new entry chain,
600
+ // the old value, and whether or not anything was swapped.
617
601
//
618
602
// swap must be called under the mutex of the indirect node which e is a child of.
619
- func (head * entry [K , V ]) swap (key K , newv V ) (V , bool ) {
603
+ func (head * entry [K , V ]) swap (key K , new V ) (* entry [ K , V ], V , bool ) {
620
604
if head .key == key {
621
- vp := new (V )
622
- * vp = newv
623
- oldp := head .value .Swap (vp )
624
- return * oldp , true
605
+ // Return the new head of the list.
606
+ e := newEntryNode (key , new )
607
+ if chain := head .overflow .Load (); chain != nil {
608
+ e .overflow .Store (chain )
609
+ }
610
+ return e , head .value , true
625
611
}
626
612
i := & head .overflow
627
613
e := i .Load ()
628
614
for e != nil {
629
615
if e .key == key {
630
- vp := new ( V )
631
- * vp = newv
632
- oldp := e . value . Swap ( vp )
633
- return * oldp , true
616
+ eNew := newEntryNode ( key , new )
617
+ eNew . overflow . Store ( e . overflow . Load ())
618
+ i . Store ( eNew )
619
+ return head , e . value , true
634
620
}
635
621
i = & e .overflow
636
622
e = e .overflow .Load ()
637
623
}
638
624
var zero V
639
- return zero , false
625
+ return head , zero , false
640
626
}
641
627
642
- // compareAndSwap replaces a value for a matching key and existing value in the overflow chain.
643
- // Returns whether or not anything was swapped.
628
+ // compareAndSwap replaces an entry in the overflow chain if both the key and value compare
629
+ // equal. Returns the new entry chain and whether or not anything was swapped.
644
630
//
645
631
// compareAndSwap must be called under the mutex of the indirect node which e is a child of.
646
- func (head * entry [K , V ]) compareAndSwap (key K , oldv , newv V , valEqual equalFunc ) bool {
647
- var vbox * V
648
- outerLoop:
649
- for {
650
- oldvp := head .value .Load ()
651
- if head .key == key && valEqual (unsafe .Pointer (oldvp ), abi .NoEscape (unsafe .Pointer (& oldv ))) {
652
- // Return the new head of the list.
653
- if vbox == nil {
654
- // Delay explicit creation of a new value to hold newv. If we just pass &newv
655
- // to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails.
656
- vbox = new (V )
657
- * vbox = newv
658
- }
659
- if head .value .CompareAndSwap (oldvp , vbox ) {
660
- return true
661
- }
662
- // We need to restart from the head of the overflow list in case, due to a removal, a node
663
- // is moved up the list and we miss it.
664
- continue outerLoop
632
+ func (head * entry [K , V ]) compareAndSwap (key K , old , new V , valEqual equalFunc ) (* entry [K , V ], bool ) {
633
+ if head .key == key && valEqual (unsafe .Pointer (& head .value ), abi .NoEscape (unsafe .Pointer (& old ))) {
634
+ // Return the new head of the list.
635
+ e := newEntryNode (key , new )
636
+ if chain := head .overflow .Load (); chain != nil {
637
+ e .overflow .Store (chain )
665
638
}
666
- i := & head .overflow
667
- e := i .Load ()
668
- for e != nil {
669
- oldvp := e .value .Load ()
670
- if e .key == key && valEqual (unsafe .Pointer (oldvp ), abi .NoEscape (unsafe .Pointer (& oldv ))) {
671
- if vbox == nil {
672
- // Delay explicit creation of a new value to hold newv. If we just pass &newv
673
- // to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails.
674
- vbox = new (V )
675
- * vbox = newv
676
- }
677
- if e .value .CompareAndSwap (oldvp , vbox ) {
678
- return true
679
- }
680
- continue outerLoop
681
- }
682
- i = & e .overflow
683
- e = e .overflow .Load ()
639
+ return e , true
640
+ }
641
+ i := & head .overflow
642
+ e := i .Load ()
643
+ for e != nil {
644
+ if e .key == key && valEqual (unsafe .Pointer (& e .value ), abi .NoEscape (unsafe .Pointer (& old ))) {
645
+ eNew := newEntryNode (key , new )
646
+ eNew .overflow .Store (e .overflow .Load ())
647
+ i .Store (eNew )
648
+ return head , true
684
649
}
685
- return false
650
+ i = & e .overflow
651
+ e = e .overflow .Load ()
686
652
}
653
+ return head , false
687
654
}
688
655
689
656
// loadAndDelete deletes an entry in the overflow chain by key. Returns the value for the key, the new
@@ -693,14 +660,14 @@ outerLoop:
693
660
func (head * entry [K , V ]) loadAndDelete (key K ) (V , * entry [K , V ], bool ) {
694
661
if head .key == key {
695
662
// Drop the head of the list.
696
- return * head .value . Load () , head .overflow .Load (), true
663
+ return head .value , head .overflow .Load (), true
697
664
}
698
665
i := & head .overflow
699
666
e := i .Load ()
700
667
for e != nil {
701
668
if e .key == key {
702
669
i .Store (e .overflow .Load ())
703
- return * e .value . Load () , head , true
670
+ return e .value , head , true
704
671
}
705
672
i = & e .overflow
706
673
e = e .overflow .Load ()
@@ -713,14 +680,14 @@ func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) {
713
680
//
714
681
// compareAndDelete must be called under the mutex of the indirect node which e is a child of.
715
682
func (head * entry [K , V ]) compareAndDelete (key K , value V , valEqual equalFunc ) (* entry [K , V ], bool ) {
716
- if head .key == key && valEqual (unsafe .Pointer (head .value . Load () ), abi .NoEscape (unsafe .Pointer (& value ))) {
683
+ if head .key == key && valEqual (unsafe .Pointer (& head .value ), abi .NoEscape (unsafe .Pointer (& value ))) {
717
684
// Drop the head of the list.
718
685
return head .overflow .Load (), true
719
686
}
720
687
i := & head .overflow
721
688
e := i .Load ()
722
689
for e != nil {
723
- if e .key == key && valEqual (unsafe .Pointer (e .value . Load () ), abi .NoEscape (unsafe .Pointer (& value ))) {
690
+ if e .key == key && valEqual (unsafe .Pointer (& e .value ), abi .NoEscape (unsafe .Pointer (& value ))) {
724
691
i .Store (e .overflow .Load ())
725
692
return head , true
726
693
}
0 commit comments