diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn.go similarity index 69% rename from p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go rename to p2p/transport/tcpreuse/internal/sampledconn/sampledconn.go index 7324b45849..ff1f8caf44 100644 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn.go @@ -14,30 +14,17 @@ const peekSize = 3 type PeekedBytes = [peekSize]byte -var errNotSupported = errors.New("not supported on this platform") - var ErrNotTCPConn = errors.New("passed conn is not a TCPConn") func PeekBytes(conn manet.Conn) (PeekedBytes, manet.Conn, error) { - if c, ok := conn.(syscall.Conn); ok { - b, err := OSPeekConn(c) - if err == nil { - return b, conn, nil - } - if err != errNotSupported { - return PeekedBytes{}, nil, err - } - // Fallback to wrapping the coonn - } - if c, ok := conn.(ManetTCPConnInterface); ok { - return newFallbackSampledConn(c) + return newWrappedSampledConn(c) } return PeekedBytes{}, nil, ErrNotTCPConn } -type fallbackPeekingConn struct { +type wrappedSampledConn struct { ManetTCPConnInterface peekedBytes PeekedBytes bytesPeeked uint8 @@ -69,16 +56,19 @@ type ManetTCPConnInterface interface { tcpConnInterface } -func newFallbackSampledConn(conn ManetTCPConnInterface) (PeekedBytes, *fallbackPeekingConn, error) { - s := &fallbackPeekingConn{ManetTCPConnInterface: conn} - _, err := io.ReadFull(conn, s.peekedBytes[:]) +func newWrappedSampledConn(conn ManetTCPConnInterface) (PeekedBytes, *wrappedSampledConn, error) { + s := &wrappedSampledConn{ManetTCPConnInterface: conn} + n, err := io.ReadFull(conn, s.peekedBytes[:]) if err != nil { + if n == 0 && err == io.EOF { + err = io.ErrUnexpectedEOF + } return s.peekedBytes, nil, err } return s.peekedBytes, s, nil } -func (sc *fallbackPeekingConn) Read(b []byte) (int, error) { +func (sc *wrappedSampledConn) Read(b []byte) (int, error) { if int(sc.bytesPeeked) != len(sc.peekedBytes) { red := copy(b, sc.peekedBytes[sc.bytesPeeked:]) sc.bytesPeeked += uint8(red) diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go deleted file mode 100644 index 5197052fab..0000000000 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build !unix - -package sampledconn - -import ( - "syscall" -) - -func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) { - return PeekedBytes{}, errNotSupported -} diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go index d5b31009e2..6c4e989b16 100644 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go @@ -10,6 +10,7 @@ import ( manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSampledConn(t *testing.T) { @@ -63,7 +64,7 @@ func TestSampledConn(t *testing.T) { assert.Equal(t, "hello", string(buf)) } else { // Wrap the client connection in SampledConn - sample, sampledConn, err := newFallbackSampledConn(clientConn.(ManetTCPConnInterface)) + sample, sampledConn, err := newWrappedSampledConn(clientConn.(ManetTCPConnInterface)) assert.NoError(t, err) assert.Equal(t, "hel", string(sample[:])) @@ -76,3 +77,102 @@ func TestSampledConn(t *testing.T) { }) } } + +func spawnServerAndClientConn(t *testing.T) (serverConn manet.Conn, clientConn manet.Conn) { + serverConnChan := make(chan manet.Conn, 1) + + listener, err := manet.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) + assert.NoError(t, err) + defer listener.Close() + + serverAddr := listener.Multiaddr() + + // Server goroutine + go func() { + conn, err := listener.Accept() + assert.NoError(t, err) + serverConnChan <- conn + }() + + // Give the server a moment to start + time.Sleep(100 * time.Millisecond) + + // Create a TCP client + clientConn, err = manet.Dial(serverAddr) + assert.NoError(t, err) + + return <-serverConnChan, clientConn +} + +func TestHandleNoBytes(t *testing.T) { + serverConn, clientConn := spawnServerAndClientConn(t) + defer clientConn.Close() + + // Server goroutine + go func() { + serverConn.Close() + }() + _, _, err := PeekBytes(clientConn.(interface { + manet.Conn + syscall.Conn + })) + assert.ErrorIs(t, err, io.ErrUnexpectedEOF) +} + +func TestHandle1ByteAndClose(t *testing.T) { + serverConn, clientConn := spawnServerAndClientConn(t) + defer clientConn.Close() + + // Server goroutine + go func() { + defer serverConn.Close() + _, err := serverConn.Write([]byte("h")) + assert.NoError(t, err) + }() + _, _, err := PeekBytes(clientConn.(interface { + manet.Conn + syscall.Conn + })) + assert.ErrorIs(t, err, io.ErrUnexpectedEOF) +} + +func TestSlowBytes(t *testing.T) { + serverConn, clientConn := spawnServerAndClientConn(t) + + interval := 100 * time.Millisecond + + // Server goroutine + go func() { + defer serverConn.Close() + + time.Sleep(interval) + _, err := serverConn.Write([]byte("h")) + assert.NoError(t, err) + time.Sleep(interval) + _, err = serverConn.Write([]byte("e")) + assert.NoError(t, err) + time.Sleep(interval) + _, err = serverConn.Write([]byte("l")) + assert.NoError(t, err) + time.Sleep(interval) + _, err = serverConn.Write([]byte("lo")) + assert.NoError(t, err) + }() + + defer clientConn.Close() + + err := clientConn.SetReadDeadline(time.Now().Add(interval * 10)) + require.NoError(t, err) + + peeked, clientConn, err := PeekBytes(clientConn.(interface { + manet.Conn + syscall.Conn + })) + assert.NoError(t, err) + assert.Equal(t, "hel", string(peeked[:])) + + buf := make([]byte, 5) + _, err = io.ReadFull(clientConn, buf) + assert.NoError(t, err) + assert.Equal(t, "hello", string(buf)) +} diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go deleted file mode 100644 index 9847e8d4be..0000000000 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go +++ /dev/null @@ -1,42 +0,0 @@ -//go:build unix - -package sampledconn - -import ( - "errors" - "syscall" -) - -func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) { - s := PeekedBytes{} - - rawConn, err := conn.SyscallConn() - if err != nil { - return s, err - } - - readBytes := 0 - var readErr error - err = rawConn.Read(func(fd uintptr) bool { - for readBytes < peekSize { - var n int - n, _, readErr = syscall.Recvfrom(int(fd), s[readBytes:], syscall.MSG_PEEK) - if errors.Is(readErr, syscall.EAGAIN) { - return false - } - if readErr != nil { - return true - } - readBytes += n - } - return true - }) - if readErr != nil { - return s, readErr - } - if err != nil { - return s, err - } - - return s, nil -}