From ed7f901037991b84b651a7e42ba808f1a7f79614 Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 6 Oct 2024 21:30:18 +0530 Subject: [PATCH] create conn scope early to prevent DoS attacks --- p2p/net/upgrader/listener.go | 38 +++++++++++------- p2p/transport/tcpreuse/demultiplex.go | 26 +++++++++---- p2p/transport/tcpreuse/listener.go | 52 +++++++++++++++++++++---- p2p/transport/tcpreuse/listener_test.go | 10 ++--- p2p/transport/websocket/conn.go | 11 ++++++ 5 files changed, 102 insertions(+), 35 deletions(-) diff --git a/p2p/net/upgrader/listener.go b/p2p/net/upgrader/listener.go index 8af2791b36..0530bde292 100644 --- a/p2p/net/upgrader/listener.go +++ b/p2p/net/upgrader/listener.go @@ -84,23 +84,33 @@ func (l *listener) handleIncoming() { } catcher.Reset() - // gate the connection if applicable - if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) { - log.Debugf("gater blocked incoming connection on local addr %s from %s", - maconn.LocalMultiaddr(), maconn.RemoteMultiaddr()) - if err := maconn.Close(); err != nil { - log.Warnf("failed to close incoming connection rejected by gater: %s", err) - } - continue + var connScope network.ConnManagementScope + if sc, ok := maconn.(interface { + Scope() network.ConnManagementScope + }); ok { + connScope = sc.Scope() } - connScope, err := l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr()) - if err != nil { - log.Debugw("resource manager blocked accept of new connection", "error", err) - if err := maconn.Close(); err != nil { - log.Warnf("failed to incoming connection rejected by resource manager: %s", err) + if connScope != nil { + // gate the connection if applicable + if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) { + log.Debugf("gater blocked incoming connection on local addr %s from %s", + maconn.LocalMultiaddr(), maconn.RemoteMultiaddr()) + if err := maconn.Close(); err != nil { + log.Warnf("failed to close incoming connection rejected by gater: %s", err) + } + continue + } + + var err error + connScope, err = l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr()) + if err != nil { + log.Debugw("resource manager blocked accept of new connection", "error", err) + if err := maconn.Close(); err != nil { + log.Warnf("failed to incoming connection rejected by resource manager: %s", err) + } + continue } - continue } // The go routine below calls Release when the context is diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go index 2036c91437..342e7de0b3 100644 --- a/p2p/transport/tcpreuse/demultiplex.go +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -9,6 +9,7 @@ import ( "net" "time" + "github.com/libp2p/go-libp2p/core/network" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" ) @@ -51,13 +52,13 @@ func (t DemultiplexedConnType) IsKnown() bool { return t >= 1 || t <= 3 } -func getDemultiplexedConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) { +func getDemultiplexedConn(c net.Conn, scope network.ConnManagementScope) (DemultiplexedConnType, manet.Conn, error) { if err := c.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { closeErr := c.Close() return 0, nil, errors.Join(err, closeErr) } - s, sc, err := ReadSampleFromConn(c) + s, sc, err := readSampleFromConn(c, scope) if err != nil { closeErr := c.Close() return 0, nil, errors.Join(err, closeErr) @@ -80,9 +81,9 @@ func getDemultiplexedConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) return DemultiplexedConnType_Unknown, sc, nil } -// ReadSampleFromConn reads a sample and returns a reader which still includes the sample, so it can be kept undamaged. +// readSampleFromConn reads a sample and returns a reader which still includes the sample, so it can be kept undamaged. // If an error occurs it only returns the error. -func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) { +func readSampleFromConn(c net.Conn, scope network.ConnManagementScope) (Sample, manet.Conn, error) { // TODO: Should we remove this? This is only implemented by bufio.Reader. // This made sense for magiselect: https://github.com/libp2p/go-libp2p/pull/2737 as it deals with a wrapped // ReadWriteCloser from multistream which does use a buffered reader underneath. @@ -120,7 +121,11 @@ func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) { return Sample{}, nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err) } - sc := &sampledConn{tcpConnInterface: tcpConnLike, maEndpoints: maEndpoints{laddr: laddr, raddr: raddr}} + sc := &sampledConn{ + tcpConnInterface: tcpConnLike, + maEndpoints: maEndpoints{laddr: laddr, raddr: raddr}, + scope: scope, + } _, err = io.ReadFull(c, sc.s[:]) if err != nil { return Sample{}, nil, err @@ -167,7 +172,7 @@ func (c *maEndpoints) RemoteMultiaddr() ma.Multiaddr { type sampledConn struct { tcpConnInterface maEndpoints - + scope network.ConnManagementScope s Sample readFromSample uint8 } @@ -214,8 +219,13 @@ func (sc *sampledConn) WriteTo(w io.Writer) (total int64, err error) { return total, err } -type Matcher interface { - Match(s Sample) bool +func (sc *sampledConn) Scope() network.ConnManagementScope { + return sc.scope +} + +func (sc *sampledConn) Close() error { + sc.scope.Done() + return sc.tcpConnInterface.Close() } // Sample is the byte sequence we use to demultiplex. diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index 73286d5006..4a5bfa119b 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -8,6 +8,8 @@ import ( "sync" logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/connmgr" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/net/reuseport" ma "github.com/multiformats/go-multiaddr" @@ -22,14 +24,22 @@ var log = logging.Logger("tcp-demultiplex") type ConnMgr struct { disableReuseport bool reuse reuseport.Transport - listeners map[string]*multiplexedListener - mx sync.Mutex + connGater connmgr.ConnectionGater + rcmgr network.ResourceManager + + mx sync.Mutex + listeners map[string]*multiplexedListener } -func NewConnMgr(disableReuseport bool) *ConnMgr { +func NewConnMgr(disableReuseport bool, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *ConnMgr { + if rcmgr == nil { + rcmgr = &network.NullResourceManager{} + } return &ConnMgr{ disableReuseport: disableReuseport, reuse: reuseport.Transport{}, + connGater: gater, + rcmgr: rcmgr, listeners: make(map[string]*multiplexedListener), } } @@ -104,6 +114,8 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed listeners: make(map[DemultiplexedConnType]*demultiplexedListener), ctx: ctx, closeFn: cancelFunc, + connGater: t.connGater, + rcmgr: t.rcmgr, } dl, err := ml.DemultiplexedListen(connType) @@ -127,9 +139,11 @@ type multiplexedListener struct { listeners map[DemultiplexedConnType]*demultiplexedListener mx sync.RWMutex - ctx context.Context - closeFn func() error - wg sync.WaitGroup + connGater connmgr.ConnectionGater + rcmgr network.ResourceManager + ctx context.Context + closeFn func() error + wg sync.WaitGroup } func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (manet.Listener, error) { @@ -167,6 +181,26 @@ func (m *multiplexedListener) run() error { if err != nil { return err } + + // gate the connection if applicable + if m.connGater != nil && !m.connGater.InterceptAccept(c) { + log.Debugf("gater blocked incoming connection on local addr %s from %s", + c.LocalMultiaddr(), c.RemoteMultiaddr()) + if err := c.Close(); err != nil { + log.Warnf("failed to close incoming connection rejected by gater: %s", err) + } + continue + } + + connScope, err := m.rcmgr.OpenConnection(network.DirInbound, true, c.RemoteMultiaddr()) + if err != nil { + log.Debugw("resource manager blocked accept of new connection", "error", err) + if err := c.Close(); err != nil { + log.Warnf("failed to incoming connection rejected by resource manager: %s", err) + } + continue + } + select { case acceptQueue <- struct{}{}: case <-m.ctx.Done(): @@ -180,18 +214,20 @@ func (m *multiplexedListener) run() error { defer m.wg.Done() // TODO: if/how do we want to handle stalled connections and stop them from clogging up the pipeline? // Drop connection because the buffer is full - t, sampleC, err := getDemultiplexedConn(c) + t, sampleC, err := getDemultiplexedConn(c, connScope) if err != nil { + connScope.Done() closeErr := c.Close() err = errors.Join(err, closeErr) log.Debugf("error demultiplexing connection: %s", err.Error()) return } + m.mx.RLock() demux, ok := m.listeners[t] m.mx.RUnlock() if !ok { - closeErr := c.Close() + closeErr := sampleC.Close() if closeErr != nil { log.Debugf("no registered listener for demultiplex connection %s. Error closing the connection %s", t, closeErr.Error()) } else { diff --git a/p2p/transport/tcpreuse/listener_test.go b/p2p/transport/tcpreuse/listener_test.go index f9d1fe589c..b61ffce09d 100644 --- a/p2p/transport/tcpreuse/listener_test.go +++ b/p2p/transport/tcpreuse/listener_test.go @@ -65,7 +65,7 @@ func TestListenerSingle(t *testing.T) { const N = 128 for _, disableReuseport := range []bool{true, false} { t.Run(fmt.Sprintf("multistream-reuseport:%v", disableReuseport), func(t *testing.T) { - cm := NewConnMgr(disableReuseport) + cm := NewConnMgr(disableReuseport, nil, nil) l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) require.NoError(t, err) go func() { @@ -112,7 +112,7 @@ func TestListenerSingle(t *testing.T) { }) t.Run(fmt.Sprintf("WebSocket-reuseport:%v", disableReuseport), func(t *testing.T) { - cm := NewConnMgr(disableReuseport) + cm := NewConnMgr(disableReuseport, nil, nil) l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) require.NoError(t, err) wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} @@ -158,7 +158,7 @@ func TestListenerSingle(t *testing.T) { }) t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", disableReuseport), func(t *testing.T) { - cm := NewConnMgr(disableReuseport) + cm := NewConnMgr(disableReuseport, nil, nil) l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) require.NoError(t, err) defer l.Close() @@ -211,7 +211,7 @@ func TestListenerMultiplexed(t *testing.T) { listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") const N = 128 for _, disableReuseport := range []bool{true, false} { - cm := NewConnMgr(disableReuseport) + cm := NewConnMgr(disableReuseport, nil, nil) msl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) require.NoError(t, err) defer msl.Close() @@ -370,7 +370,7 @@ func TestListenerClose(t *testing.T) { testClose := func(listenAddr ma.Multiaddr) { // listen on port 0 - cm := NewConnMgr(true) + cm := NewConnMgr(true, nil, nil) ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) require.NoError(t, err) wl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index 30b70055d0..19d4e46ec5 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -99,9 +99,20 @@ func (c *Conn) Write(b []byte) (n int, err error) { return len(b), nil } +func (c *Conn) Scope() network.ConnManagementScope { + nc := c.NetConn() + if sc, ok := nc.(interface { + Scope() network.ConnManagementScope + }); ok { + return sc.Scope() + } + return nil +} + // Close closes the connection. Only the first call to Close will receive the // close error, subsequent and concurrent calls will return nil. // This method is thread-safe. +// TODO: Fix this ^ func (c *Conn) Close() error { var err error c.closeOnce.Do(func() {