Skip to content

Commit 5a2e008

Browse files
mknyszekwyf9661
authored andcommitted
Revert "internal/sync: optimize CompareAndSwap and Swap"
This reverts CL 606462. Reason for revert: Breaks atomicity between operations. See golang#70970. Change-Id: I1a899f2784da5a0f9da3193e3267275c23aea661 Reviewed-on: https://go-review.googlesource.com/c/go/+/638615 Auto-Submit: Michael Knyszek <[email protected]> Commit-Queue: Michael Knyszek <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: David Chase <[email protected]>
1 parent c551b7f commit 5a2e008

File tree

1 file changed

+69
-102
lines changed

1 file changed

+69
-102
lines changed

src/internal/sync/hashtriemap.go

+69-102
Original file line numberDiff line numberDiff line change
@@ -219,22 +219,12 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
219219

220220
slot = &i.children[(hash>>hashShift)&nChildrenMask]
221221
n = slot.Load()
222-
if n == nil {
222+
if n == nil || n.isEntry {
223223
// We found a nil slot which is a candidate for insertion,
224224
// or an existing entry that we'll replace.
225225
haveInsertPoint = true
226226
break
227227
}
228-
if n.isEntry {
229-
// Swap if the keys compare.
230-
old, swapped := n.entry().swap(key, new)
231-
if swapped {
232-
return old, true
233-
}
234-
// If we fail, that means we should try to insert.
235-
haveInsertPoint = true
236-
break
237-
}
238228
i = n.indirect()
239229
}
240230
if !haveInsertPoint {
@@ -261,10 +251,11 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
261251
var zero V
262252
var oldEntry *entry[K, V]
263253
if n != nil {
264-
// Between before and now, something got inserted. Swap if the keys compare.
254+
// Swap if the keys compare.
265255
oldEntry = n.entry()
266-
old, swapped := oldEntry.swap(key, new)
256+
newEntry, old, swapped := oldEntry.swap(key, new)
267257
if swapped {
258+
slot.Store(&newEntry.node)
268259
return old, true
269260
}
270261
}
@@ -292,30 +283,25 @@ func (ht *HashTrieMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) {
292283
panic("called CompareAndSwap when value is not of comparable type")
293284
}
294285
hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
295-
for {
296-
// Find the key or return if it's not there.
297-
i := ht.root.Load()
298-
hashShift := 8 * goarch.PtrSize
299-
found := false
300-
for hashShift != 0 {
301-
hashShift -= nChildrenLog2
302286

303-
slot := &i.children[(hash>>hashShift)&nChildrenMask]
304-
n := slot.Load()
305-
if n == nil {
306-
// Nothing to compare with. Give up.
307-
return false
308-
}
309-
if n.isEntry {
310-
// We found an entry. Try to compare and swap directly.
311-
return n.entry().compareAndSwap(key, old, new, ht.valEqual)
312-
}
313-
i = n.indirect()
314-
}
315-
if !found {
316-
panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
317-
}
287+
// Find a node with the key and compare with it. n != nil if we found the node.
288+
i, _, slot, n := ht.find(key, hash, ht.valEqual, old)
289+
if i != nil {
290+
defer i.mu.Unlock()
318291
}
292+
if n == nil {
293+
return false
294+
}
295+
296+
// Try to swap the entry.
297+
e, swapped := n.entry().compareAndSwap(key, old, new, ht.valEqual)
298+
if !swapped {
299+
// Nothing was actually swapped, which means the node is no longer there.
300+
return false
301+
}
302+
// Store the entry back because it changed.
303+
slot.Store(&e.node)
304+
return true
319305
}
320306

321307
// LoadAndDelete deletes the value for a key, returning the previous value if any.
@@ -523,7 +509,7 @@ func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V)
523509
}
524510
e := n.entry()
525511
for e != nil {
526-
if !yield(e.key, *e.value.Load()) {
512+
if !yield(e.key, e.value) {
527513
return false
528514
}
529515
e = e.overflow.Load()
@@ -579,22 +565,21 @@ type entry[K comparable, V any] struct {
579565
node[K, V]
580566
overflow atomic.Pointer[entry[K, V]] // Overflow for hash collisions.
581567
key K
582-
value atomic.Pointer[V]
568+
value V
583569
}
584570

585571
func newEntryNode[K comparable, V any](key K, value V) *entry[K, V] {
586-
e := &entry[K, V]{
587-
node: node[K, V]{isEntry: true},
588-
key: key,
572+
return &entry[K, V]{
573+
node: node[K, V]{isEntry: true},
574+
key: key,
575+
value: value,
589576
}
590-
e.value.Store(&value)
591-
return e
592577
}
593578

594579
func (e *entry[K, V]) lookup(key K) (V, bool) {
595580
for e != nil {
596581
if e.key == key {
597-
return *e.value.Load(), true
582+
return e.value, true
598583
}
599584
e = e.overflow.Load()
600585
}
@@ -603,87 +588,69 @@ func (e *entry[K, V]) lookup(key K) (V, bool) {
603588

604589
func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bool) {
605590
for e != nil {
606-
oldp := e.value.Load()
607-
if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(oldp), abi.NoEscape(unsafe.Pointer(&value)))) {
608-
return *oldp, true
591+
if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value)))) {
592+
return e.value, true
609593
}
610594
e = e.overflow.Load()
611595
}
612596
return *new(V), false
613597
}
614598

