Skip to content

Commit

Permalink
fix: basichost: Use NegotiationTimeout as fallback timeout for NewStr…
Browse files Browse the repository at this point in the history
…eam (#3020)
  • Loading branch information
MarcoPolo authored Nov 4, 2024
1 parent 5a47a90 commit c31f093
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
15 changes: 12 additions & 3 deletions p2p/host/basic/basic_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,10 @@ type HostOpts struct {
// MultistreamMuxer is essential for the *BasicHost and will use a sensible default value if omitted.
MultistreamMuxer *msmux.MultistreamMuxer[protocol.ID]

// NegotiationTimeout determines the read and write timeouts on streams.
// If 0 or omitted, it will use DefaultNegotiationTimeout.
// If below 0, timeouts on streams will be deactivated.
// NegotiationTimeout determines the read and write timeouts when negotiating
// protocols for streams. If 0 or omitted, it will use
// DefaultNegotiationTimeout. If below 0, timeouts on streams will be
// deactivated.
NegotiationTimeout time.Duration

// AddrsFactory holds a function which can be used to override or filter the result of Addrs.
Expand Down Expand Up @@ -689,6 +690,14 @@ func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) {
// to create one. If ProtocolID is "", writes no header.
// (Thread-safe)
func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (str network.Stream, strErr error) {
if _, ok := ctx.Deadline(); !ok {
if h.negtimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, h.negtimeout)
defer cancel()
}
}

// If the caller wants to prevent the host from dialing, it should use the NoDial option.
if nodial, _ := network.GetNoDial(ctx); !nodial {
err := h.Connect(ctx, peer.AddrInfo{ID: p})
Expand Down
54 changes: 54 additions & 0 deletions p2p/host/basic/basic_host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package basichost

import (
"context"
"encoding/binary"
"fmt"
"io"
"reflect"
Expand Down Expand Up @@ -941,3 +942,56 @@ func TestTrimHostAddrList(t *testing.T) {
})
}
}

func TestHostTimeoutNewStream(t *testing.T) {
h1, err := NewHost(swarmt.GenSwarm(t), nil)
require.NoError(t, err)
h1.Start()
defer h1.Close()

const proto = "/testing"
h2 := swarmt.GenSwarm(t)

h2.SetStreamHandler(func(s network.Stream) {
// First message is multistream header. Just echo it
msHeader := []byte("\x19/multistream/1.0.0\n")
_, err := s.Read(msHeader)
assert.NoError(t, err)
_, err = s.Write(msHeader)
assert.NoError(t, err)

buf := make([]byte, 1024)
n, err := s.Read(buf)
assert.NoError(t, err)

msgLen, varintN := binary.Uvarint(buf[:n])
buf = buf[varintN:]
proto := buf[:int(msgLen)]
if string(proto) == "/ipfs/id/1.0.0\n" {
// Signal we don't support identify
na := []byte("na\n")
n := binary.PutUvarint(buf, uint64(len(na)))
copy(buf[n:], na)

_, err = s.Write(buf[:int(n)+len(na)])
assert.NoError(t, err)
} else {
// Stall
time.Sleep(5 * time.Second)
}
t.Log("Resetting")
s.Reset()
})

err = h1.Connect(context.Background(), peer.AddrInfo{
ID: h2.LocalPeer(),
Addrs: h2.ListenAddresses(),
})
require.NoError(t, err)

// No context passed in, fallback to negtimeout
h1.negtimeout = time.Second
_, err = h1.NewStream(context.Background(), h2.LocalPeer(), proto)
require.Error(t, err)
require.ErrorContains(t, err, "context deadline exceeded")
}

0 comments on commit c31f093

Please sign in to comment.