diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go index 342e7de0b3..864dc76040 100644 --- a/p2p/transport/tcpreuse/demultiplex.go +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -4,13 +4,9 @@ import ( "bufio" "errors" "fmt" - "io" - "math" - "net" "time" - "github.com/libp2p/go-libp2p/core/network" - ma "github.com/multiformats/go-multiaddr" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse/internal/sampledconn" manet "github.com/multiformats/go-multiaddr/net" ) @@ -52,13 +48,17 @@ func (t DemultiplexedConnType) IsKnown() bool { return t >= 1 || t <= 3 } -func getDemultiplexedConn(c net.Conn, scope network.ConnManagementScope) (DemultiplexedConnType, manet.Conn, error) { +// identifyConnType attempts to identify the connection type by peeking at the +// first few bytes. +// It Callers must not use the passed in Conn after this +// function returns. if an error is returned, the connection will be closed. +func identifyConnType(c manet.Conn) (DemultiplexedConnType, manet.Conn, error) { if err := c.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { closeErr := c.Close() return 0, nil, errors.Join(err, closeErr) } - s, sc, err := readSampleFromConn(c, scope) + s, c, err := sampledconn.PeekBytes(c) if err != nil { closeErr := c.Close() return 0, nil, errors.Join(err, closeErr) @@ -70,174 +70,25 @@ func getDemultiplexedConn(c net.Conn, scope network.ConnManagementScope) (Demult } if IsMultistreamSelect(s) { - return DemultiplexedConnType_MultistreamSelect, sc, nil + return DemultiplexedConnType_MultistreamSelect, c, nil } if IsTLS(s) { - return DemultiplexedConnType_TLS, sc, nil + return DemultiplexedConnType_TLS, c, nil } if IsHTTP(s) { - return DemultiplexedConnType_HTTP, sc, nil + return DemultiplexedConnType_HTTP, c, nil } - return DemultiplexedConnType_Unknown, sc, nil + return DemultiplexedConnType_Unknown, c, nil } -// readSampleFromConn reads a sample and returns a reader which still includes the sample, so it can be kept undamaged. -// If an error occurs it only returns the error. -func readSampleFromConn(c net.Conn, scope network.ConnManagementScope) (Sample, manet.Conn, error) { - // TODO: Should we remove this? This is only implemented by bufio.Reader. - // This made sense for magiselect: https://github.com/libp2p/go-libp2p/pull/2737 as it deals with a wrapped - // ReadWriteCloser from multistream which does use a buffered reader underneath. - // For our present purpose, we have a net.Conn and no net.Conn implementation offers peeking. - if peekAble, ok := c.(peekAble); ok { - b, err := peekAble.Peek(len(Sample{})) - switch { - case err == nil: - mac, err := manet.WrapNetConn(c) - if err != nil { - return Sample{}, nil, err - } - - return Sample(b), mac, nil - case errors.Is(err, bufio.ErrBufferFull): - // We can only peek < len(Sample{}) data. - // fallback to sampledConn - default: - return Sample{}, nil, err - } - } - - tcpConnLike, ok := c.(tcpConnInterface) - if !ok { - return Sample{}, nil, fmt.Errorf("expected tcp-like connection") - } - - laddr, err := manet.FromNetAddr(c.LocalAddr()) - if err != nil { - return Sample{}, nil, fmt.Errorf("failed to convert nconn.LocalAddr: %s", err) - } - - raddr, err := manet.FromNetAddr(c.RemoteAddr()) - if err != nil { - return Sample{}, nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err) - } - - sc := &sampledConn{ - tcpConnInterface: tcpConnLike, - maEndpoints: maEndpoints{laddr: laddr, raddr: raddr}, - scope: scope, - } - _, err = io.ReadFull(c, sc.s[:]) - if err != nil { - return Sample{}, nil, err - } - return sc.s, sc, nil -} - -// tcpConnInterface is the interface for TCPConn's functions -// Note: Skipping `SyscallConn() (syscall.RawConn, error)` since it can be misused given we've read a few bytes from the connection. -// TODO: allow SyscallConn? Disallowing it breaks metrics tracking in TCP Transport. -type tcpConnInterface interface { - net.Conn - - CloseRead() error - CloseWrite() error - - SetLinger(sec int) error - SetKeepAlive(keepalive bool) error - SetKeepAlivePeriod(d time.Duration) error - SetNoDelay(noDelay bool) error - MultipathTCP() (bool, error) - - io.ReaderFrom - io.WriterTo -} - -type maEndpoints struct { - laddr ma.Multiaddr - raddr ma.Multiaddr -} - -// LocalMultiaddr returns the local address associated with -// this connection -func (c *maEndpoints) LocalMultiaddr() ma.Multiaddr { - return c.laddr -} - -// RemoteMultiaddr returns the remote address associated with -// this connection -func (c *maEndpoints) RemoteMultiaddr() ma.Multiaddr { - return c.raddr -} - -type sampledConn struct { - tcpConnInterface - maEndpoints - scope network.ConnManagementScope - s Sample - readFromSample uint8 -} - -var _ = [math.MaxUint8]struct{}{}[len(Sample{})] // compiletime assert sampledConn.readFromSample wont overflow -var _ io.ReaderFrom = (*sampledConn)(nil) -var _ io.WriterTo = (*sampledConn)(nil) - -func (sc *sampledConn) Read(b []byte) (int, error) { - if int(sc.readFromSample) != len(sc.s) { - red := copy(b, sc.s[sc.readFromSample:]) - sc.readFromSample += uint8(red) - return red, nil - } - - return sc.tcpConnInterface.Read(b) -} - -// TODO: Do we need these? - -func (sc *sampledConn) ReadFrom(r io.Reader) (int64, error) { - return io.Copy(sc.tcpConnInterface, r) -} - -func (sc *sampledConn) WriteTo(w io.Writer) (total int64, err error) { - if int(sc.readFromSample) != len(sc.s) { - b := sc.s[sc.readFromSample:] - written, err := w.Write(b) - if written < 0 || len(b) < written { - // buggy writer, harden against this - sc.readFromSample = uint8(len(sc.s)) - total = int64(len(sc.s)) - } else { - sc.readFromSample += uint8(written) - total += int64(written) - } - if err != nil { - return total, err - } - } - - written, err := io.Copy(w, sc.tcpConnInterface) - total += written - return total, err -} - -func (sc *sampledConn) Scope() network.ConnManagementScope { - return sc.scope -} - -func (sc *sampledConn) Close() error { - sc.scope.Done() - return sc.tcpConnInterface.Close() -} - -// Sample is the byte sequence we use to demultiplex. -type Sample [3]byte - // Matchers are implemented here instead of in the transports so we can easily fuzz them together. +type Prefix = [3]byte -func IsMultistreamSelect(s Sample) bool { +func IsMultistreamSelect(s Prefix) bool { return string(s[:]) == "\x13/m" } -func IsHTTP(s Sample) bool { +func IsHTTP(s Prefix) bool { switch string(s[:]) { case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT": return true @@ -246,7 +97,7 @@ func IsHTTP(s Sample) bool { } } -func IsTLS(s Sample) bool { +func IsTLS(s Prefix) bool { switch string(s[:]) { case "\x16\x03\x01", "\x16\x03\x02", "\x16\x03\x03", "\x16\x03\x04": return true diff --git a/p2p/transport/tcpreuse/demultiplex_test.go b/p2p/transport/tcpreuse/demultiplex_test.go index 3d6e91f35a..e201f2ca75 100644 --- a/p2p/transport/tcpreuse/demultiplex_test.go +++ b/p2p/transport/tcpreuse/demultiplex_test.go @@ -25,7 +25,7 @@ func FuzzClash(f *testing.F) { add('\x16', '\x03', '\x04') f.Fuzz(func(t *testing.T, a, b, c byte) { - s := Sample{a, b, c} + s := Prefix{a, b, c} var total uint ms := IsMultistreamSelect(s) diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index 55fd85ed56..2aa61a0fb0 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -220,7 +220,7 @@ func (m *multiplexedListener) run() error { go func() { defer func() { <-acceptQueue }() defer m.wg.Done() - t, sampleC, err := getDemultiplexedConn(c, connScope) + t, c, err := identifyConnType(c) if err != nil { connScope.Done() closeErr := c.Close() @@ -229,11 +229,22 @@ func (m *multiplexedListener) run() error { return } + // TODO: Add a test that makes sure we can get the SyscallConn in Unix platforms. + // Wrap the scope into the conn. + connWithScope, err := manetConnWithScope(c, connScope) + if err != nil { + connScope.Done() + closeErr := c.Close() + err = errors.Join(err, closeErr) + log.Debugf("error wrapping connection with scope: %s", err.Error()) + return + } + m.mx.RLock() demux, ok := m.listeners[t] m.mx.RUnlock() if !ok { - closeErr := sampleC.Close() + closeErr := connWithScope.Close() if closeErr != nil { log.Debugf("no registered listener for demultiplex connection %s. Error closing the connection %s", t, closeErr.Error()) } else { @@ -243,9 +254,9 @@ func (m *multiplexedListener) run() error { } select { - case demux.buffer <- sampleC: + case demux.buffer <- connWithScope: case <-m.ctx.Done(): - sampleC.Close() + connWithScope.Close() return } }()