diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 8e6e8efe7c..313f441cb9 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -7,6 +7,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" "github.com/libp2p/go-libp2p/core/connmgr" @@ -646,24 +647,32 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I return nil, fmt.Errorf("failed to open stream: %w", err) } - // Wait for any in-progress identifies on the connection to finish. This - // is faster than negotiating. - // - // If the other side doesn't support identify, that's fine. This will - // just be a no-op. - select { - case <-h.ids.IdentifyWait(s.Conn()): - case <-ctx.Done(): - _ = s.Reset() - return nil, fmt.Errorf("identify failed to complete: %w", ctx.Err()) - } + // If pids contains only a single protocol, optimistically use that protocol (i.e. don't wait for + // multistream negotiation). + var pref protocol.ID + if len(pids) == 1 { + pref = pids[0] + } else if len(pids) > 1 { + // Wait for any in-progress identifies on the connection to finish. + // This is faster than negotiating. + // If the other side doesn't support identify, that's fine. This will just be a no-op. + select { + case <-h.ids.IdentifyWait(s.Conn()): + case <-ctx.Done(): + _ = s.Reset() + return nil, fmt.Errorf("identify failed to complete: %w", ctx.Err()) + } - pref, err := h.preferredProtocol(p, pids) - if err != nil { - _ = s.Reset() - return nil, err + // If Identify has finished, we know which protocols the peer supports. + // We don't need to do a multistream negotiation. + // Instead, we just pick the first supported protocol. + var err error + pref, err = h.preferredProtocol(p, pids) + if err != nil { + _ = s.Reset() + return nil, err + } } - if pref != "" { if err := s.SetProtocol(pref); err != nil { return nil, err @@ -736,22 +745,10 @@ func (h *BasicHost) Connect(ctx context.Context, pi peer.AddrInfo) error { // the connection once it has been opened. func (h *BasicHost) dialPeer(ctx context.Context, p peer.ID) error { log.Debugf("host %s dialing %s", h.ID(), p) - c, err := h.Network().DialPeer(ctx, p) - if err != nil { + if _, err := h.Network().DialPeer(ctx, p); err != nil { return fmt.Errorf("failed to dial: %w", err) } - // TODO: Consider removing this? On one hand, it's nice because we can - // assume that things like the agent version are usually set when this - // returns. On the other hand, we don't _really_ need to wait for this. - // - // This is mostly here to preserve existing behavior. - select { - case <-h.ids.IdentifyWait(c): - case <-ctx.Done(): - return fmt.Errorf("identify failed to complete: %w", ctx.Err()) - } - log.Debugf("host %s finished dialing %s", h.ID(), p) return nil } @@ -1049,14 +1046,26 @@ func (h *BasicHost) Close() error { type streamWrapper struct { network.Stream rw io.ReadWriteCloser + + calledRead atomic.Bool } func (s *streamWrapper) Read(b []byte) (int, error) { - return s.rw.Read(b) + n, err := s.rw.Read(b) + if s.calledRead.CompareAndSwap(false, true) { + if errors.Is(err, network.ErrReset) { + return n, msmux.ErrNotSupported[protocol.ID]{Protos: []protocol.ID{s.Protocol()}} + } + } + return n, err } func (s *streamWrapper) Write(b []byte) (int, error) { - return s.rw.Write(b) + n, err := s.rw.Write(b) + if s.calledRead.Load() && errors.Is(err, network.ErrReset) { + return n, msmux.ErrNotSupported[protocol.ID]{Protos: []protocol.ID{s.Protocol()}} + } + return n, err } func (s *streamWrapper) Close() error { diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 1fb3b4a397..673a1fc23c 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -535,8 +535,17 @@ func TestProtoDowngrade(t *testing.T) { // This is _almost_ instantaneous, but this test fails once every ~1k runs without this. time.Sleep(time.Millisecond) + sub, err := h1.EventBus().Subscribe(&event.EvtPeerIdentificationCompleted{}) + require.NoError(t, err) + defer sub.Close() + h2pi := h2.Peerstore().PeerInfo(h2.ID()) require.NoError(t, h1.Connect(ctx, h2pi)) + select { + case <-sub.Out(): + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } s2, err := h1.NewStream(ctx, h2.ID(), "/testing/1.0.0", "/testing") require.NoError(t, err) @@ -704,13 +713,12 @@ func TestHostAddrChangeDetection(t *testing.T) { } func TestNegotiationCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - h1, h2 := getHostPair(t) defer h1.Close() defer h2.Close() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // pre-negotiation so we can make the negotiation hang. h2.Network().SetStreamHandler(func(s network.Stream) { <-ctx.Done() // wait till the test is done. @@ -722,7 +730,7 @@ func TestNegotiationCancel(t *testing.T) { errCh := make(chan error, 1) go func() { - s, err := h1.NewStream(ctx2, h2.ID(), "/testing") + s, err := h1.NewStream(ctx2, h2.ID(), "/testing", "/testing2") if s != nil { errCh <- fmt.Errorf("expected to fail negotiation") return diff --git a/p2p/protocol/circuitv2/client/reservation.go b/p2p/protocol/circuitv2/client/reservation.go index dbb9241937..0cc2ad6429 100644 --- a/p2p/protocol/circuitv2/client/reservation.go +++ b/p2p/protocol/circuitv2/client/reservation.go @@ -89,7 +89,7 @@ func Reserve(ctx context.Context, h host.Host, ai peer.AddrInfo) (*Reservation, if err := rd.ReadMsg(&msg); err != nil { s.Reset() - return nil, ReservationError{Status: pbv2.Status_CONNECTION_FAILED, Reason: "error reading reservation response message: %w", err: err} + return nil, ReservationError{Status: pbv2.Status_CONNECTION_FAILED, Reason: "error reading reservation response message", err: err} } if msg.GetType() != pbv2.HopMessage_STATUS { diff --git a/p2p/protocol/holepunch/holepunch_test.go b/p2p/protocol/holepunch/holepunch_test.go index 1f3d9263df..af7244ac86 100644 --- a/p2p/protocol/holepunch/holepunch_test.go +++ b/p2p/protocol/holepunch/holepunch_test.go @@ -123,7 +123,8 @@ func TestDirectDialWorks(t *testing.T) { require.Empty(t, h1.Network().ConnsToPeer(h2.ID())) require.NoError(t, h1ps.DirectConnect(h2.ID())) require.GreaterOrEqual(t, len(h1.Network().ConnsToPeer(h2.ID())), 1) - require.GreaterOrEqual(t, len(h2.Network().ConnsToPeer(h1.ID())), 1) + // h1 might finish the handshake first, but h2 should see the connection shortly after + require.Eventually(t, func() bool { return len(h2.Network().ConnsToPeer(h1.ID())) > 0 }, time.Second, 25*time.Millisecond) events := tr.getEvents() require.Len(t, events, 1) require.Equal(t, holepunch.DirectDialEvtT, events[0].Type) @@ -340,9 +341,10 @@ func TestFailuresOnResponder(t *testing.T) { defer relay.Close() s, err := h2.NewStream(network.WithUseTransient(context.Background(), "holepunch"), h1.ID(), holepunch.Protocol) - require.NoError(t, err) - - go tc.initiator(s) + // h1 will reset the stream. This might or might not happen before multistream has finished. + if err == nil { + go tc.initiator(s) + } getTracerError := func(tr *mockEventTracer) []string { var errs []string @@ -487,6 +489,7 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePunc ID: h2.ID(), Addrs: []ma.Multiaddr{raddr}, })) + require.Eventually(t, func() bool { return len(h2.Network().ConnsToPeer(h1.ID())) > 0 }, time.Second, 50*time.Millisecond) return } diff --git a/p2p/protocol/identify/id_test.go b/p2p/protocol/identify/id_test.go index 02e4251434..0f9695ecb8 100644 --- a/p2p/protocol/identify/id_test.go +++ b/p2p/protocol/identify/id_test.go @@ -473,25 +473,25 @@ func TestUserAgent(t *testing.T) { defer cancel() h1, err := libp2p.New(libp2p.UserAgent("foo"), libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer h1.Close() h2, err := libp2p.New(libp2p.UserAgent("bar"), libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer h2.Close() - err = h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) - if err != nil { - t.Fatal(err) + sub, err := h1.EventBus().Subscribe(&event.EvtPeerIdentificationCompleted{}) + require.NoError(t, err) + defer sub.Close() + + require.NoError(t, h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})) + select { + case <-sub.Out(): + case <-time.After(time.Second): + t.Fatal("timeout") } av, err := h1.Peerstore().Get(h2.ID(), "AgentVersion") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if ver, ok := av.(string); !ok || ver != "bar" { t.Errorf("expected agent version %q, got %q", "bar", av) } diff --git a/p2p/test/quic/quic_test.go b/p2p/test/quic/quic_test.go index fe52119b89..c88bb581db 100644 --- a/p2p/test/quic/quic_test.go +++ b/p2p/test/quic/quic_test.go @@ -61,6 +61,7 @@ func TestQUICAndWebTransport(t *testing.T) { ) require.NoError(t, err) require.NoError(t, h2.Connect(ctx, peer.AddrInfo{ID: h1.ID(), Addrs: h1.Addrs()})) + require.Eventually(t, func() bool { return len(h1.Network().ConnsToPeer(h2.ID())) > 0 }, time.Second, 25*time.Millisecond) for _, conns := range [][]network.Conn{h2.Network().ConnsToPeer(h1.ID()), h1.Network().ConnsToPeer(h2.ID())} { require.Len(t, conns, 1) if _, err := conns[0].LocalMultiaddr().ValueForProtocol(ma.P_WEBTRANSPORT); err == nil { @@ -78,6 +79,7 @@ func TestQUICAndWebTransport(t *testing.T) { ) require.NoError(t, err) require.NoError(t, h3.Connect(ctx, peer.AddrInfo{ID: h1.ID(), Addrs: h1.Addrs()})) + require.Eventually(t, func() bool { return len(h1.Network().ConnsToPeer(h3.ID())) > 0 }, time.Second, 25*time.Millisecond) for _, conns := range [][]network.Conn{h3.Network().ConnsToPeer(h1.ID()), h1.Network().ConnsToPeer(h3.ID())} { require.Len(t, conns, 1) if _, err := conns[0].LocalMultiaddr().ValueForProtocol(ma.P_WEBTRANSPORT); err != nil { diff --git a/p2p/test/swarm/swarm_test.go b/p2p/test/swarm/swarm_test.go index 8027cebe53..6aab669048 100644 --- a/p2p/test/swarm/swarm_test.go +++ b/p2p/test/swarm/swarm_test.go @@ -193,6 +193,7 @@ func TestLimitStreamsWhenHangingHandlers(t *testing.T) { // Open streamLimit streams success := 0 + errCnt := 0 // we make a lot of tries because identify and identify push take up a few streams for i := 0; i < 1000 && success < streamLimit; i++ { mgr, err = rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(rcmgr.InfiniteLimits)) @@ -206,6 +207,7 @@ func TestLimitStreamsWhenHangingHandlers(t *testing.T) { s, err := sender.NewStream(context.Background(), receiver.ID(), pid) if err != nil { + errCnt++ continue } @@ -227,7 +229,11 @@ func TestLimitStreamsWhenHangingHandlers(t *testing.T) { sender.Peerstore().AddAddrs(receiver.ID(), receiver.Addrs(), peerstore.PermanentAddrTTL) - _, err = sender.NewStream(context.Background(), receiver.ID(), pid) + s, err := sender.NewStream(context.Background(), receiver.ID(), pid) + // stream is not received by the peer before the first write or read + require.NoError(t, err) + var b [1]byte + _, err = io.ReadFull(s, b[:]) require.Error(t, err) // Close the open streams @@ -236,6 +242,10 @@ func TestLimitStreamsWhenHangingHandlers(t *testing.T) { // Next call should succeed require.Eventually(t, func() bool { s, err := sender.NewStream(context.Background(), receiver.ID(), pid) + // stream is not received by the peer before the first write or read + require.NoError(t, err) + var b [1]byte + _, err = io.ReadFull(s, b[:]) if err == nil { s.Close() return true diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index df53da6eeb..d323c5383b 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -181,7 +181,8 @@ func TestInterceptAccept(t *testing.T) { } h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) - _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) + // use two protocols here, so we actually enter multistream negotiation + _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID, protocol.TestingID) require.Error(t, err) if _, err := h2.Addrs()[0].ValueForProtocol(ma.P_WEBRTC_DIRECT); err != nil { // WebRTC rejects connection attempt before an error can be sent to the client. @@ -218,7 +219,8 @@ func TestInterceptSecuredIncoming(t *testing.T) { }), ) h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) - _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) + // use two protocols here, so we actually enter multistream negotiation + _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID, protocol.TestingID) require.Error(t, err) require.NotErrorIs(t, err, context.DeadlineExceeded) }) @@ -254,7 +256,8 @@ func TestInterceptUpgradedIncoming(t *testing.T) { }), ) h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) - _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) + // use two protocols here, so we actually enter multistream negotiation + _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID, protocol.TestingID) require.Error(t, err) require.NotErrorIs(t, err, context.DeadlineExceeded) }) diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index e39b72a71a..4d5c3b0702 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -549,14 +549,19 @@ func TestListenerStreamResets(t *testing.T) { })) h1.SetStreamHandler("reset", func(s network.Stream) { + // Make sure the multistream negotiation actually succeeds before resetting. + // This is necessary because we don't have stream error codes yet. + s.Read(make([]byte, 4)) + s.Write([]byte("pong")) + s.Read(make([]byte, 4)) s.Reset() }) s, err := h2.NewStream(context.Background(), h1.ID(), "reset") - if err != nil { - require.ErrorIs(t, err, network.ErrReset) - return - } + require.NoError(t, err) + s.Write([]byte("ping")) + s.Read(make([]byte, 4)) + s.Write([]byte("ping")) _, err = s.Read([]byte{0}) require.ErrorIs(t, err, network.ErrReset)