Skip to content

Commit

Permalink
create conn scope early to prevent DoS attacks
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Oct 6, 2024
1 parent 60cef4d commit ed7f901
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 35 deletions.
38 changes: 24 additions & 14 deletions p2p/net/upgrader/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,33 @@ func (l *listener) handleIncoming() {
}
catcher.Reset()

// gate the connection if applicable
if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) {
log.Debugf("gater blocked incoming connection on local addr %s from %s",
maconn.LocalMultiaddr(), maconn.RemoteMultiaddr())
if err := maconn.Close(); err != nil {
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
}
continue
var connScope network.ConnManagementScope
if sc, ok := maconn.(interface {
Scope() network.ConnManagementScope
}); ok {
connScope = sc.Scope()
}

connScope, err := l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr())
if err != nil {
log.Debugw("resource manager blocked accept of new connection", "error", err)
if err := maconn.Close(); err != nil {
log.Warnf("failed to incoming connection rejected by resource manager: %s", err)
if connScope != nil {
// gate the connection if applicable
if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) {
log.Debugf("gater blocked incoming connection on local addr %s from %s",
maconn.LocalMultiaddr(), maconn.RemoteMultiaddr())
if err := maconn.Close(); err != nil {
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
}
continue
}

var err error
connScope, err = l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr())
if err != nil {
log.Debugw("resource manager blocked accept of new connection", "error", err)
if err := maconn.Close(); err != nil {
log.Warnf("failed to incoming connection rejected by resource manager: %s", err)
}
continue
}
continue
}

// The go routine below calls Release when the context is
Expand Down
26 changes: 18 additions & 8 deletions p2p/transport/tcpreuse/demultiplex.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net"
"time"

"github.com/libp2p/go-libp2p/core/network"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
Expand Down Expand Up @@ -51,13 +52,13 @@ func (t DemultiplexedConnType) IsKnown() bool {
return t >= 1 || t <= 3
}

func getDemultiplexedConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) {
func getDemultiplexedConn(c net.Conn, scope network.ConnManagementScope) (DemultiplexedConnType, manet.Conn, error) {
if err := c.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil {
closeErr := c.Close()
return 0, nil, errors.Join(err, closeErr)
}

s, sc, err := ReadSampleFromConn(c)
s, sc, err := readSampleFromConn(c, scope)
if err != nil {
closeErr := c.Close()
return 0, nil, errors.Join(err, closeErr)
Expand All @@ -80,9 +81,9 @@ func getDemultiplexedConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error)
return DemultiplexedConnType_Unknown, sc, nil
}

// ReadSampleFromConn reads a sample and returns a reader which still includes the sample, so it can be kept undamaged.
// readSampleFromConn reads a sample and returns a reader which still includes the sample, so it can be kept undamaged.
// If an error occurs it only returns the error.
func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) {
func readSampleFromConn(c net.Conn, scope network.ConnManagementScope) (Sample, manet.Conn, error) {
// TODO: Should we remove this? This is only implemented by bufio.Reader.
// This made sense for magiselect: https://github.com/libp2p/go-libp2p/pull/2737 as it deals with a wrapped
// ReadWriteCloser from multistream which does use a buffered reader underneath.
Expand Down Expand Up @@ -120,7 +121,11 @@ func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) {
return Sample{}, nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err)
}

