Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] feat: conntracker #3032

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/libp2p/go-libp2p/p2p/host/autorelay"
bhost "github.com/libp2p/go-libp2p/p2p/host/basic"
blankhost "github.com/libp2p/go-libp2p/p2p/host/blank"
"github.com/libp2p/go-libp2p/p2p/host/conntracker"
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
"github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem"
rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager"
Expand Down Expand Up @@ -413,7 +414,15 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
return fxopts, nil
}

func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.BasicHost, error) {
type basicHostParams struct {
fx.In
Swarm *swarm.Swarm
EventBus event.Bus

ConnTracker *conntracker.ConnTracker `optional:"true"`
}

func (cfg *Config) newBasicHost(params basicHostParams) (*bhost.BasicHost, error) {
var autonatv2Dialer host.Host
if cfg.EnableAutoNATv2 {
ah, err := cfg.makeAutoNATV2Host()
Expand All @@ -422,8 +431,8 @@ func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.B
}
autonatv2Dialer = ah
}
h, err := bhost.NewHost(swrm, &bhost.HostOpts{
EventBus: eventBus,
h, err := bhost.NewHost(params.Swarm, &bhost.HostOpts{
EventBus: params.EventBus,
ConnManager: cfg.ConnManager,
AddrsFactory: cfg.AddrsFactory,
NATManager: cfg.NATManager,
Expand All @@ -439,6 +448,7 @@ func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.B
DisableIdentifyAddressDiscovery: cfg.DisableIdentifyAddressDiscovery,
EnableAutoNATv2: cfg.EnableAutoNATv2,
AutoNATv2Dialer: autonatv2Dialer,
ConnTracker: params.ConnTracker,
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -493,6 +503,15 @@ func (cfg *Config) NewNode() (host.Host, error) {
})
return sw, nil
}),
fx.Provide(func(l fx.Lifecycle, eb event.Bus, s *swarm.Swarm) *conntracker.ConnTracker {
ct := &conntracker.ConnTracker{}
l.Append(fx.StartStopHook(
func() error {
return ct.Start(eb, s)
}, ct.Stop,
))
return ct
}),
fx.Provide(cfg.newBasicHost),
fx.Provide(func(bh *bhost.BasicHost) identify.IDService {
return bh.IDService()
Expand Down
3 changes: 3 additions & 0 deletions core/event/identify.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,7 @@ type EvtPeerIdentificationFailed struct {
Peer peer.ID
// Reason is the reason why identification failed.
Reason error

// Conn is the connection we failed to identify.
Conn network.Conn
}
17 changes: 17 additions & 0 deletions core/event/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package event
import (
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
)

// EvtPeerConnectednessChanged should be emitted every time the "connectedness" to a
Expand Down Expand Up @@ -53,3 +54,19 @@ type EvtPeerConnectednessChanged struct {
// Connectedness is the new connectedness state.
Connectedness network.Connectedness
}

// EvtProtocolNegotiationSuccess is emitted when we learn about a protocol a
// peer supports via protocol negotiation (i.e. MultiStream).
//
// This is only emitted if we learned about a protocol during negotiation. It
// is not emitted if we already expected a peer to support the protocol.
type EvtProtocolNegotiationSuccess struct {
// Peer is the remote peer who we negotiated the protocol with.
Peer peer.ID

// Conn is the connection we opened the stream on.
Conn network.Conn

// Protocol is protocol we've successfully negotiated.
Protocol protocol.ID
}
27 changes: 27 additions & 0 deletions libp2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -775,3 +775,30 @@ func TestSharedTCPAddr(t *testing.T) {
require.True(t, sawWS)
h.Close()
}

func TestMinimalEcho(t *testing.T) {
h1, err := New()
require.NoError(t, err)
defer h1.Close()

h2, err := New()
require.NoError(t, err)
defer h2.Close()

h2.SetStreamHandler("/testing/echo", func(s network.Stream) {
defer s.Close()
io.Copy(s, s)
})

h1.Connect(context.Background(), peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})

s, err := h1.NewStream(context.Background(), h2.ID(), "/testing/echo")
require.NoError(t, err)
body := []byte("hello")
s.Write(body)
s.CloseWrite()
resp, err := io.ReadAll(s)
require.NoError(t, err)
require.Equal(t, body, resp)
s.Close()
}
180 changes: 147 additions & 33 deletions p2p/host/basic/basic_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/host/autonat"
"github.com/libp2p/go-libp2p/p2p/host/basic/internal/backoff"
"github.com/libp2p/go-libp2p/p2p/host/conntracker"
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
"github.com/libp2p/go-libp2p/p2p/host/pstoremanager"
"github.com/libp2p/go-libp2p/p2p/host/relaysvc"
Expand Down Expand Up @@ -74,6 +75,7 @@ type BasicHost struct {
// keep track of resources we need to wait on before shutting down
refCount sync.WaitGroup

connTracker *conntracker.ConnTracker
network network.Network
psManager *pstoremanager.PeerstoreManager
mux *msmux.MultistreamMuxer[protocol.ID]
Expand All @@ -92,6 +94,7 @@ type BasicHost struct {
emitters struct {
evtLocalProtocolsUpdated event.Emitter
evtLocalAddrsUpdated event.Emitter
evtProtoNegotiation event.Emitter
}

addrChangeChan chan struct{}
Expand Down Expand Up @@ -170,6 +173,8 @@ type HostOpts struct {
DisableIdentifyAddressDiscovery bool
EnableAutoNATv2 bool
AutoNATv2Dialer host.Host

ConnTracker *conntracker.ConnTracker
}

// NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network.
Expand All @@ -185,7 +190,21 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
if err != nil {
return nil, err
}

hostCtx, cancel := context.WithCancel(context.Background())

// FIXME:
// Remove this. Hack to try the tests quickly
if opts.ConnTracker == nil {
opts.ConnTracker = &conntracker.ConnTracker{}
opts.ConnTracker.Start(opts.EventBus, n)
go func() {
<-hostCtx.Done()
opts.ConnTracker.Stop()

}()
}

h := &BasicHost{
network: n,
psManager: psManager,
Expand All @@ -197,6 +216,7 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
ctx: hostCtx,
ctxCancel: cancel,
disableSignedPeerRecord: opts.DisableSignedPeerRecord,
connTracker: opts.ConnTracker,
}

h.updateLocalIpAddr()
Expand All @@ -207,6 +227,9 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
if h.emitters.evtLocalAddrsUpdated, err = h.eventbus.Emitter(&event.EvtLocalAddressesUpdated{}, eventbus.Stateful); err != nil {
return nil, err
}
if h.emitters.evtProtoNegotiation, err = h.eventbus.Emitter(&event.EvtProtocolNegotiationSuccess{}); err != nil {
return nil, err
}

if !h.disableSignedPeerRecord {
cab, ok := peerstore.GetCertifiedAddrBook(n.Peerstore())
Expand Down Expand Up @@ -685,6 +708,80 @@ func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) {
})
}

func (h *BasicHost) newConnFromConnTracker(ctx context.Context, p peer.ID) (conntracker.ConnWithMeta, error) {
connFilter := conntracker.NoLimitedConnFilter
if canUseLimitedConn, _ := network.GetAllowLimitedConn(ctx); canUseLimitedConn {
connFilter = nil
}

// requiredProtos is nil because we will fallback to MSS to negotiate the
// protocol if none of our prefrred protocols were broadcasted by identify.
var requiredProtos []protocol.ID

// Do we have a conn?
// TODO: for both usages of conntracker, use a sort fn that sorts by more streams.
conn, err := h.connTracker.GetBestConn(ctx, p, conntracker.GetBestConnOpts{
OneOf: requiredProtos,
FilterFn: connFilter,
WaitForIdentify: true,
AllowNoConn: true,
})
if err == nil {
return conn, nil
}

var errCh chan error
if nodial, _ := network.GetNoDial(ctx); !nodial {
errCh = make(chan error, 1)
go func() {
err := h.Connect(ctx, peer.AddrInfo{ID: p})
if err != nil {
select {
case errCh <- err:
default:
}
}
}()
}

// Wait for a connection that works for us
connChan, err := h.connTracker.GetBestConnChan(ctx, p, conntracker.GetBestConnOpts{
OneOf: requiredProtos,
FilterFn: connFilter,
WaitForIdentify: true, // Old behavior
})
if err != nil {
return conntracker.ConnWithMeta{}, err
}

select {
case <-ctx.Done():
return conntracker.ConnWithMeta{}, ctx.Err()
case <-h.ctx.Done():
return conntracker.ConnWithMeta{}, h.ctx.Err()
case err = <-errCh:
return conntracker.ConnWithMeta{}, err
case connWithMeta := <-connChan:
return connWithMeta, nil
}
}

func (h *BasicHost) newStreamWithConnTracker(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, protocol.ID, error) {
var preferredProto protocol.ID
connWithMeta, err := h.newConnFromConnTracker(ctx, p)
if err != nil {
return nil, "", err
}
for _, proto := range pids {
if connWithMeta.SupportsProtocol(proto) {
preferredProto = proto
break
}
}
s, err := connWithMeta.NewStream(ctx)
return s, preferredProto, err
}

// NewStream opens a new stream to given peer p, and writes a p2p/protocol
// header with given protocol.ID. If there is no connection to p, attempts
// to create one. If ProtocolID is "", writes no header.
Expand All @@ -698,56 +795,71 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
}
}

// If the caller wants to prevent the host from dialing, it should use the NoDial option.
if nodial, _ := network.GetNoDial(ctx); !nodial {
err := h.Connect(ctx, peer.AddrInfo{ID: p})
var (
err error
s network.Stream
preferredProto protocol.ID
)

if h.connTracker != nil {
s, preferredProto, err = h.newStreamWithConnTracker(ctx, p, pids...)
if err != nil {
return nil, err
}
}

s, err := h.Network().NewStream(network.WithNoDial(ctx, "already dialed"), p)
if err != nil {
// TODO: It would be nicer to get the actual error from the swarm,
// but this will require some more work.
if errors.Is(err, network.ErrNoConn) {
return nil, errors.New("connection failed")
} else {
// If the caller wants to prevent the host from dialing, it should use the NoDial option.
if nodial, _ := network.GetNoDial(ctx); !nodial {
err := h.Connect(ctx, peer.AddrInfo{ID: p})
if err != nil {
return nil, err
}
}
return nil, fmt.Errorf("failed to open stream: %w", err)
}
defer func() {
if strErr != nil && s != nil {
s.Reset()

s, err = h.Network().NewStream(network.WithNoDial(ctx, "already dialed"), p)
if err != nil {
// TODO: It would be nicer to get the actual error from the swarm,
// but this will require some more work.
if errors.Is(err, network.ErrNoConn) {
return nil, errors.New("connection failed")
}
return nil, fmt.Errorf("failed to open stream: %w", err)
}
}()
defer func() {
if strErr != nil && s != nil {
s.Reset()
}
}()

// Wait for any in-progress identifies on the connection to finish. This
// is faster than negotiating.
//
// If the other side doesn't support identify, that's fine. This will
// just be a no-op.
select {
case <-h.ids.IdentifyWait(s.Conn()):
case <-ctx.Done():
return nil, fmt.Errorf("identify failed to complete: %w", ctx.Err())
}
// Wait for any in-progress identifies on the connection to finish. This
// is faster than negotiating.
//
// If the other side doesn't support identify, that's fine. This will
// just be a no-op.
select {
case <-h.ids.IdentifyWait(s.Conn()):
case <-ctx.Done():
return nil, fmt.Errorf("identify failed to complete: %w", ctx.Err())
}

pref, err := h.preferredProtocol(p, pids)
if err != nil {
return nil, err
preferredProto, err = h.preferredProtocol(p, pids)
if err != nil {
return nil, err
}
}

if pref != "" {
if err := s.SetProtocol(pref); err != nil {
if preferredProto != "" {
if err := s.SetProtocol(preferredProto); err != nil {
return nil, err
}
lzcon := msmux.NewMSSelect(s, pref)
lzcon := msmux.NewMSSelect(s, preferredProto)
return &streamWrapper{
Stream: s,
rw: lzcon,
}, nil
}

// Fallback to MultiStreamSelect.

// Negotiate the protocol in the background, obeying the context.
var selected protocol.ID
errCh := make(chan error, 1)
Expand All @@ -771,6 +883,8 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
return nil, err
}
_ = h.Peerstore().AddProtocols(p, selected) // adding the protocol to the peerstore isn't critical
h.emitters.evtProtoNegotiation.Emit(event.EvtProtocolNegotiationSuccess{Peer: p, Conn: s.Conn(), Protocol: selected})

return s, nil
}

Expand Down
Loading
Loading