diff --git a/config/config.go b/config/config.go index 900c06bc30..adf123fe84 100644 --- a/config/config.go +++ b/config/config.go @@ -27,6 +27,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/host/autorelay" bhost "github.com/libp2p/go-libp2p/p2p/host/basic" blankhost "github.com/libp2p/go-libp2p/p2p/host/blank" + "github.com/libp2p/go-libp2p/p2p/host/conntracker" "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem" rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" @@ -413,7 +414,15 @@ func (cfg *Config) addTransports() ([]fx.Option, error) { return fxopts, nil } -func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.BasicHost, error) { +type basicHostParams struct { + fx.In + Swarm *swarm.Swarm + EventBus event.Bus + + ConnTracker *conntracker.ConnTracker `optional:"true"` +} + +func (cfg *Config) newBasicHost(params basicHostParams) (*bhost.BasicHost, error) { var autonatv2Dialer host.Host if cfg.EnableAutoNATv2 { ah, err := cfg.makeAutoNATV2Host() @@ -422,8 +431,8 @@ func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.B } autonatv2Dialer = ah } - h, err := bhost.NewHost(swrm, &bhost.HostOpts{ - EventBus: eventBus, + h, err := bhost.NewHost(params.Swarm, &bhost.HostOpts{ + EventBus: params.EventBus, ConnManager: cfg.ConnManager, AddrsFactory: cfg.AddrsFactory, NATManager: cfg.NATManager, @@ -439,6 +448,7 @@ func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.B DisableIdentifyAddressDiscovery: cfg.DisableIdentifyAddressDiscovery, EnableAutoNATv2: cfg.EnableAutoNATv2, AutoNATv2Dialer: autonatv2Dialer, + ConnTracker: params.ConnTracker, }) if err != nil { return nil, err @@ -493,6 +503,15 @@ func (cfg *Config) NewNode() (host.Host, error) { }) return sw, nil }), + fx.Provide(func(l fx.Lifecycle, eb event.Bus, s *swarm.Swarm) *conntracker.ConnTracker { + ct := &conntracker.ConnTracker{} + l.Append(fx.StartStopHook( + func() error { + return ct.Start(eb, s) + }, ct.Stop, + )) + return ct + }), fx.Provide(cfg.newBasicHost), fx.Provide(func(bh *bhost.BasicHost) identify.IDService { return bh.IDService() diff --git a/core/event/identify.go b/core/event/identify.go index 888572a2d5..2180ad5ec8 100644 --- a/core/event/identify.go +++ b/core/event/identify.go @@ -43,4 +43,7 @@ type EvtPeerIdentificationFailed struct { Peer peer.ID // Reason is the reason why identification failed. Reason error + + // Conn is the connection we failed to identify. + Conn network.Conn } diff --git a/core/event/network.go b/core/event/network.go index 37dd09ca9a..8ce32e3cd3 100644 --- a/core/event/network.go +++ b/core/event/network.go @@ -3,6 +3,7 @@ package event import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" ) // EvtPeerConnectednessChanged should be emitted every time the "connectedness" to a @@ -53,3 +54,19 @@ type EvtPeerConnectednessChanged struct { // Connectedness is the new connectedness state. Connectedness network.Connectedness } + +// EvtProtocolNegotiationSuccess is emitted when we learn about a protocol a +// peer supports via protocol negotiation (i.e. MultiStream). +// +// This is only emitted if we learned about a protocol during negotiation. It +// is not emitted if we already expected a peer to support the protocol. +type EvtProtocolNegotiationSuccess struct { + // Peer is the remote peer who we negotiated the protocol with. + Peer peer.ID + + // Conn is the connection we opened the stream on. + Conn network.Conn + + // Protocol is protocol we've successfully negotiated. + Protocol protocol.ID +} diff --git a/libp2p_test.go b/libp2p_test.go index 3de82946d8..7ccabd02f9 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -775,3 +775,30 @@ func TestSharedTCPAddr(t *testing.T) { require.True(t, sawWS) h.Close() } + +func TestMinimalEcho(t *testing.T) { + h1, err := New() + require.NoError(t, err) + defer h1.Close() + + h2, err := New() + require.NoError(t, err) + defer h2.Close() + + h2.SetStreamHandler("/testing/echo", func(s network.Stream) { + defer s.Close() + io.Copy(s, s) + }) + + h1.Connect(context.Background(), peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) + + s, err := h1.NewStream(context.Background(), h2.ID(), "/testing/echo") + require.NoError(t, err) + body := []byte("hello") + s.Write(body) + s.CloseWrite() + resp, err := io.ReadAll(s) + require.NoError(t, err) + require.Equal(t, body, resp) + s.Close() +} diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 820411bd27..04aec1da75 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -22,6 +22,7 @@ import ( "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/autonat" "github.com/libp2p/go-libp2p/p2p/host/basic/internal/backoff" + "github.com/libp2p/go-libp2p/p2p/host/conntracker" "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/host/pstoremanager" "github.com/libp2p/go-libp2p/p2p/host/relaysvc" @@ -74,6 +75,7 @@ type BasicHost struct { // keep track of resources we need to wait on before shutting down refCount sync.WaitGroup + connTracker *conntracker.ConnTracker network network.Network psManager *pstoremanager.PeerstoreManager mux *msmux.MultistreamMuxer[protocol.ID] @@ -92,6 +94,7 @@ type BasicHost struct { emitters struct { evtLocalProtocolsUpdated event.Emitter evtLocalAddrsUpdated event.Emitter + evtProtoNegotiation event.Emitter } addrChangeChan chan struct{} @@ -170,6 +173,8 @@ type HostOpts struct { DisableIdentifyAddressDiscovery bool EnableAutoNATv2 bool AutoNATv2Dialer host.Host + + ConnTracker *conntracker.ConnTracker } // NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network. @@ -185,7 +190,21 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { if err != nil { return nil, err } + hostCtx, cancel := context.WithCancel(context.Background()) + + // FIXME: + // Remove this. Hack to try the tests quickly + if opts.ConnTracker == nil { + opts.ConnTracker = &conntracker.ConnTracker{} + opts.ConnTracker.Start(opts.EventBus, n) + go func() { + <-hostCtx.Done() + opts.ConnTracker.Stop() + + }() + } + h := &BasicHost{ network: n, psManager: psManager, @@ -197,6 +216,7 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { ctx: hostCtx, ctxCancel: cancel, disableSignedPeerRecord: opts.DisableSignedPeerRecord, + connTracker: opts.ConnTracker, } h.updateLocalIpAddr() @@ -207,6 +227,9 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { if h.emitters.evtLocalAddrsUpdated, err = h.eventbus.Emitter(&event.EvtLocalAddressesUpdated{}, eventbus.Stateful); err != nil { return nil, err } + if h.emitters.evtProtoNegotiation, err = h.eventbus.Emitter(&event.EvtProtocolNegotiationSuccess{}); err != nil { + return nil, err + } if !h.disableSignedPeerRecord { cab, ok := peerstore.GetCertifiedAddrBook(n.Peerstore()) @@ -685,6 +708,80 @@ func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) { }) } +func (h *BasicHost) newConnFromConnTracker(ctx context.Context, p peer.ID) (conntracker.ConnWithMeta, error) { + connFilter := conntracker.NoLimitedConnFilter + if canUseLimitedConn, _ := network.GetAllowLimitedConn(ctx); canUseLimitedConn { + connFilter = nil + } + + // requiredProtos is nil because we will fallback to MSS to negotiate the + // protocol if none of our prefrred protocols were broadcasted by identify. + var requiredProtos []protocol.ID + + // Do we have a conn? + // TODO: for both usages of conntracker, use a sort fn that sorts by more streams. + conn, err := h.connTracker.GetBestConn(ctx, p, conntracker.GetBestConnOpts{ + OneOf: requiredProtos, + FilterFn: connFilter, + WaitForIdentify: true, + AllowNoConn: true, + }) + if err == nil { + return conn, nil + } + + var errCh chan error + if nodial, _ := network.GetNoDial(ctx); !nodial { + errCh = make(chan error, 1) + go func() { + err := h.Connect(ctx, peer.AddrInfo{ID: p}) + if err != nil { + select { + case errCh <- err: + default: + } + } + }() + } + + // Wait for a connection that works for us + connChan, err := h.connTracker.GetBestConnChan(ctx, p, conntracker.GetBestConnOpts{ + OneOf: requiredProtos, + FilterFn: connFilter, + WaitForIdentify: true, // Old behavior + }) + if err != nil { + return conntracker.ConnWithMeta{}, err + } + + select { + case <-ctx.Done(): + return conntracker.ConnWithMeta{}, ctx.Err() + case <-h.ctx.Done(): + return conntracker.ConnWithMeta{}, h.ctx.Err() + case err = <-errCh: + return conntracker.ConnWithMeta{}, err + case connWithMeta := <-connChan: + return connWithMeta, nil + } +} + +func (h *BasicHost) newStreamWithConnTracker(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, protocol.ID, error) { + var preferredProto protocol.ID + connWithMeta, err := h.newConnFromConnTracker(ctx, p) + if err != nil { + return nil, "", err + } + for _, proto := range pids { + if connWithMeta.SupportsProtocol(proto) { + preferredProto = proto + break + } + } + s, err := connWithMeta.NewStream(ctx) + return s, preferredProto, err +} + // NewStream opens a new stream to given peer p, and writes a p2p/protocol // header with given protocol.ID. If there is no connection to p, attempts // to create one. If ProtocolID is "", writes no header. @@ -698,56 +795,71 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I } } - // If the caller wants to prevent the host from dialing, it should use the NoDial option. - if nodial, _ := network.GetNoDial(ctx); !nodial { - err := h.Connect(ctx, peer.AddrInfo{ID: p}) + var ( + err error + s network.Stream + preferredProto protocol.ID + ) + + if h.connTracker != nil { + s, preferredProto, err = h.newStreamWithConnTracker(ctx, p, pids...) if err != nil { return nil, err } - } - - s, err := h.Network().NewStream(network.WithNoDial(ctx, "already dialed"), p) - if err != nil { - // TODO: It would be nicer to get the actual error from the swarm, - // but this will require some more work. - if errors.Is(err, network.ErrNoConn) { - return nil, errors.New("connection failed") + } else { + // If the caller wants to prevent the host from dialing, it should use the NoDial option. + if nodial, _ := network.GetNoDial(ctx); !nodial { + err := h.Connect(ctx, peer.AddrInfo{ID: p}) + if err != nil { + return nil, err + } } - return nil, fmt.Errorf("failed to open stream: %w", err) - } - defer func() { - if strErr != nil && s != nil { - s.Reset() + + s, err = h.Network().NewStream(network.WithNoDial(ctx, "already dialed"), p) + if err != nil { + // TODO: It would be nicer to get the actual error from the swarm, + // but this will require some more work. + if errors.Is(err, network.ErrNoConn) { + return nil, errors.New("connection failed") + } + return nil, fmt.Errorf("failed to open stream: %w", err) } - }() + defer func() { + if strErr != nil && s != nil { + s.Reset() + } + }() - // Wait for any in-progress identifies on the connection to finish. This - // is faster than negotiating. - // - // If the other side doesn't support identify, that's fine. This will - // just be a no-op. - select { - case <-h.ids.IdentifyWait(s.Conn()): - case <-ctx.Done(): - return nil, fmt.Errorf("identify failed to complete: %w", ctx.Err()) - } + // Wait for any in-progress identifies on the connection to finish. This + // is faster than negotiating. + // + // If the other side doesn't support identify, that's fine. This will + // just be a no-op. + select { + case <-h.ids.IdentifyWait(s.Conn()): + case <-ctx.Done(): + return nil, fmt.Errorf("identify failed to complete: %w", ctx.Err()) + } - pref, err := h.preferredProtocol(p, pids) - if err != nil { - return nil, err + preferredProto, err = h.preferredProtocol(p, pids) + if err != nil { + return nil, err + } } - if pref != "" { - if err := s.SetProtocol(pref); err != nil { + if preferredProto != "" { + if err := s.SetProtocol(preferredProto); err != nil { return nil, err } - lzcon := msmux.NewMSSelect(s, pref) + lzcon := msmux.NewMSSelect(s, preferredProto) return &streamWrapper{ Stream: s, rw: lzcon, }, nil } + // Fallback to MultiStreamSelect. + // Negotiate the protocol in the background, obeying the context. var selected protocol.ID errCh := make(chan error, 1) @@ -771,6 +883,8 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I return nil, err } _ = h.Peerstore().AddProtocols(p, selected) // adding the protocol to the peerstore isn't critical + h.emitters.evtProtoNegotiation.Emit(event.EvtProtocolNegotiationSuccess{Peer: p, Conn: s.Conn(), Protocol: selected}) + return s, nil } diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 2a7a772976..f4b9729b1a 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -312,10 +312,13 @@ func TestAllAddrsUnique(t *testing.T) { func getHostPair(t *testing.T) (host.Host, host.Host) { t.Helper() - h1, err := NewHost(swarmt.GenSwarm(t), nil) + eb1 := eventbus.NewBus() + h1, err := NewHost(swarmt.GenSwarm(t, swarmt.EventBus(eb1)), &HostOpts{EventBus: eb1}) require.NoError(t, err) h1.Start() - h2, err := NewHost(swarmt.GenSwarm(t), nil) + + eb2 := eventbus.NewBus() + h2, err := NewHost(swarmt.GenSwarm(t, swarmt.EventBus(eb2)), &HostOpts{EventBus: eb2}) require.NoError(t, err) h2.Start() @@ -587,7 +590,8 @@ func TestProtoDowngrade(t *testing.T) { assertWait(t, connectedOn, "/testing/1.0.0") require.NoError(t, s.Close()) - h1.Network().ClosePeer(h2.ID()) + protosSub, err := h2.EventBus().Subscribe(&event.EvtLocalProtocolsUpdated{}) + require.NoError(t, err) h2.RemoveStreamHandler("/testing/1.0.0") h2.SetStreamHandler("/testing", func(s network.Stream) { defer s.Close() @@ -597,10 +601,24 @@ func TestProtoDowngrade(t *testing.T) { connectedOn <- s.Protocol() }) + // Wait for protos updated + for { + protoUpdate := <-protosSub.Out() + if len(protoUpdate.(event.EvtLocalProtocolsUpdated).Removed) > 0 { + break + } + } + // Give us a second to update our protocol list. This happens async through the event bus. // This is _almost_ instantaneous, but this test fails once every ~1k runs without this. time.Sleep(time.Millisecond) + // Wait for disconnect + disconnectSub, err := h1.EventBus().Subscribe(&event.EvtPeerConnectednessChanged{}) + require.NoError(t, err) + h1.Network().ClosePeer(h2.ID()) + <-disconnectSub.Out() + h2pi := h2.Peerstore().PeerInfo(h2.ID()) require.NoError(t, h1.Connect(ctx, h2pi)) diff --git a/p2p/host/conntracker/conntracker.go b/p2p/host/conntracker/conntracker.go new file mode 100644 index 0000000000..646a613ce2 --- /dev/null +++ b/p2p/host/conntracker/conntracker.go @@ -0,0 +1,528 @@ +// Package conntracker holds the ConnTracker service. Which tracks a peer's +// connections and supported protocols. +package conntracker + +import ( + "context" + "errors" + "time" + + "slices" + + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" + multiaddr "github.com/multiformats/go-multiaddr" + + logging "github.com/ipfs/go-log/v2" +) + +const pendingReqBound = 1_000_000 +const gcInterval = time.Minute + +// If we reach this many pending cleanup requests, we'll cleanup immediately. +const maxPendingCleanup = 1_000 + +// If we don't have maxPendingCleanup requests to cleanup, then we'll cleanup +// up at most this frequency if requested. +const maxCleanupFrequency = 10 * time.Second + +var log = logging.Logger("bestconn") + +// ConnTracker lets other services find the best connection to a peer. It relies +// on Identify to identify the protocols a peer supports on a given connection. +// If identify fails, it will still track the connection, but not record any +// protocols for that connection. If a protocol is found via Multistream Select +// (or any other protocol negotiation mechanism), the conntracker will update +// the supported protocols when the EvtProtocolNegotiationSuccess event is +// received. +type ConnTracker struct { + svcCtx context.Context + stop context.CancelFunc + stopped chan struct{} + clock clock + sub event.Subscription + trackedConns map[peer.ID]map[network.Conn]connMeta + pendingReqs map[peer.ID][]req + totalPendingReqs int + + notifier connNotifier + connNotifs chan connNotif + reqCh chan req + cleanupCh chan peer.ID +} + +type connNotifKind int + +const ( + connNotifConnected connNotifKind = iota + connNotifDisconnected +) + +type connNotif struct { + kind connNotifKind + conn network.Conn +} + +type clock interface { + Now() time.Time + Since(time.Time) time.Duration + NewTicker(d time.Duration) *time.Ticker +} + +type realClock struct{} + +func (realClock) Now() time.Time { + return time.Now() +} +func (realClock) Since(t time.Time) time.Duration { + return time.Since(t) +} +func (realClock) NewTicker(d time.Duration) *time.Ticker { + return time.NewTicker(d) +} + +type connMeta struct { + // Identify ran on this connection + identified bool + protos map[protocol.ID]struct{} +} + +type req struct { + ctx context.Context + + // These are request parameters + immediate bool + waitForIdentify bool + p peer.ID + oneOf []protocol.ID + filter func(c network.Conn) bool + sort func(a, b network.Conn) int + + // These are internal plumbing + resCh chan ConnWithMeta + onFulfilled func() +} + +type connNotifier interface { + Notify(network.Notifiee) + StopNotify(network.Notifiee) +} + +func (ct *ConnTracker) Start(eb event.Bus, notifier connNotifier) error { + sub, err := eb.Subscribe([]any{ + new(event.EvtPeerIdentificationCompleted), + new(event.EvtPeerIdentificationFailed), + new(event.EvtPeerConnectednessChanged), + new(event.EvtProtocolNegotiationSuccess), + }) + if err != nil { + return err + } + + ct.svcCtx, ct.stop = context.WithCancel(context.Background()) + ct.stopped = make(chan struct{}) + ct.sub = sub + ct.trackedConns = make(map[peer.ID]map[network.Conn]connMeta) + ct.pendingReqs = make(map[peer.ID][]req) + ct.reqCh = make(chan req, 1) + ct.cleanupCh = make(chan peer.ID, 16) + ct.connNotifs = make(chan connNotif, 8) + ct.notifier = notifier + ct.notifier.Notify(ct) + + if ct.clock == nil { + ct.clock = realClock{} + } + + go ct.loop() + + return nil +} + +func (ct *ConnTracker) Stop() { + ct.notifier.StopNotify(ct) + ct.stop() + <-ct.stopped + ct.notifier = nil + ct.trackedConns = nil + ct.pendingReqs = nil + ct.connNotifs = nil + ct.sub.Close() +} + +func (ct *ConnTracker) gc() { + for p, rs := range ct.pendingReqs { + rs, n := clearCancelledReqs(rs) + if n > 0 { + ct.pendingReqs[p] = rs + ct.totalPendingReqs -= n + } + } +} + +func (ct *ConnTracker) updateProtos(p peer.ID, conn network.Conn, protos []protocol.ID, replace bool, identified bool) { + meta := connMeta{identified: identified} + + // Check if we already have tracked this conn + if conns, ok := ct.trackedConns[p]; ok { + if connMeta, ok := conns[conn]; ok { + // Keep the identified status + if connMeta.identified { + meta.identified = true + } + // reuse the map + meta.protos = connMeta.protos + if replace { + clear(meta.protos) + } + } + } + if meta.protos == nil { + meta.protos = make(map[protocol.ID]struct{}, len(protos)) + } + + for _, p := range protos { + meta.protos[p] = struct{}{} + } + + if _, ok := ct.trackedConns[p]; !ok { + ct.trackedConns[p] = make(map[network.Conn]connMeta) + } + ct.trackedConns[p][conn] = meta +} + +func (ct *ConnTracker) loop() { + defer close(ct.stopped) + gcTicker := ct.clock.NewTicker(gcInterval) + defer gcTicker.Stop() + + // debounce many recurring cleanup requests. + var lastCleanupTime time.Time + pendingCleanup := make(map[peer.ID]struct{}, maxPendingCleanup) + + for { + select { + case <-ct.svcCtx.Done(): + return + case <-gcTicker.C: + ct.gc() + case p := <-ct.cleanupCh: + pendingCleanup[p] = struct{}{} + if ct.clock.Since(lastCleanupTime) < maxCleanupFrequency && len(pendingCleanup) < maxPendingCleanup { + continue + } + + lastCleanupTime = ct.clock.Now() + for p := range pendingCleanup { + rs, n := clearCancelledReqs(ct.pendingReqs[p]) + if n > 0 { + ct.pendingReqs[p] = rs + ct.totalPendingReqs -= n + } + } + case notif := <-ct.connNotifs: + switch notif.kind { + case connNotifConnected: + ct.updateProtos(notif.conn.RemotePeer(), notif.conn, nil, false, false) + case connNotifDisconnected: + if m, ok := ct.trackedConns[notif.conn.RemotePeer()]; ok { + delete(m, notif.conn) + } + } + case evt := <-ct.sub.Out(): + switch evt := evt.(type) { + case event.EvtPeerConnectednessChanged: + // Clean up if a peer has disconnected + switch evt.Connectedness { + case network.Connected, network.Limited: + // Do nothing. We'll add this connection when we get the identify. + default: + // clean up + delete(ct.trackedConns, evt.Peer) + } + case event.EvtPeerIdentificationFailed: + ct.updateProtos(evt.Peer, evt.Conn, nil, false, true) + ct.totalPendingReqs -= ct.tryFulfillPendingReqs(evt.Peer) + case event.EvtPeerIdentificationCompleted: + ct.updateProtos(evt.Peer, evt.Conn, evt.Protocols, true, true) + ct.totalPendingReqs -= ct.tryFulfillPendingReqs(evt.Peer) + case event.EvtProtocolNegotiationSuccess: + ct.updateProtos(evt.Peer, evt.Conn, []protocol.ID{evt.Protocol}, false, false) + ct.totalPendingReqs -= ct.tryFulfillPendingReqs(evt.Peer) + default: + log.Debug("unknown event", evt) + continue + } + case req := <-ct.reqCh: + fulfilled := ct.fulfillReq(req) + if fulfilled && req.onFulfilled != nil { + req.onFulfilled() + } + if !fulfilled { + if req.immediate { + req.resCh <- ConnWithMeta{} + continue + } + + if ct.totalPendingReqs >= pendingReqBound { + // Drop the request + log.Warn("dropping request. Too many pending requests") + continue + } + + ct.totalPendingReqs++ + ct.pendingReqs[req.p] = append(ct.pendingReqs[req.p], req) + } + } + } +} + +type ConnWithMeta struct { + network.Conn + Identified bool + supportedProtocols map[protocol.ID]struct{} + MatchingProtocols []protocol.ID +} + +func (c *ConnWithMeta) SupportsProtocol(p protocol.ID) bool { + _, ok := c.supportedProtocols[p] + return ok +} + +// wrapConnWithMeta wraps a network.Conn with the supported protocols that +// intersect the requested oneOf protocols. It preserves the order of the oneOf +// protocols. +func wrapConnWithMeta(c network.Conn, meta connMeta, oneOf []protocol.ID) ConnWithMeta { + supportedProtocols := make(map[protocol.ID]struct{}, len(meta.protos)) + for p := range meta.protos { + supportedProtocols[p] = struct{}{} + } + + var matchingProtocols []protocol.ID + if len(oneOf) == 0 { + // If no oneOf protocols are provided, return all supported protocols. + matchingProtocols = make([]protocol.ID, 0, len(meta.protos)) + for p := range meta.protos { + matchingProtocols = append(matchingProtocols, p) + } + } else { + matchingProtocols = make([]protocol.ID, 0, len(oneOf)) + for _, p := range oneOf { + if _, ok := meta.protos[p]; ok { + matchingProtocols = append(matchingProtocols, p) + } + } + } + + return ConnWithMeta{ + Conn: c, + Identified: meta.identified, + MatchingProtocols: matchingProtocols, + supportedProtocols: supportedProtocols, + } +} + +// fulfillReq returns true if the request was fulfilled +func (ct *ConnTracker) fulfillReq(r req) bool { + if r.ctx.Err() != nil { + // Request has been cancelled + return true + } + + conns := make([]network.Conn, 0, len(ct.trackedConns[r.p])) + for c, m := range ct.trackedConns[r.p] { + if c.IsClosed() { + delete(ct.trackedConns[r.p], c) + continue + } + if r.waitForIdentify && !m.identified { + continue + } + if r.filter != nil && !r.filter(c) { + continue + } + if len(r.oneOf) != 0 { + found := false + for _, p := range r.oneOf { + if _, ok := m.protos[p]; ok { + found = true + } + } + if !found { + continue + } + } + if r.sort == nil { + r.resCh <- wrapConnWithMeta(c, m, r.oneOf) + return true + } + conns = append(conns, c) + } + if r.sort != nil { + slices.SortFunc(conns, r.sort) + } + + if len(conns) > 0 { + r.resCh <- wrapConnWithMeta(conns[0], ct.trackedConns[r.p][conns[0]], r.oneOf) + return true + } + return false +} + +// tryFulfillPendingReqs will attempt to fulfill pending requests. +// returns the number of requests fulfilled +func (ct *ConnTracker) tryFulfillPendingReqs(p peer.ID) int { + l := len(ct.pendingReqs[p]) + ct.pendingReqs[p] = slices.DeleteFunc(ct.pendingReqs[p], ct.fulfillReq) + if len(ct.pendingReqs[p]) == 0 { + ct.pendingReqs[p] = nil + } + return l - len(ct.pendingReqs[p]) +} + +// clearCancelledReqs will clear all cancelled requests for a peer +// Returns the new slice and the number of cancelled requests +func clearCancelledReqs(rs []req) ([]req, int) { + l := len(rs) + rs = slices.DeleteFunc(rs, func(r req) bool { + return r.ctx.Err() != nil + }) + if len(rs) == 0 { + rs = nil + } + return rs, l - len(rs) +} + +func (ct *ConnTracker) cleanup(p peer.ID) { + select { + case ct.cleanupCh <- p: + case <-ct.stopped: + log.Debug("dropping cleanup request: service stopped") + default: + log.Debug("dropping cleanup request: channel full") + } +} + +var _ network.Notifiee = (*ConnTracker)(nil) + +func (ct *ConnTracker) Connected(_ network.Network, c network.Conn) { + select { + case ct.connNotifs <- connNotif{kind: connNotifConnected, conn: c}: + case <-ct.svcCtx.Done(): + log.Debug("dropping connection notification: service stopped") + } +} + +func (ct *ConnTracker) Disconnected(_ network.Network, c network.Conn) { + select { + case ct.connNotifs <- connNotif{kind: connNotifDisconnected, conn: c}: + case <-ct.svcCtx.Done(): + log.Debug("dropping disconnection notification: service stopped") + } +} + +// Listen implements network.Notifiee. +func (ct *ConnTracker) Listen(network.Network, multiaddr.Multiaddr) { + // unused +} + +// ListenClose implements network.Notifiee. +func (ct *ConnTracker) ListenClose(network.Network, multiaddr.Multiaddr) { + // unused +} + +type GetBestConnOpts struct { + OneOf []protocol.ID + + // Optional. If a filter function is provided, it will further filter the connections. + // The filter function should return true if the connection is acceptable. + FilterFn func(c network.Conn) bool + // Optional. If a sort function is provided, it will be used to sort the connections. + // Refer to slices.SortFunc for the signature of the sort function. + // The first conenction in the sorted list will be returned + SortFn func(a, b network.Conn) int + WaitForIdentify bool + // If true, will return a response as soon as possible, even if no connection is available. + AllowNoConn bool +} + +// GetBestConn will return the best conn to a peer capable of using the +// provided protocol. +// If no connection is currently available this will block. +// If an empty oneOf protocolID is passed, any connection that has done `Identify` +// will be returned. +func (ct *ConnTracker) GetBestConn(ctx context.Context, peer peer.ID, opts GetBestConnOpts) (ConnWithMeta, error) { + r := req{ + ctx: ctx, + p: peer, + oneOf: opts.OneOf, + filter: opts.FilterFn, + sort: opts.SortFn, + immediate: opts.AllowNoConn, + waitForIdentify: opts.WaitForIdentify, + } + res, err := ct.sendReq(ctx, r) + if err != nil { + return ConnWithMeta{}, err + } + select { + case <-ctx.Done(): + return ConnWithMeta{}, ctx.Err() + case conn := <-res: + if conn.Conn == nil { + return ConnWithMeta{}, ErrNoConn + } + return conn, nil + } +} + +var ( + ErrNoConn = errors.New("no connection available") + ErrStopped = errors.New("conntracker is stopped") +) + +// GetBestConnChan is like GetBestConn but returns a channel that contains the best connection. +func (ct *ConnTracker) GetBestConnChan(ctx context.Context, peer peer.ID, opts GetBestConnOpts) (<-chan ConnWithMeta, error) { + r := req{ + ctx: ctx, + p: peer, + oneOf: opts.OneOf, + filter: opts.FilterFn, + sort: opts.SortFn, + immediate: opts.AllowNoConn, + waitForIdentify: opts.WaitForIdentify, + } + return ct.sendReq(ctx, r) +} + +func (ct *ConnTracker) sendReq(ctx context.Context, r req) (<-chan ConnWithMeta, error) { + if r.resCh == nil { + r.resCh = make(chan ConnWithMeta, 1) + } + + stopCleanup := context.AfterFunc(ctx, func() { + ct.cleanup(r.p) + }) + + r.onFulfilled = func() { + // No need to cleanup, we fulfilled the request. + stopCleanup() + } + + select { + case ct.reqCh <- r: + return r.resCh, nil + case <-ct.stopped: + stopCleanup() + // In case the conntracker is stopped, don't block. return an error. + return nil, ErrStopped + } +} + +func NoLimitedConnFilter(c network.Conn) bool { + return !c.Stat().Limited +} + +// TODO: add a basic sort function that essentially copies [isBetterConn] in swarm.go diff --git a/p2p/host/conntracker/conntracker_test.go b/p2p/host/conntracker/conntracker_test.go new file mode 100644 index 0000000000..bf675ee137 --- /dev/null +++ b/p2p/host/conntracker/conntracker_test.go @@ -0,0 +1,295 @@ +package conntracker + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" + "github.com/libp2p/go-libp2p/core/test" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" + multiaddr "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TODO: +// - Test proto negotiation event +// - Test Identify failed +// - Test Notifier + +// MockClock is a mock implementation of the clock interface +type MockClock struct { + mu sync.Mutex + current time.Time + tickers []tickerMeta + tickerWG sync.WaitGroup +} + +type tickerMeta struct { + clockTick chan time.Time +} + +func (m *MockClock) Close() { + m.mu.Lock() + defer m.mu.Unlock() + + for _, t := range m.tickers { + close(t.clockTick) + } +} + +// Now returns the fixed time +func (m *MockClock) Now() time.Time { + m.mu.Lock() + defer m.mu.Unlock() + return m.current +} + +// Since returns the duration since the fixed time +func (m *MockClock) Since(t time.Time) time.Duration { + m.mu.Lock() + defer m.mu.Unlock() + return time.Since(m.current) +} + +// NewTicker returns a ticker that ticks at the specified duration +func (m *MockClock) NewTicker(d time.Duration) *time.Ticker { + m.mu.Lock() + defer m.mu.Unlock() + + c := make(chan time.Time) + toFire := m.current.Add(d) + ticker := &time.Ticker{ + C: c, + } + + clockTick := make(chan time.Time, 1) + m.tickers = append(m.tickers, tickerMeta{ + clockTick: clockTick, + }) + + go func() { + for t := range clockTick { + for t.After(toFire) || t.Equal(toFire) { + c <- t + toFire = toFire.Add(d) + } + m.tickerWG.Done() + } + }() + + return ticker +} + +func (m *MockClock) Advance(d time.Duration) { + m.mu.Lock() + m.current = m.current.Add(d) + current := m.current + tickers := m.tickers + m.mu.Unlock() + + m.tickerWG.Add(len(tickers)) + for _, t := range tickers { + t.clockTick <- current + } + m.tickerWG.Wait() +} + +type mockNotifier struct { +} + +func (m *mockNotifier) Notify(n network.Notifiee) {} +func (m *mockNotifier) StopNotify(n network.Notifiee) {} + +func TestGetConnTracker(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + eb := eventbus.NewBus() + + clk := &MockClock{} + connTracker := ConnTracker{clock: clk} + + err := connTracker.Start(eb, &mockNotifier{}) + require.NoError(t, err) + defer connTracker.Stop() + + idEmitter, err := eb.Emitter(new(event.EvtPeerIdentificationCompleted)) + require.NoError(t, err) + defer idEmitter.Close() + connectednessEmitter, err := eb.Emitter(new(event.EvtPeerConnectednessChanged)) + require.NoError(t, err) + defer connectednessEmitter.Close() + + peerA := test.RandPeerIDFatal(t) + connA := &mockConn{} + + var wg sync.WaitGroup + defer func() { + sem := make(chan struct{}) + go func() { + defer close(sem) + wg.Wait() + }() + select { + case <-sem: + case <-time.After(1 * time.Second): + assert.Fail(t, "WaitGroup was not completed") + } + }() + wg.Add(1) + + go func() { + defer wg.Done() + // Asking for a conn before we have one will block + c, err := connTracker.GetBestConn(ctx, peerA, GetBestConnOpts{ + OneOf: []protocol.ID{"/test/1.0.0"}, + FilterFn: NoLimitedConnFilter, + }) + assert.NoError(t, err) + assert.Equal(t, connA, c.Conn) + }() + + // We've connected to a peer + evt := event.EvtPeerConnectednessChanged{ + Connectedness: network.Connected, + Peer: peerA, + } + err = connectednessEmitter.Emit(evt) + require.NoError(t, err) + + idEvt := event.EvtPeerIdentificationCompleted{ + Peer: peerA, + Conn: connA, + Protocols: []protocol.ID{"/test/1.0.0"}, + } + err = idEmitter.Emit(idEvt) + require.NoError(t, err) + + // Getting a connection to peerA should return the connection we just added + c, err := connTracker.GetBestConn(ctx, peerA, GetBestConnOpts{ + OneOf: []protocol.ID{"/test/1.0.0"}, + FilterFn: NoLimitedConnFilter, + }) + require.NoError(t, err) + require.Equal(t, connA, c.Conn) + + // Advance the clock to trigger the GC + clk.Advance(1 * time.Hour) + + // Still have the connection + c, err = connTracker.GetBestConn(ctx, peerA, GetBestConnOpts{ + OneOf: []protocol.ID{"/test/1.0.0"}, + FilterFn: NoLimitedConnFilter, + }) + require.NoError(t, err) + require.Equal(t, connA, c.Conn) + + // Disconnect from the peer + evt = event.EvtPeerConnectednessChanged{ + Connectedness: network.NotConnected, + Peer: peerA, + } + err = connectednessEmitter.Emit(evt) + require.NoError(t, err) + + // Advance the clock to trigger the GC + clk.Advance(1 * time.Hour) + time.Sleep(100 * time.Millisecond) // Wait for the other goroutine to GC + + // Should block since we don't have any connection + ctx, timeoutCancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer timeoutCancel() + _, err = connTracker.GetBestConn(ctx, peerA, GetBestConnOpts{ + OneOf: []protocol.ID{"/test/1.0.0"}, + FilterFn: NoLimitedConnFilter, + }) + require.ErrorContains(t, err, "context deadline exceeded") + + _, err = connTracker.GetBestConn(context.Background(), peerA, GetBestConnOpts{ + OneOf: []protocol.ID{"/test/1.0.0"}, + AllowNoConn: true, + FilterFn: NoLimitedConnFilter, + }) + require.ErrorIs(t, err, ErrNoConn) +} + +// TODO test that we get the meta with the conn + +type mockConn struct { + closed bool +} + +// Close implements network.Conn. +func (m *mockConn) Close() error { + m.closed = true + return nil +} + +// ConnState implements network.Conn. +func (m *mockConn) ConnState() network.ConnectionState { + panic("unimplemented") +} + +// GetStreams implements network.Conn. +func (m *mockConn) GetStreams() []network.Stream { + panic("unimplemented") +} + +// ID implements network.Conn. +func (m *mockConn) ID() string { + panic("unimplemented") +} + +// IsClosed implements network.Conn. +func (m *mockConn) IsClosed() bool { + return m.closed +} + +// LocalMultiaddr implements network.Conn. +func (m *mockConn) LocalMultiaddr() multiaddr.Multiaddr { + panic("unimplemented") +} + +// LocalPeer implements network.Conn. +func (m *mockConn) LocalPeer() peer.ID { + panic("unimplemented") +} + +// NewStream implements network.Conn. +func (m *mockConn) NewStream(context.Context) (network.Stream, error) { + panic("unimplemented") +} + +// RemoteMultiaddr implements network.Conn. +func (m *mockConn) RemoteMultiaddr() multiaddr.Multiaddr { + panic("unimplemented") +} + +// RemotePeer implements network.Conn. +func (m *mockConn) RemotePeer() peer.ID { + panic("unimplemented") +} + +// RemotePublicKey implements network.Conn. +func (m *mockConn) RemotePublicKey() crypto.PubKey { + panic("unimplemented") +} + +// Scope implements network.Conn. +func (m *mockConn) Scope() network.ConnScope { + panic("unimplemented") +} + +// Stat implements network.Conn. +func (m *mockConn) Stat() network.ConnStats { + return network.ConnStats{} +} + +var _ network.Conn = &mockConn{} diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index 1733a4166c..35332f5bfa 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -347,6 +347,10 @@ func (ids *idService) sendPushes(ctx context.Context) { defer func() { <-sem }() ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() + + // We only want to send an identify push if we already have an open + // connection. + ctx = network.WithNoDial(ctx, "id push") str, err := ids.Host.NewStream(ctx, c.RemotePeer(), IDPush) if err != nil { // connection might have been closed recently return @@ -429,7 +433,7 @@ func (ids *idService) IdentifyWait(c network.Conn) <-chan struct{} { defer close(e.IdentifyWaitChan) if err := ids.identifyConn(c); err != nil { log.Warnf("failed to identify %s: %s", c.RemotePeer(), err) - ids.emitters.evtPeerIdentificationFailed.Emit(event.EvtPeerIdentificationFailed{Peer: c.RemotePeer(), Reason: err}) + ids.emitters.evtPeerIdentificationFailed.Emit(event.EvtPeerIdentificationFailed{Peer: c.RemotePeer(), Reason: err, Conn: c}) return } }() diff --git a/p2p/test/reconnects/reconnect_test.go b/p2p/test/reconnects/reconnect_test.go index cf05c80f37..4d3e9b8670 100644 --- a/p2p/test/reconnects/reconnect_test.go +++ b/p2p/test/reconnects/reconnect_test.go @@ -2,6 +2,7 @@ package reconnect import ( "context" + "fmt" "io" "math/rand" "runtime" @@ -13,6 +14,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/protocol" bhost "github.com/libp2p/go-libp2p/p2p/host/basic" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/stretchr/testify/require" @@ -34,7 +36,8 @@ func TestReconnect5(t *testing.T) { hosts := make([]host.Host, 0, num) for i := 0; i < num; i++ { - h, err := bhost.NewHost(swarmt.GenSwarm(t, swarmOpt), nil) + eb := eventbus.NewBus() + h, err := bhost.NewHost(swarmt.GenSwarm(t, swarmOpt, swarmt.EventBus(eb)), &bhost.HostOpts{EventBus: eb}) require.NoError(t, err) defer h.Close() h.Start() @@ -42,8 +45,11 @@ func TestReconnect5(t *testing.T) { h.SetStreamHandler(protocol.TestingID, EchoStreamHandler) } - for i := 0; i < 4; i++ { - runRound(t, hosts) + const numTimes = 5 + for i := 0; i < numTimes; i++ { + t.Run(fmt.Sprintf("round %d", i), func(t *testing.T) { + runRound(t, hosts) + }) } } diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index 99ce67b521..fce8d026c1 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/libp2p/go-libp2p/core/control" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" @@ -16,6 +17,7 @@ import ( "github.com/libp2p/go-libp2p-testing/race" ma "github.com/multiformats/go-multiaddr" + "github.com/multiformats/go-multistream" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) @@ -269,17 +271,18 @@ func TestInterceptUpgradedIncoming(t *testing.T) { gomock.InOrder( connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true), connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true), - connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) { + connGater.EXPECT().InterceptUpgraded(gomock.Any()).DoAndReturn(func(c network.Conn) (bool, control.DisconnectReason) { // remove the certhash component from WebTransport addresses require.Equal(t, stripCertHash(h2.Addrs()[0]), c.LocalMultiaddr()) require.Equal(t, h1.ID(), c.RemotePeer()) require.Equal(t, h2.ID(), c.LocalPeer()) + + return true, 0 }), ) h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) - require.Error(t, err) - require.NotErrorIs(t, err, context.DeadlineExceeded) + require.ErrorAs(t, err, &multistream.ErrNotSupported[protocol.ID]{}) }) } }