615-
// swap replaces a value in the overflow chain if keys compare equal.
616-
// Returns the old value, and whether or not anything was swapped.
599+
// swap replaces an entry in the overflow chain if keys compare equal. Returns the new entry chain,
600+
// the old value, and whether or not anything was swapped.
617601
//
618602
// swap must be called under the mutex of the indirect node which e is a child of.
619-
func (head *entry[K, V]) swap(key K, newv V) (V, bool) {
603+
func (head *entry[K, V]) swap(key K, new V) (*entry[K, V], V, bool) {
620604
if head.key == key {
621-
vp := new(V)
622-
*vp = newv
623-
oldp := head.value.Swap(vp)
624-
return *oldp, true
605+
// Return the new head of the list.
606+
e := newEntryNode(key, new)
607+
if chain := head.overflow.Load(); chain != nil {
608+
e.overflow.Store(chain)
609+
}
610+
return e, head.value, true
625611
}
626612
i := &head.overflow
627613
e := i.Load()
628614
for e != nil {
629615
if e.key == key {
630-
vp := new(V)
631-
*vp = newv
632-
oldp := e.value.Swap(vp)
633-
return *oldp, true
616+
eNew := newEntryNode(key, new)
617+
eNew.overflow.Store(e.overflow.Load())
618+
i.Store(eNew)
619+
return head, e.value, true
634620
}
635621
i = &e.overflow
636622
e = e.overflow.Load()
637623
}
638624
var zero V
639-
return zero, false
625+
return head, zero, false
640626
}
641627

642-
// compareAndSwap replaces a value for a matching key and existing value in the overflow chain.
643-
// Returns whether or not anything was swapped.
628+
// compareAndSwap replaces an entry in the overflow chain if both the key and value compare
629+
// equal. Returns the new entry chain and whether or not anything was swapped.
644630
//
645631
// compareAndSwap must be called under the mutex of the indirect node which e is a child of.
646-
func (head *entry[K, V]) compareAndSwap(key K, oldv, newv V, valEqual equalFunc) bool {
647-
var vbox *V
648-
outerLoop:
649-
for {
650-
oldvp := head.value.Load()
651-
if head.key == key && valEqual(unsafe.Pointer(oldvp), abi.NoEscape(unsafe.Pointer(&oldv))) {
652-
// Return the new head of the list.
653-
if vbox == nil {
654-
// Delay explicit creation of a new value to hold newv. If we just pass &newv
655-
// to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails.
656-
vbox = new(V)
657-
*vbox = newv
658-
}
659-
if head.value.CompareAndSwap(oldvp, vbox) {
660-
return true
661-
}
662-
// We need to restart from the head of the overflow list in case, due to a removal, a node
663-
// is moved up the list and we miss it.
664-
continue outerLoop
632+
func (head *entry[K, V]) compareAndSwap(key K, old, new V, valEqual equalFunc) (*entry[K, V], bool) {
633+
if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&old))) {
634+
// Return the new head of the list.
635+
e := newEntryNode(key, new)
636+
if chain := head.overflow.Load(); chain != nil {
637+
e.overflow.Store(chain)
665638
}
666-
i := &head.overflow
667-
e := i.Load()
668-
for e != nil {
669-
oldvp := e.value.Load()
670-
if e.key == key && valEqual(unsafe.Pointer(oldvp), abi.NoEscape(unsafe.Pointer(&oldv))) {
671-
if vbox == nil {
672-
// Delay explicit creation of a new value to hold newv. If we just pass &newv
673-
// to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails.
674-
vbox = new(V)
675-
*vbox = newv
676-
}
677-
if e.value.CompareAndSwap(oldvp, vbox) {
678-
return true
679-
}
680-
continue outerLoop
681-
}
682-
i = &e.overflow
683-
e = e.overflow.Load()
639+
return e, true
640+
}
641+
i := &head.overflow
642+
e := i.Load()
643+
for e != nil {
644+
if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&old))) {
645+
eNew := newEntryNode(key, new)
646+
eNew.overflow.Store(e.overflow.Load())
647+
i.Store(eNew)
648+
return head, true
684649
}
685-
return false
650+
i = &e.overflow
651+
e = e.overflow.Load()
686652
}
653+
return head, false
687654
}
688655

689656
// loadAndDelete deletes an entry in the overflow chain by key. Returns the value for the key, the new
@@ -693,14 +660,14 @@ outerLoop:
693660
func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) {
694661
if head.key == key {
695662
// Drop the head of the list.
696-
return *head.value.Load(), head.overflow.Load(), true
663+
return head.value, head.overflow.Load(), true
697664
}
698665
i := &head.overflow
699666
e := i.Load()
700667
for e != nil {
701668
if e.key == key {
702669
i.Store(e.overflow.Load())
703-
return *e.value.Load(), head, true
670+
return e.value, head, true
704671
}
705672
i = &e.overflow
706673
e = e.overflow.Load()
@@ -713,14 +680,14 @@ func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) {
713680
//
714681
// compareAndDelete must be called under the mutex of the indirect node which e is a child of.
715682
func (head *entry[K, V]) compareAndDelete(key K, value V, valEqual equalFunc) (*entry[K, V], bool) {
716-
if head.key == key && valEqual(unsafe.Pointer(head.value.Load()), abi.NoEscape(unsafe.Pointer(&value))) {
683+
if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&value))) {
717684
// Drop the head of the list.
718685
return head.overflow.Load(), true
719686
}
720687
i := &head.overflow
721688
e := i.Load()
722689
for e != nil {
723-
if e.key == key && valEqual(unsafe.Pointer(e.value.Load()), abi.NoEscape(unsafe.Pointer(&value))) {
690+
if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value))) {
724691
i.Store(e.overflow.Load())
725692
return head, true
726693
}

0 commit comments

Comments
 (0)