Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: basichost: Use NegotiationTimeout as fallback timeout for NewStream #3020

Merged
merged 3 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions p2p/host/basic/basic_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,10 @@ type HostOpts struct {
// MultistreamMuxer is essential for the *BasicHost and will use a sensible default value if omitted.
MultistreamMuxer *msmux.MultistreamMuxer[protocol.ID]

// NegotiationTimeout determines the read and write timeouts on streams.
// If 0 or omitted, it will use DefaultNegotiationTimeout.
// If below 0, timeouts on streams will be deactivated.
// NegotiationTimeout determines the read and write timeouts when negotiating
// protocols for streams. If 0 or omitted, it will use
// DefaultNegotiationTimeout. If below 0, timeouts on streams will be
// deactivated.
NegotiationTimeout time.Duration

// AddrsFactory holds a function which can be used to override or filter the result of Addrs.
Expand Down Expand Up @@ -689,6 +690,14 @@ func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) {
// to create one. If ProtocolID is "", writes no header.
// (Thread-safe)
func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (str network.Stream, strErr error) {
if _, ok := ctx.Deadline(); !ok {
if h.negtimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, h.negtimeout)
defer cancel()
}
}

// If the caller wants to prevent the host from dialing, it should use the NoDial option.
if nodial, _ := network.GetNoDial(ctx); !nodial {
err := h.Connect(ctx, peer.AddrInfo{ID: p})
Expand Down
54 changes: 54 additions & 0 deletions p2p/host/basic/basic_host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package basichost

import (
"context"
"encoding/binary"
"fmt"
"io"
"reflect"
Expand Down Expand Up @@ -941,3 +942,56 @@ func TestTrimHostAddrList(t *testing.T) {
})
}
}

func TestHostTimeoutNewStream(t *testing.T) {
h1, err := NewHost(swarmt.GenSwarm(t), nil)
require.NoError(t, err)
h1.Start()
defer h1.Close()

const proto = "/testing"
h2 := swarmt.GenSwarm(t)

h2.SetStreamHandler(func(s network.Stream) {
// First message is multistream header. Just echo it
msHeader := []byte("\x19/multistream/1.0.0\n")
_, err := s.Read(msHeader)
assert.NoError(t, err)
_, err = s.Write(msHeader)
assert.NoError(t, err)

buf := make([]byte, 1024)
n, err := s.Read(buf)
assert.NoError(t, err)

msgLen, varintN := binary.Uvarint(buf[:n])
buf = buf[varintN:]
proto := buf[:int(msgLen)]
if string(proto) == "/ipfs/id/1.0.0\n" {
// Signal we don't support identify
na := []byte("na\n")
n := binary.PutUvarint(buf, uint64(len(na)))
copy(buf[n:], na)

_, err = s.Write(buf[:int(n)+len(na)])
assert.NoError(t, err)
} else {
// Stall
time.Sleep(5 * time.Second)
}
t.Log("Resetting")
s.Reset()
})

err = h1.Connect(context.Background(), peer.AddrInfo{
ID: h2.LocalPeer(),
Addrs: h2.ListenAddresses(),
})
require.NoError(t, err)

// No context passed in, fallback to negtimeout
h1.negtimeout = time.Second
_, err = h1.NewStream(context.Background(), h2.LocalPeer(), proto)
require.Error(t, err)
require.ErrorContains(t, err, "context deadline exceeded")
}
Loading