sc := &sampledConn{tcpConnInterface: tcpConnLike, maEndpoints: maEndpoints{laddr: laddr, raddr: raddr}}
sc := &sampledConn{
tcpConnInterface: tcpConnLike,
maEndpoints: maEndpoints{laddr: laddr, raddr: raddr},
scope: scope,
}
_, err = io.ReadFull(c, sc.s[:])
if err != nil {
return Sample{}, nil, err
Expand Down Expand Up @@ -167,7 +172,7 @@ func (c *maEndpoints) RemoteMultiaddr() ma.Multiaddr {
type sampledConn struct {
tcpConnInterface
maEndpoints

scope network.ConnManagementScope
s Sample
readFromSample uint8
}
Expand Down Expand Up @@ -214,8 +219,13 @@ func (sc *sampledConn) WriteTo(w io.Writer) (total int64, err error) {
return total, err
}

type Matcher interface {
Match(s Sample) bool
func (sc *sampledConn) Scope() network.ConnManagementScope {
return sc.scope
}

func (sc *sampledConn) Close() error {
sc.scope.Done()
return sc.tcpConnInterface.Close()
}

// Sample is the byte sequence we use to demultiplex.
Expand Down
52 changes: 44 additions & 8 deletions p2p/transport/tcpreuse/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"sync"

logging "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/net/reuseport"
ma "github.com/multiformats/go-multiaddr"
Expand All @@ -22,14 +24,22 @@ var log = logging.Logger("tcp-demultiplex")
type ConnMgr struct {
disableReuseport bool
reuse reuseport.Transport
listeners map[string]*multiplexedListener
mx sync.Mutex
connGater connmgr.ConnectionGater
rcmgr network.ResourceManager

mx sync.Mutex
listeners map[string]*multiplexedListener
}

func NewConnMgr(disableReuseport bool) *ConnMgr {
func NewConnMgr(disableReuseport bool, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *ConnMgr {
if rcmgr == nil {
rcmgr = &network.NullResourceManager{}
}
return &ConnMgr{
disableReuseport: disableReuseport,
reuse: reuseport.Transport{},
connGater: gater,
rcmgr: rcmgr,
listeners: make(map[string]*multiplexedListener),
}
}
Expand Down Expand Up @@ -104,6 +114,8 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed
listeners: make(map[DemultiplexedConnType]*demultiplexedListener),
ctx: ctx,
closeFn: cancelFunc,
connGater: t.connGater,
rcmgr: t.rcmgr,
}

dl, err := ml.DemultiplexedListen(connType)
Expand All @@ -127,9 +139,11 @@ type multiplexedListener struct {
listeners map[DemultiplexedConnType]*demultiplexedListener
mx sync.RWMutex

ctx context.Context
closeFn func() error
wg sync.WaitGroup
connGater connmgr.ConnectionGater
rcmgr network.ResourceManager
ctx context.Context
closeFn func() error
wg sync.WaitGroup
}

func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (manet.Listener, error) {
Expand Down Expand Up @@ -167,6 +181,26 @@ func (m *multiplexedListener) run() error {
if err != nil {
return err
}

// gate the connection if applicable
if m.connGater != nil && !m.connGater.InterceptAccept(c) {
log.Debugf("gater blocked incoming connection on local addr %s from %s",
c.LocalMultiaddr(), c.RemoteMultiaddr())
if err := c.Close(); err != nil {
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
}
continue
}

connScope, err := m.rcmgr.OpenConnection(network.DirInbound, true, c.RemoteMultiaddr())
if err != nil {
log.Debugw("resource manager blocked accept of new connection", "error", err)
if err := c.Close(); err != nil {
log.Warnf("failed to incoming connection rejected by resource manager: %s", err)
}
continue
}

select {
case acceptQueue <- struct{}{}:
case <-m.ctx.Done():
Expand All @@ -180,18 +214,20 @@ func (m *multiplexedListener) run() error {
defer m.wg.Done()
// TODO: if/how do we want to handle stalled connections and stop them from clogging up the pipeline?
// Drop connection because the buffer is full
t, sampleC, err := getDemultiplexedConn(c)
t, sampleC, err := getDemultiplexedConn(c, connScope)
if err != nil {
connScope.Done()
closeErr := c.Close()
err = errors.Join(err, closeErr)
log.Debugf("error demultiplexing connection: %s", err.Error())
return
}

m.mx.RLock()
demux, ok := m.listeners[t]
m.mx.RUnlock()
if !ok {
closeErr := c.Close()
closeErr := sampleC.Close()
if closeErr != nil {
log.Debugf("no registered listener for demultiplex connection %s. Error closing the connection %s", t, closeErr.Error())
} else {
Expand Down
10 changes: 5 additions & 5 deletions p2p/transport/tcpreuse/listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestListenerSingle(t *testing.T) {
const N = 128
for _, disableReuseport := range []bool{true, false} {
t.Run(fmt.Sprintf("multistream-reuseport:%v", disableReuseport), func(t *testing.T) {
cm := NewConnMgr(disableReuseport)
cm := NewConnMgr(disableReuseport, nil, nil)
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err)
go func() {
Expand Down Expand Up @@ -112,7 +112,7 @@ func TestListenerSingle(t *testing.T) {
})

t.Run(fmt.Sprintf("WebSocket-reuseport:%v", disableReuseport), func(t *testing.T) {
cm := NewConnMgr(disableReuseport)
cm := NewConnMgr(disableReuseport, nil, nil)
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
require.NoError(t, err)
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
Expand Down Expand Up @@ -158,7 +158,7 @@ func TestListenerSingle(t *testing.T) {
})

t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", disableReuseport), func(t *testing.T) {
cm := NewConnMgr(disableReuseport)
cm := NewConnMgr(disableReuseport, nil, nil)
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS)
require.NoError(t, err)
defer l.Close()
Expand Down Expand Up @@ -211,7 +211,7 @@ func TestListenerMultiplexed(t *testing.T) {
listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0")
const N = 128
for _, disableReuseport := range []bool{true, false} {
cm := NewConnMgr(disableReuseport)
cm := NewConnMgr(disableReuseport, nil, nil)
msl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err)
defer msl.Close()
Expand Down Expand Up @@ -370,7 +370,7 @@ func TestListenerClose(t *testing.T) {

testClose := func(listenAddr ma.Multiaddr) {
// listen on port 0
cm := NewConnMgr(true)
cm := NewConnMgr(true, nil, nil)
ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err)
wl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
Expand Down
11 changes: 11 additions & 0 deletions p2p/transport/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,20 @@ func (c *Conn) Write(b []byte) (n int, err error) {
return len(b), nil
}

func (c *Conn) Scope() network.ConnManagementScope {
nc := c.NetConn()
if sc, ok := nc.(interface {
Scope() network.ConnManagementScope
}); ok {
return sc.Scope()
}
return nil
}

// Close closes the connection. Only the first call to Close will receive the
// close error, subsequent and concurrent calls will return nil.
// This method is thread-safe.
// TODO: Fix this ^
func (c *Conn) Close() error {
var err error
c.closeOnce.Do(func() {
Expand Down

0 comments on commit ed7f901

Please sign in to comment.