Skip to content

Commit

Permalink
interpret stream resets as multistream errors
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Sep 6, 2023
1 parent b2d7d03 commit 45211f8
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 19 deletions.
47 changes: 40 additions & 7 deletions p2p/host/basic/basic_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"time"

"github.com/libp2p/go-libp2p/core/connmgr"
Expand Down Expand Up @@ -647,12 +648,32 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
return nil, fmt.Errorf("failed to open stream: %w", err)
}

pref, err := h.preferredProtocol(p, pids)
if err != nil {
_ = s.Reset()
return nil, err
}
// If pids contains only a single protocol, optimistically use that protocol (i.e. don't wait for
// multistream negotiation).
var pref protocol.ID
if len(pids) == 1 {
pref = pids[0]
} else if len(pids) > 1 {
// Wait for any in-progress identifies on the connection to finish.
// This is faster than negotiating.
// If the other side doesn't support identify, that's fine. This will just be a no-op.
select {
case <-h.ids.IdentifyWait(s.Conn()):
case <-ctx.Done():
_ = s.Reset()
return nil, fmt.Errorf("identify failed to complete: %w", ctx.Err())
}

// If Identify has finished, we know which protocols the peer supports.
// We don't need to do a multistream negotiation.
// Instead, we just pick the first supported protocol.
var err error
pref, err = h.preferredProtocol(p, pids)
if err != nil {
_ = s.Reset()
return nil, err
}
}
if pref != "" {
if err := s.SetProtocol(pref); err != nil {
return nil, err
Expand Down Expand Up @@ -1026,14 +1047,26 @@ func (h *BasicHost) Close() error {
type streamWrapper struct {
network.Stream
rw io.ReadWriteCloser

calledRead atomic.Bool
}

func (s *streamWrapper) Read(b []byte) (int, error) {
return s.rw.Read(b)
n, err := s.rw.Read(b)
if s.calledRead.CompareAndSwap(false, true) {
if errors.Is(err, network.ErrReset) {
return n, msmux.ErrNotSupported[protocol.ID]{Protos: []protocol.ID{s.Protocol()}}
}
}
return n, err
}

func (s *streamWrapper) Write(b []byte) (int, error) {
return s.rw.Write(b)
n, err := s.rw.Write(b)
if s.calledRead.Load() && errors.Is(err, network.ErrReset) {
return n, msmux.ErrNotSupported[protocol.ID]{Protos: []protocol.ID{s.Protocol()}}
}
return n, err
}

func (s *streamWrapper) Close() error {
Expand Down
7 changes: 3 additions & 4 deletions p2p/host/basic/basic_host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,13 +713,12 @@ func TestHostAddrChangeDetection(t *testing.T) {
}

func TestNegotiationCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

h1, h2 := getHostPair(t)
defer h1.Close()
defer h2.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// pre-negotiation so we can make the negotiation hang.
h2.Network().SetStreamHandler(func(s network.Stream) {
<-ctx.Done() // wait till the test is done.
Expand All @@ -731,7 +730,7 @@ func TestNegotiationCancel(t *testing.T) {

errCh := make(chan error, 1)
go func() {
s, err := h1.NewStream(ctx2, h2.ID(), "/testing")
s, err := h1.NewStream(ctx2, h2.ID(), "/testing", "/testing2")
if s != nil {
errCh <- fmt.Errorf("expected to fail negotiation")
return
Expand Down
2 changes: 1 addition & 1 deletion p2p/protocol/circuitv2/client/reservation.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func Reserve(ctx context.Context, h host.Host, ai peer.AddrInfo) (*Reservation,

if err := rd.ReadMsg(&msg); err != nil {
s.Reset()
return nil, ReservationError{Status: pbv2.Status_CONNECTION_FAILED, Reason: "error reading reservation response message: %w", err: err}
return nil, ReservationError{Status: pbv2.Status_CONNECTION_FAILED, Reason: "error reading reservation response message", err: err}
}

if msg.GetType() != pbv2.HopMessage_STATUS {
Expand Down
9 changes: 6 additions & 3 deletions p2p/test/transport/gating_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ func TestInterceptAccept(t *testing.T) {
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
})
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID)
// use two protocols here, so we actually enter multistream negotiation
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID, protocol.TestingID)
require.Error(t, err)
require.NotErrorIs(t, err, context.DeadlineExceeded)
})
Expand Down Expand Up @@ -195,7 +196,8 @@ func TestInterceptSecuredIncoming(t *testing.T) {
}),
)
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID)
// use two protocols here, so we actually enter multistream negotiation
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID, protocol.TestingID)
require.Error(t, err)
require.NotErrorIs(t, err, context.DeadlineExceeded)
})
Expand Down Expand Up @@ -229,7 +231,8 @@ func TestInterceptUpgradedIncoming(t *testing.T) {
}),
)
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID)
// use two protocols here, so we actually enter multistream negotiation
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID, protocol.TestingID)
require.Error(t, err)
require.NotErrorIs(t, err, context.DeadlineExceeded)
})
Expand Down
13 changes: 9 additions & 4 deletions p2p/test/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,14 +527,19 @@ func TestListenerStreamResets(t *testing.T) {
}))

h1.SetStreamHandler("reset", func(s network.Stream) {
// Make sure the multistream negotiation actually succeeds before resetting.
// This is necessary because we don't have stream error codes yet.
s.Read(make([]byte, 4))
s.Write([]byte("pong"))
s.Read(make([]byte, 4))
s.Reset()
})

s, err := h2.NewStream(context.Background(), h1.ID(), "reset")
if err != nil {
require.ErrorIs(t, err, network.ErrReset)
return
}
require.NoError(t, err)
s.Write([]byte("ping"))
s.Read(make([]byte, 4))
s.Write([]byte("ping"))

_, err = s.Read([]byte{0})
require.ErrorIs(t, err, network.ErrReset)
Expand Down

0 comments on commit 45211f8

Please sign in to comment.