From 3215f29e93698f13290e00cb2c69f41d32e050c1 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 11 Nov 2024 20:51:53 +0530 Subject: [PATCH] eventbus: dont panic on closing Subscription twice --- p2p/host/eventbus/basic.go | 40 +++++++++++++++++---------------- p2p/host/eventbus/basic_test.go | 11 +++++++++ 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/p2p/host/eventbus/basic.go b/p2p/host/eventbus/basic.go index 42365a7916..a91f84e79b 100644 --- a/p2p/host/eventbus/basic.go +++ b/p2p/host/eventbus/basic.go @@ -145,6 +145,7 @@ type sub struct { dropper func(reflect.Type) metricsTracer MetricsTracer name string + closeOnce sync.Once } func (s *sub) Name() string { @@ -162,31 +163,32 @@ func (s *sub) Close() error { for range s.ch { } }() - - for _, n := range s.nodes { - n.lk.Lock() - - for i := 0; i < len(n.sinks); i++ { - if n.sinks[i].ch == s.ch { - n.sinks[i], n.sinks[len(n.sinks)-1] = n.sinks[len(n.sinks)-1], nil - n.sinks = n.sinks[:len(n.sinks)-1] - - if s.metricsTracer != nil { - s.metricsTracer.RemoveSubscriber(n.typ) + s.closeOnce.Do(func() { + for _, n := range s.nodes { + n.lk.Lock() + + for i := 0; i < len(n.sinks); i++ { + if n.sinks[i].ch == s.ch { + n.sinks[i], n.sinks[len(n.sinks)-1] = n.sinks[len(n.sinks)-1], nil + n.sinks = n.sinks[:len(n.sinks)-1] + + if s.metricsTracer != nil { + s.metricsTracer.RemoveSubscriber(n.typ) + } + break } - break } - } - tryDrop := len(n.sinks) == 0 && n.nEmitters.Load() == 0 + tryDrop := len(n.sinks) == 0 && n.nEmitters.Load() == 0 - n.lk.Unlock() + n.lk.Unlock() - if tryDrop { - s.dropper(n.typ) + if tryDrop { + s.dropper(n.typ) + } } - } - close(s.ch) + close(s.ch) + }) return nil } diff --git a/p2p/host/eventbus/basic_test.go b/p2p/host/eventbus/basic_test.go index 57362ce9b7..d7c9287e8a 100644 --- a/p2p/host/eventbus/basic_test.go +++ b/p2p/host/eventbus/basic_test.go @@ -481,6 +481,17 @@ func TestSubFailFully(t *testing.T) { } } +func TestSubCloseMultiple(t *testing.T) { + bus := NewBus() + + sub, err := bus.Subscribe([]interface{}{new(EventB)}) + require.NoError(t, err) + err = sub.Close() + require.NoError(t, err) + err = sub.Close() + require.NoError(t, err) +} + func testMany(t testing.TB, subs, emits, msgs int, stateful bool) { if race.WithRace() && subs+emits > 5000 { t.SkipNow()