Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sampledconn): Correctly handle slow bytes and closed conns #3080

Merged
merged 11 commits into from
Dec 10, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,17 @@ const peekSize = 3

type PeekedBytes = [peekSize]byte

var errNotSupported = errors.New("not supported on this platform")

var ErrNotTCPConn = errors.New("passed conn is not a TCPConn")

func PeekBytes(conn manet.Conn) (PeekedBytes, manet.Conn, error) {
if c, ok := conn.(syscall.Conn); ok {
b, err := OSPeekConn(c)
if err == nil {
return b, conn, nil
}
if err != errNotSupported {
return PeekedBytes{}, nil, err
}
// Fallback to wrapping the coonn
}

if c, ok := conn.(ManetTCPConnInterface); ok {
return newFallbackSampledConn(c)
return newWrappedSampledConn(c)
}

return PeekedBytes{}, nil, ErrNotTCPConn
}

type fallbackPeekingConn struct {
type wrappedSampledConn struct {
ManetTCPConnInterface
peekedBytes PeekedBytes
bytesPeeked uint8
Expand Down Expand Up @@ -69,16 +56,19 @@ type ManetTCPConnInterface interface {
tcpConnInterface
}

func newFallbackSampledConn(conn ManetTCPConnInterface) (PeekedBytes, *fallbackPeekingConn, error) {
s := &fallbackPeekingConn{ManetTCPConnInterface: conn}
_, err := io.ReadFull(conn, s.peekedBytes[:])
func newWrappedSampledConn(conn ManetTCPConnInterface) (PeekedBytes, *wrappedSampledConn, error) {
s := &wrappedSampledConn{ManetTCPConnInterface: conn}
n, err := io.ReadFull(conn, s.peekedBytes[:])
if err != nil {
if n == 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return s.peekedBytes, nil, err
}
return s.peekedBytes, s, nil
}

func (sc *fallbackPeekingConn) Read(b []byte) (int, error) {
func (sc *wrappedSampledConn) Read(b []byte) (int, error) {
if int(sc.bytesPeeked) != len(sc.peekedBytes) {
red := copy(b, sc.peekedBytes[sc.bytesPeeked:])
sc.bytesPeeked += uint8(red)
Expand Down
11 changes: 0 additions & 11 deletions p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go

This file was deleted.

102 changes: 101 additions & 1 deletion p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
manet "github.com/multiformats/go-multiaddr/net"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestSampledConn(t *testing.T) {
Expand Down Expand Up @@ -63,7 +64,7 @@ func TestSampledConn(t *testing.T) {
assert.Equal(t, "hello", string(buf))
} else {
// Wrap the client connection in SampledConn
sample, sampledConn, err := newFallbackSampledConn(clientConn.(ManetTCPConnInterface))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we rename fallbackPeekingConn to sampledConn?

sample, sampledConn, err := newWrappedSampledConn(clientConn.(ManetTCPConnInterface))
assert.NoError(t, err)
assert.Equal(t, "hel", string(sample[:]))

Expand All @@ -76,3 +77,102 @@ func TestSampledConn(t *testing.T) {
})
}
}

func spawnServerAndClientConn(t *testing.T) (serverConn manet.Conn, clientConn manet.Conn) {
serverConnChan := make(chan manet.Conn, 1)

listener, err := manet.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
assert.NoError(t, err)
defer listener.Close()

serverAddr := listener.Multiaddr()

// Server goroutine
go func() {
conn, err := listener.Accept()
assert.NoError(t, err)
serverConnChan <- conn
}()

// Give the server a moment to start
time.Sleep(100 * time.Millisecond)

// Create a TCP client
clientConn, err = manet.Dial(serverAddr)
assert.NoError(t, err)

return <-serverConnChan, clientConn
}

func TestHandleNoBytes(t *testing.T) {
serverConn, clientConn := spawnServerAndClientConn(t)
defer clientConn.Close()

// Server goroutine
go func() {
serverConn.Close()
}()
_, _, err := PeekBytes(clientConn.(interface {
manet.Conn
syscall.Conn
}))
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
}

func TestHandle1ByteAndClose(t *testing.T) {
serverConn, clientConn := spawnServerAndClientConn(t)
defer clientConn.Close()

// Server goroutine
go func() {
defer serverConn.Close()
_, err := serverConn.Write([]byte("h"))
assert.NoError(t, err)
}()
_, _, err := PeekBytes(clientConn.(interface {
manet.Conn
syscall.Conn
}))
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
}

func TestSlowBytes(t *testing.T) {
serverConn, clientConn := spawnServerAndClientConn(t)

interval := 100 * time.Millisecond

// Server goroutine
go func() {
defer serverConn.Close()

time.Sleep(interval)
_, err := serverConn.Write([]byte("h"))
assert.NoError(t, err)
time.Sleep(interval)
_, err = serverConn.Write([]byte("e"))
assert.NoError(t, err)
time.Sleep(interval)
_, err = serverConn.Write([]byte("l"))
assert.NoError(t, err)
time.Sleep(interval)
_, err = serverConn.Write([]byte("lo"))
assert.NoError(t, err)
}()

defer clientConn.Close()

err := clientConn.SetReadDeadline(time.Now().Add(interval * 10))
require.NoError(t, err)

peeked, clientConn, err := PeekBytes(clientConn.(interface {
manet.Conn
syscall.Conn
}))
assert.NoError(t, err)
assert.Equal(t, "hel", string(peeked[:]))

buf := make([]byte, 5)
_, err = io.ReadFull(clientConn, buf)
assert.NoError(t, err)
assert.Equal(t, "hello", string(buf))
}
42 changes: 0 additions & 42 deletions p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go

This file was deleted.

Loading