Skip to content

Commit

Permalink
simplify demultiplex a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoPolo committed Oct 30, 2024
1 parent 2c053d8 commit 7e34d05
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 169 deletions.
179 changes: 15 additions & 164 deletions p2p/transport/tcpreuse/demultiplex.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,9 @@ import (
"bufio"
"errors"
"fmt"
"io"
"math"
"net"
"time"

"github.com/libp2p/go-libp2p/core/network"
ma "github.com/multiformats/go-multiaddr"
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse/internal/sampledconn"
manet "github.com/multiformats/go-multiaddr/net"
)

Expand Down Expand Up @@ -52,13 +48,17 @@ func (t DemultiplexedConnType) IsKnown() bool {
return t >= 1 || t <= 3
}

func getDemultiplexedConn(c net.Conn, scope network.ConnManagementScope) (DemultiplexedConnType, manet.Conn, error) {
// identifyConnType attempts to identify the connection type by peeking at the
// first few bytes.
// It Callers must not use the passed in Conn after this
// function returns. if an error is returned, the connection will be closed.
func identifyConnType(c manet.Conn) (DemultiplexedConnType, manet.Conn, error) {
if err := c.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil {
closeErr := c.Close()
return 0, nil, errors.Join(err, closeErr)
}

s, sc, err := readSampleFromConn(c, scope)
s, c, err := sampledconn.PeekBytes(c)
if err != nil {
closeErr := c.Close()
return 0, nil, errors.Join(err, closeErr)
Expand All @@ -70,174 +70,25 @@ func getDemultiplexedConn(c net.Conn, scope network.ConnManagementScope) (Demult
}

if IsMultistreamSelect(s) {
return DemultiplexedConnType_MultistreamSelect, sc, nil
return DemultiplexedConnType_MultistreamSelect, c, nil
}
if IsTLS(s) {
return DemultiplexedConnType_TLS, sc, nil
return DemultiplexedConnType_TLS, c, nil
}
if IsHTTP(s) {
return DemultiplexedConnType_HTTP, sc, nil
return DemultiplexedConnType_HTTP, c, nil
}
return DemultiplexedConnType_Unknown, sc, nil
return DemultiplexedConnType_Unknown, c, nil
}

// readSampleFromConn reads a sample and returns a reader which still includes the sample, so it can be kept undamaged.
// If an error occurs it only returns the error.
func readSampleFromConn(c net.Conn, scope network.ConnManagementScope) (Sample, manet.Conn, error) {
// TODO: Should we remove this? This is only implemented by bufio.Reader.
// This made sense for magiselect: https://github.com/libp2p/go-libp2p/pull/2737 as it deals with a wrapped
// ReadWriteCloser from multistream which does use a buffered reader underneath.
// For our present purpose, we have a net.Conn and no net.Conn implementation offers peeking.
if peekAble, ok := c.(peekAble); ok {
b, err := peekAble.Peek(len(Sample{}))
switch {
case err == nil:
mac, err := manet.WrapNetConn(c)
if err != nil {
return Sample{}, nil, err
}

return Sample(b), mac, nil
case errors.Is(err, bufio.ErrBufferFull):
// We can only peek < len(Sample{}) data.
// fallback to sampledConn
default:
return Sample{}, nil, err
}
}

tcpConnLike, ok := c.(tcpConnInterface)
if !ok {
return Sample{}, nil, fmt.Errorf("expected tcp-like connection")
}

laddr, err := manet.FromNetAddr(c.LocalAddr())
if err != nil {
return Sample{}, nil, fmt.Errorf("failed to convert nconn.LocalAddr: %s", err)
}

raddr, err := manet.FromNetAddr(c.RemoteAddr())
if err != nil {
return Sample{}, nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err)
}

sc := &sampledConn{
tcpConnInterface: tcpConnLike,
maEndpoints: maEndpoints{laddr: laddr, raddr: raddr},
scope: scope,
}
_, err = io.ReadFull(c, sc.s[:])
if err != nil {
return Sample{}, nil, err
}
return sc.s, sc, nil
}

// tcpConnInterface is the interface for TCPConn's functions
// Note: Skipping `SyscallConn() (syscall.RawConn, error)` since it can be misused given we've read a few bytes from the connection.
// TODO: allow SyscallConn? Disallowing it breaks metrics tracking in TCP Transport.
type tcpConnInterface interface {
net.Conn

CloseRead() error
CloseWrite() error

SetLinger(sec int) error
SetKeepAlive(keepalive bool) error
SetKeepAlivePeriod(d time.Duration) error
SetNoDelay(noDelay bool) error
MultipathTCP() (bool, error)

io.ReaderFrom
io.WriterTo
}

type maEndpoints struct {
laddr ma.Multiaddr
raddr ma.Multiaddr
}

// LocalMultiaddr returns the local address associated with
// this connection
func (c *maEndpoints) LocalMultiaddr() ma.Multiaddr {
return c.laddr
}

// RemoteMultiaddr returns the remote address associated with
// this connection
func (c *maEndpoints) RemoteMultiaddr() ma.Multiaddr {
return c.raddr
}

type sampledConn struct {
tcpConnInterface
maEndpoints
scope network.ConnManagementScope
s Sample
readFromSample uint8
}

var _ = [math.MaxUint8]struct{}{}[len(Sample{})] // compiletime assert sampledConn.readFromSample wont overflow
var _ io.ReaderFrom = (*sampledConn)(nil)
var _ io.WriterTo = (*sampledConn)(nil)

func (sc *sampledConn) Read(b []byte) (int, error) {
if int(sc.readFromSample) != len(sc.s) {
red := copy(b, sc.s[sc.readFromSample:])
sc.readFromSample += uint8(red)
return red, nil
}

return sc.tcpConnInterface.Read(b)
}

// TODO: Do we need these?

func (sc *sampledConn) ReadFrom(r io.Reader) (int64, error) {
return io.Copy(sc.tcpConnInterface, r)
}

func (sc *sampledConn) WriteTo(w io.Writer) (total int64, err error) {
if int(sc.readFromSample) != len(sc.s) {
b := sc.s[sc.readFromSample:]
written, err := w.Write(b)
if written < 0 || len(b) < written {
// buggy writer, harden against this
sc.readFromSample = uint8(len(sc.s))
total = int64(len(sc.s))
} else {
sc.readFromSample += uint8(written)
total += int64(written)
}
if err != nil {
return total, err
}
}

written, err := io.Copy(w, sc.tcpConnInterface)
total += written
return total, err
}

func (sc *sampledConn) Scope() network.ConnManagementScope {
return sc.scope
}

func (sc *sampledConn) Close() error {
sc.scope.Done()
return sc.tcpConnInterface.Close()
}

// Sample is the byte sequence we use to demultiplex.
type Sample [3]byte

// Matchers are implemented here instead of in the transports so we can easily fuzz them together.
type Prefix = [3]byte

func IsMultistreamSelect(s Sample) bool {
func IsMultistreamSelect(s Prefix) bool {
return string(s[:]) == "\x13/m"
}

func IsHTTP(s Sample) bool {
func IsHTTP(s Prefix) bool {
switch string(s[:]) {
case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT":
return true
Expand All @@ -246,7 +97,7 @@ func IsHTTP(s Sample) bool {
}
}

func IsTLS(s Sample) bool {
func IsTLS(s Prefix) bool {
switch string(s[:]) {
case "\x16\x03\x01", "\x16\x03\x02", "\x16\x03\x03", "\x16\x03\x04":
return true
Expand Down
2 changes: 1 addition & 1 deletion p2p/transport/tcpreuse/demultiplex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func FuzzClash(f *testing.F) {
add('\x16', '\x03', '\x04')

f.Fuzz(func(t *testing.T, a, b, c byte) {
s := Sample{a, b, c}
s := Prefix{a, b, c}
var total uint

ms := IsMultistreamSelect(s)
Expand Down
19 changes: 15 additions & 4 deletions p2p/transport/tcpreuse/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func (m *multiplexedListener) run() error {
go func() {
defer func() { <-acceptQueue }()
defer m.wg.Done()
t, sampleC, err := getDemultiplexedConn(c, connScope)
t, c, err := identifyConnType(c)
if err != nil {
connScope.Done()
closeErr := c.Close()
Expand All @@ -229,11 +229,22 @@ func (m *multiplexedListener) run() error {
return
}

// TODO: Add a test that makes sure we can get the SyscallConn in Unix platforms.
// Wrap the scope into the conn.
connWithScope, err := manetConnWithScope(c, connScope)
if err != nil {
connScope.Done()
closeErr := c.Close()
err = errors.Join(err, closeErr)
log.Debugf("error wrapping connection with scope: %s", err.Error())
return
}

m.mx.RLock()
demux, ok := m.listeners[t]
m.mx.RUnlock()
if !ok {
closeErr := sampleC.Close()
closeErr := connWithScope.Close()
if closeErr != nil {
log.Debugf("no registered listener for demultiplex connection %s. Error closing the connection %s", t, closeErr.Error())
} else {
Expand All @@ -243,9 +254,9 @@ func (m *multiplexedListener) run() error {
}

select {
case demux.buffer <- sampleC:
case demux.buffer <- connWithScope:
case <-m.ctx.Done():
sampleC.Close()
connWithScope.Close()
return
}
}()
Expand Down

0 comments on commit 7e34d05

Please sign in to comment.