diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index f7d3c5275a..820411bd27 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -122,9 +122,10 @@ type HostOpts struct { // MultistreamMuxer is essential for the *BasicHost and will use a sensible default value if omitted. MultistreamMuxer *msmux.MultistreamMuxer[protocol.ID] - // NegotiationTimeout determines the read and write timeouts on streams. - // If 0 or omitted, it will use DefaultNegotiationTimeout. - // If below 0, timeouts on streams will be deactivated. + // NegotiationTimeout determines the read and write timeouts when negotiating + // protocols for streams. If 0 or omitted, it will use + // DefaultNegotiationTimeout. If below 0, timeouts on streams will be + // deactivated. NegotiationTimeout time.Duration // AddrsFactory holds a function which can be used to override or filter the result of Addrs. @@ -689,6 +690,14 @@ func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) { // to create one. If ProtocolID is "", writes no header. // (Thread-safe) func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (str network.Stream, strErr error) { + if _, ok := ctx.Deadline(); !ok { + if h.negtimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, h.negtimeout) + defer cancel() + } + } + // If the caller wants to prevent the host from dialing, it should use the NoDial option. if nodial, _ := network.GetNoDial(ctx); !nodial { err := h.Connect(ctx, peer.AddrInfo{ID: p}) diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 1ab98aae9d..2a7a772976 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -2,6 +2,7 @@ package basichost import ( "context" + "encoding/binary" "fmt" "io" "reflect" @@ -941,3 +942,56 @@ func TestTrimHostAddrList(t *testing.T) { }) } } + +func TestHostTimeoutNewStream(t *testing.T) { + h1, err := NewHost(swarmt.GenSwarm(t), nil) + require.NoError(t, err) + h1.Start() + defer h1.Close() + + const proto = "/testing" + h2 := swarmt.GenSwarm(t) + + h2.SetStreamHandler(func(s network.Stream) { + // First message is multistream header. Just echo it + msHeader := []byte("\x19/multistream/1.0.0\n") + _, err := s.Read(msHeader) + assert.NoError(t, err) + _, err = s.Write(msHeader) + assert.NoError(t, err) + + buf := make([]byte, 1024) + n, err := s.Read(buf) + assert.NoError(t, err) + + msgLen, varintN := binary.Uvarint(buf[:n]) + buf = buf[varintN:] + proto := buf[:int(msgLen)] + if string(proto) == "/ipfs/id/1.0.0\n" { + // Signal we don't support identify + na := []byte("na\n") + n := binary.PutUvarint(buf, uint64(len(na))) + copy(buf[n:], na) + + _, err = s.Write(buf[:int(n)+len(na)]) + assert.NoError(t, err) + } else { + // Stall + time.Sleep(5 * time.Second) + } + t.Log("Resetting") + s.Reset() + }) + + err = h1.Connect(context.Background(), peer.AddrInfo{ + ID: h2.LocalPeer(), + Addrs: h2.ListenAddresses(), + }) + require.NoError(t, err) + + // No context passed in, fallback to negtimeout + h1.negtimeout = time.Second + _, err = h1.NewStream(context.Background(), h2.LocalPeer(), proto) + require.Error(t, err) + require.ErrorContains(t, err, "context deadline exceeded") +}