From 45211f8def972cf1367149c483d8e7cca32085a4 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 5 Sep 2023 13:06:54 +0700 Subject: [PATCH] interpret stream resets as multistream errors --- p2p/host/basic/basic_host.go | 47 +++++++++++++++++--- p2p/host/basic/basic_host_test.go | 7 ++- p2p/protocol/circuitv2/client/reservation.go | 2 +- p2p/test/transport/gating_test.go | 9 ++-- p2p/test/transport/transport_test.go | 13 ++++-- 5 files changed, 59 insertions(+), 19 deletions(-) diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 08bf67868e..116429c8f2 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -7,6 +7,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" "github.com/libp2p/go-libp2p/core/connmgr" @@ -647,12 +648,32 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I return nil, fmt.Errorf("failed to open stream: %w", err) } - pref, err := h.preferredProtocol(p, pids) - if err != nil { - _ = s.Reset() - return nil, err - } + // If pids contains only a single protocol, optimistically use that protocol (i.e. don't wait for + // multistream negotiation). + var pref protocol.ID + if len(pids) == 1 { + pref = pids[0] + } else if len(pids) > 1 { + // Wait for any in-progress identifies on the connection to finish. + // This is faster than negotiating. + // If the other side doesn't support identify, that's fine. This will just be a no-op. + select { + case <-h.ids.IdentifyWait(s.Conn()): + case <-ctx.Done(): + _ = s.Reset() + return nil, fmt.Errorf("identify failed to complete: %w", ctx.Err()) + } + // If Identify has finished, we know which protocols the peer supports. + // We don't need to do a multistream negotiation. + // Instead, we just pick the first supported protocol. + var err error + pref, err = h.preferredProtocol(p, pids) + if err != nil { + _ = s.Reset() + return nil, err + } + } if pref != "" { if err := s.SetProtocol(pref); err != nil { return nil, err @@ -1026,14 +1047,26 @@ func (h *BasicHost) Close() error { type streamWrapper struct { network.Stream rw io.ReadWriteCloser + + calledRead atomic.Bool } func (s *streamWrapper) Read(b []byte) (int, error) { - return s.rw.Read(b) + n, err := s.rw.Read(b) + if s.calledRead.CompareAndSwap(false, true) { + if errors.Is(err, network.ErrReset) { + return n, msmux.ErrNotSupported[protocol.ID]{Protos: []protocol.ID{s.Protocol()}} + } + } + return n, err } func (s *streamWrapper) Write(b []byte) (int, error) { - return s.rw.Write(b) + n, err := s.rw.Write(b) + if s.calledRead.Load() && errors.Is(err, network.ErrReset) { + return n, msmux.ErrNotSupported[protocol.ID]{Protos: []protocol.ID{s.Protocol()}} + } + return n, err } func (s *streamWrapper) Close() error { diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 9748bb0560..556f484c3d 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -713,13 +713,12 @@ func TestHostAddrChangeDetection(t *testing.T) { } func TestNegotiationCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - h1, h2 := getHostPair(t) defer h1.Close() defer h2.Close() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // pre-negotiation so we can make the negotiation hang. h2.Network().SetStreamHandler(func(s network.Stream) { <-ctx.Done() // wait till the test is done. @@ -731,7 +730,7 @@ func TestNegotiationCancel(t *testing.T) { errCh := make(chan error, 1) go func() { - s, err := h1.NewStream(ctx2, h2.ID(), "/testing") + s, err := h1.NewStream(ctx2, h2.ID(), "/testing", "/testing2") if s != nil { errCh <- fmt.Errorf("expected to fail negotiation") return diff --git a/p2p/protocol/circuitv2/client/reservation.go b/p2p/protocol/circuitv2/client/reservation.go index dbb9241937..0cc2ad6429 100644 --- a/p2p/protocol/circuitv2/client/reservation.go +++ b/p2p/protocol/circuitv2/client/reservation.go @@ -89,7 +89,7 @@ func Reserve(ctx context.Context, h host.Host, ai peer.AddrInfo) (*Reservation, if err := rd.ReadMsg(&msg); err != nil { s.Reset() - return nil, ReservationError{Status: pbv2.Status_CONNECTION_FAILED, Reason: "error reading reservation response message: %w", err: err} + return nil, ReservationError{Status: pbv2.Status_CONNECTION_FAILED, Reason: "error reading reservation response message", err: err} } if msg.GetType() != pbv2.HopMessage_STATUS { diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index 426fc906e5..45dfccaebb 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -164,7 +164,8 @@ func TestInterceptAccept(t *testing.T) { require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) }) h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) - _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) + // use two protocols here, so we actually enter multistream negotiation + _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID, protocol.TestingID) require.Error(t, err) require.NotErrorIs(t, err, context.DeadlineExceeded) }) @@ -195,7 +196,8 @@ func TestInterceptSecuredIncoming(t *testing.T) { }), ) h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) - _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) + // use two protocols here, so we actually enter multistream negotiation + _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID, protocol.TestingID) require.Error(t, err) require.NotErrorIs(t, err, context.DeadlineExceeded) }) @@ -229,7 +231,8 @@ func TestInterceptUpgradedIncoming(t *testing.T) { }), ) h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) - _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) + // use two protocols here, so we actually enter multistream negotiation + _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID, protocol.TestingID) require.Error(t, err) require.NotErrorIs(t, err, context.DeadlineExceeded) }) diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 370ef9b114..41f8f4ee5f 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -527,14 +527,19 @@ func TestListenerStreamResets(t *testing.T) { })) h1.SetStreamHandler("reset", func(s network.Stream) { + // Make sure the multistream negotiation actually succeeds before resetting. + // This is necessary because we don't have stream error codes yet. + s.Read(make([]byte, 4)) + s.Write([]byte("pong")) + s.Read(make([]byte, 4)) s.Reset() }) s, err := h2.NewStream(context.Background(), h1.ID(), "reset") - if err != nil { - require.ErrorIs(t, err, network.ErrReset) - return - } + require.NoError(t, err) + s.Write([]byte("ping")) + s.Read(make([]byte, 4)) + s.Write([]byte("ping")) _, err = s.Read([]byte{0}) require.ErrorIs(t, err, network.ErrReset)