From 6eae9ab059b14d2827779124130c9f2307b9c166 Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Mon, 11 Nov 2024 23:38:55 -0500 Subject: [PATCH] fix(tcpreuse): handle connection that failed to be sampled --- p2p/transport/tcpreuse/demultiplex.go | 18 +++++++++--------- p2p/transport/tcpreuse/listener.go | 2 -- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go index fe58243d67..3b45aec7f5 100644 --- a/p2p/transport/tcpreuse/demultiplex.go +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -40,35 +40,35 @@ func (t DemultiplexedConnType) IsKnown() bool { // identifyConnType attempts to identify the connection type by peeking at the // first few bytes. -// It Callers must not use the passed in Conn after this -// function returns. if an error is returned, the connection will be closed. +// Its Callers must not use the passed in Conn after this function returns. +// If an error is returned, the connection will be closed. func identifyConnType(c manet.Conn) (DemultiplexedConnType, manet.Conn, error) { if err := c.SetReadDeadline(time.Now().Add(identifyConnTimeout)); err != nil { closeErr := c.Close() return 0, nil, errors.Join(err, closeErr) } - s, c, err := sampledconn.PeekBytes(c) + s, peekableConn, err := sampledconn.PeekBytes(c) if err != nil { closeErr := c.Close() return 0, nil, errors.Join(err, closeErr) } - if err := c.SetReadDeadline(time.Time{}); err != nil { - closeErr := c.Close() + if err := peekableConn.SetReadDeadline(time.Time{}); err != nil { + closeErr := peekableConn.Close() return 0, nil, errors.Join(err, closeErr) } if IsMultistreamSelect(s) { - return DemultiplexedConnType_MultistreamSelect, c, nil + return DemultiplexedConnType_MultistreamSelect, peekableConn, nil } if IsTLS(s) { - return DemultiplexedConnType_TLS, c, nil + return DemultiplexedConnType_TLS, peekableConn, nil } if IsHTTP(s) { - return DemultiplexedConnType_HTTP, c, nil + return DemultiplexedConnType_HTTP, peekableConn, nil } - return DemultiplexedConnType_Unknown, c, nil + return DemultiplexedConnType_Unknown, peekableConn, nil } // Matchers are implemented here instead of in the transports so we can easily fuzz them together. diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index 326e1e15b7..d94186e7ec 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -231,8 +231,6 @@ func (m *multiplexedListener) run() error { t, c, err := identifyConnType(c) if err != nil { connScope.Done() - closeErr := c.Close() - err = errors.Join(err, closeErr) log.Debugf("error demultiplexing connection: %s", err.Error()) return }