Skip to content

Commit 6fede0d

Browse files
marten-seemannsukunrt
authored andcommitted
interpret stream resets as multistream errors
1 parent 6817a9d commit 6fede0d

File tree

5 files changed

+59
-19
lines changed

5 files changed

+59
-19
lines changed

p2p/host/basic/basic_host.go

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io"
88
"net"
99
"sync"
10+
"sync/atomic"
1011
"time"
1112

1213
"github.com/libp2p/go-libp2p/core/connmgr"
@@ -646,12 +647,32 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
646647
return nil, fmt.Errorf("failed to open stream: %w", err)
647648
}
648649

649-
pref, err := h.preferredProtocol(p, pids)
650-
if err != nil {
651-
_ = s.Reset()
652-
return nil, err
653-
}
650+
// If pids contains only a single protocol, optimistically use that protocol (i.e. don't wait for
651+
// multistream negotiation).
652+
var pref protocol.ID
653+
if len(pids) == 1 {
654+
pref = pids[0]
655+
} else if len(pids) > 1 {
656+
// Wait for any in-progress identifies on the connection to finish.
657+
// This is faster than negotiating.
658+
// If the other side doesn't support identify, that's fine. This will just be a no-op.
659+
select {
660+
case <-h.ids.IdentifyWait(s.Conn()):
661+
case <-ctx.Done():
662+
_ = s.Reset()
663+
return nil, fmt.Errorf("identify failed to complete: %w", ctx.Err())
664+
}
654665

666+
// If Identify has finished, we know which protocols the peer supports.
667+
// We don't need to do a multistream negotiation.
668+
// Instead, we just pick the first supported protocol.
669+
var err error
670+
pref, err = h.preferredProtocol(p, pids)
671+
if err != nil {
672+
_ = s.Reset()
673+
return nil, err
674+
}
675+
}
655676
if pref != "" {
656677
if err := s.SetProtocol(pref); err != nil {
657678
return nil, err
@@ -1025,14 +1046,26 @@ func (h *BasicHost) Close() error {
10251046
type streamWrapper struct {
10261047
network.Stream
10271048
rw io.ReadWriteCloser
1049+
1050+
calledRead atomic.Bool
10281051
}
10291052

10301053
func (s *streamWrapper) Read(b []byte) (int, error) {
1031-
return s.rw.Read(b)
1054+
n, err := s.rw.Read(b)
1055+
if s.calledRead.CompareAndSwap(false, true) {
1056+
if errors.Is(err, network.ErrReset) {
1057+
return n, msmux.ErrNotSupported[protocol.ID]{Protos: []protocol.ID{s.Protocol()}}
1058+
}
1059+
}
1060+
return n, err
10321061
}
10331062

10341063
func (s *streamWrapper) Write(b []byte) (int, error) {
1035-
return s.rw.Write(b)
1064+
n, err := s.rw.Write(b)
1065+
if s.calledRead.Load() && errors.Is(err, network.ErrReset) {
1066+
return n, msmux.ErrNotSupported[protocol.ID]{Protos: []protocol.ID{s.Protocol()}}
1067+
}
1068+
return n, err
10361069
}
10371070

10381071
func (s *streamWrapper) Close() error {

p2p/host/basic/basic_host_test.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -713,13 +713,12 @@ func TestHostAddrChangeDetection(t *testing.T) {
713713
}
714714

715715
func TestNegotiationCancel(t *testing.T) {
716-
ctx, cancel := context.WithCancel(context.Background())
717-
defer cancel()
718-
719716
h1, h2 := getHostPair(t)
720717
defer h1.Close()
721718
defer h2.Close()
722719

720+
ctx, cancel := context.WithCancel(context.Background())
721+
defer cancel()
723722
// pre-negotiation so we can make the negotiation hang.
724723
h2.Network().SetStreamHandler(func(s network.Stream) {
725724
<-ctx.Done() // wait till the test is done.
@@ -731,7 +730,7 @@ func TestNegotiationCancel(t *testing.T) {
731730

732731
errCh := make(chan error, 1)
733732
go func() {
734-
s, err := h1.NewStream(ctx2, h2.ID(), "/testing")
733+
s, err := h1.NewStream(ctx2, h2.ID(), "/testing", "/testing2")
735734
if s != nil {
736735
errCh <- fmt.Errorf("expected to fail negotiation")
737736
return

p2p/protocol/circuitv2/client/reservation.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func Reserve(ctx context.Context, h host.Host, ai peer.AddrInfo) (*Reservation,
8989

9090
if err := rd.ReadMsg(&msg); err != nil {
9191
s.Reset()
92-
return nil, ReservationError{Status: pbv2.Status_CONNECTION_FAILED, Reason: "error reading reservation response message: %w", err: err}
92+
return nil, ReservationError{Status: pbv2.Status_CONNECTION_FAILED, Reason: "error reading reservation response message", err: err}
9393
}
9494

9595
if msg.GetType() != pbv2.HopMessage_STATUS {

p2p/test/transport/gating_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ func TestInterceptAccept(t *testing.T) {
181181
}
182182

183183
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
184-
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID)
184+
// use two protocols here, so we actually enter multistream negotiation
185+
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID, protocol.TestingID)
185186
require.Error(t, err)
186187
if _, err := h2.Addrs()[0].ValueForProtocol(ma.P_WEBRTC_DIRECT); err != nil {
187188
// WebRTC rejects connection attempt before an error can be sent to the client.
@@ -218,7 +219,8 @@ func TestInterceptSecuredIncoming(t *testing.T) {
218219
}),
219220
)
220221
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
221-
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID)
222+
// use two protocols here, so we actually enter multistream negotiation
223+
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID, protocol.TestingID)
222224
require.Error(t, err)
223225
require.NotErrorIs(t, err, context.DeadlineExceeded)
224226
})
@@ -254,7 +256,8 @@ func TestInterceptUpgradedIncoming(t *testing.T) {
254256
}),
255257
)
256258
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
257-
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID)
259+
// use two protocols here, so we actually enter multistream negotiation
260+
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID, protocol.TestingID)
258261
require.Error(t, err)
259262
require.NotErrorIs(t, err, context.DeadlineExceeded)
260263
})

p2p/test/transport/transport_test.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -549,14 +549,19 @@ func TestListenerStreamResets(t *testing.T) {
549549
}))
550550

551551
h1.SetStreamHandler("reset", func(s network.Stream) {
552+
// Make sure the multistream negotiation actually succeeds before resetting.
553+
// This is necessary because we don't have stream error codes yet.
554+
s.Read(make([]byte, 4))
555+
s.Write([]byte("pong"))
556+
s.Read(make([]byte, 4))
552557
s.Reset()
553558
})
554559

555560
s, err := h2.NewStream(context.Background(), h1.ID(), "reset")
556-
if err != nil {
557-
require.ErrorIs(t, err, network.ErrReset)
558-
return
559-
}
561+
require.NoError(t, err)
562+
s.Write([]byte("ping"))
563+
s.Read(make([]byte, 4))
564+
s.Write([]byte("ping"))
560565

561566
_, err = s.Read([]byte{0})
562567
require.ErrorIs(t, err, network.ErrReset)

0 commit comments

Comments
 (0)