Skip to content

Commit

Permalink
fix connection limits
Browse files Browse the repository at this point in the history
an SFTP client can start multiple transfers on a single connection

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
  • Loading branch information
drakkan committed Oct 26, 2024
1 parent c69fbe6 commit ae1487d
Show file tree
Hide file tree
Showing 24 changed files with 707 additions and 7 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ require (
github.com/bmatcuk/doublestar/v4 v4.7.1
github.com/cockroachdb/cockroach-go/v2 v2.3.8
github.com/coreos/go-oidc/v3 v3.11.0
github.com/drakkan/webdav v0.0.0-20240503091431-218ec83910bb
github.com/drakkan/webdav v0.0.0-20241026165615-b8b8f74ae71b
github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001
github.com/fclairamb/ftpserverlib v0.24.1
github.com/fclairamb/go-log v0.5.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f h1:S9JUlrOzjK58UKoLqqb
github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f/go.mod h1:4p8lUl4vQ80L598CygL+3IFtm+3nggvvW/palOlViwE=
github.com/drakkan/ftpserverlib v0.0.0-20240603150004-6a8f643fbf2e h1:VBpqQeChkGXSV1FXCtvd3BJTyB+DcMgiu7SfkpsGuKw=
github.com/drakkan/ftpserverlib v0.0.0-20240603150004-6a8f643fbf2e/go.mod h1:aAwyOAC6IIe+IZeeGD1QjuE3GGDzqW/c5Xtn+Dp0JUM=
github.com/drakkan/webdav v0.0.0-20240503091431-218ec83910bb h1:067/Uo8cfeY7QC0yzWCr/RImuNcM0rLWAsBUyMks59o=
github.com/drakkan/webdav v0.0.0-20240503091431-218ec83910bb/go.mod h1:zOVb1QDhwwqWn2L2qZ0U3swMSO4GTSNyIwXCGO/UGWE=
github.com/drakkan/webdav v0.0.0-20241026165615-b8b8f74ae71b h1:Y1tLiQ8fnxM5f3wiBjAXsHzHNwiY9BR+mXZA75nZwrs=
github.com/drakkan/webdav v0.0.0-20241026165615-b8b8f74ae71b/go.mod h1:zOVb1QDhwwqWn2L2qZ0U3swMSO4GTSNyIwXCGO/UGWE=
github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001 h1:/ZshrfQzayqRSBDodmp3rhNCHJCff+utvgBuWRbiqu4=
github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001/go.mod h1:kltMsfRMTHSFdMbK66XdS8mfMW77+FZA1fGY1xYMF84=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
Expand Down
45 changes: 43 additions & 2 deletions internal/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ func init() {
Connections.clients = clientsMap{
clients: make(map[string]int),
}
Connections.transfers = clientsMap{
clients: make(map[string]int),
}
Connections.perUserConns = make(map[string]int)
Connections.mapping = make(map[string]int)
Connections.sshMapping = make(map[string]int)
Expand Down Expand Up @@ -908,7 +911,9 @@ func (c *SSHConnection) Close() error {
type ActiveConnections struct {
// clients contains both authenticated and estabilished connections and the ones waiting
// for authentication
clients clientsMap
clients clientsMap
// transfers contains active transfers, total and per-user
transfers clientsMap
transfersCheckStatus atomic.Bool
sync.RWMutex
connections []ActiveConnection
Expand Down Expand Up @@ -959,6 +964,9 @@ func (conns *ActiveConnections) Add(c ActiveConnection) error {
if val := conns.perUserConns[username]; val >= maxSessions {
return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions)
}
if val := conns.transfers.getTotalFrom(username); val >= maxSessions {
return fmt.Errorf("too many open transfers: %d/%d", val, maxSessions)
}
}
conns.addUserConnection(username)
}
Expand Down Expand Up @@ -1219,6 +1227,35 @@ func (conns *ActiveConnections) GetClientConnections() int32 {
return conns.clients.getTotal()
}

// GetTotalTransfers returns the total number of active transfers
func (conns *ActiveConnections) GetTotalTransfers() int32 {
return conns.transfers.getTotal()
}

// IsNewTransferAllowed returns an error if the maximum number of concurrent allowed
// transfers is exceeded
func (conns *ActiveConnections) IsNewTransferAllowed(username string) error {
if isShuttingDown.Load() {
return ErrShuttingDown
}
if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 {
return nil
}
if Config.MaxPerHostConnections > 0 {
if transfers := conns.transfers.getTotalFrom(username); transfers >= Config.MaxPerHostConnections {
logger.Info(logSender, "", "active transfers from user %q: %d/%d", username, transfers, Config.MaxPerHostConnections)
return ErrConnectionDenied
}
}
if Config.MaxTotalConnections > 0 {
if transfers := conns.transfers.getTotal(); transfers >= int32(Config.MaxTotalConnections) {
logger.Info(logSender, "", "active transfers %d/%d", transfers, Config.MaxTotalConnections)
return ErrConnectionDenied
}
}
return nil
}

// IsNewConnectionAllowed returns an error if the maximum number of concurrent allowed
// connections is exceeded or a whitelist is defined and the specified ipAddr is not listed
// or the service is shutting down
Expand Down Expand Up @@ -1259,7 +1296,11 @@ func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr, protocol string)
}

// on a single SFTP connection we could have multiple SFTP channels or commands
// so we check the estabilished connections too
// so we check the estabilished connections and active uploads too
if transfers := conns.transfers.getTotal(); transfers >= int32(Config.MaxTotalConnections) {
logger.Info(logSender, "", "active transfers %d/%d", transfers, Config.MaxTotalConnections)
return ErrConnectionDenied
}

conns.RLock()
defer conns.RUnlock()
Expand Down
13 changes: 13 additions & 0 deletions internal/common/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -626,11 +626,17 @@ func TestMaxConnections(t *testing.T) {

ipAddr := "192.168.7.8"
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolFTP))
assert.NoError(t, Connections.IsNewTransferAllowed(userTestUsername))

Config.MaxTotalConnections = 1
Config.MaxPerHostConnections = perHost

assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolHTTP))
assert.NoError(t, Connections.IsNewTransferAllowed(userTestUsername))
isShuttingDown.Store(true)
assert.ErrorIs(t, Connections.IsNewTransferAllowed(userTestUsername), ErrShuttingDown)
isShuttingDown.Store(false)

c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{})
fakeConn := &fakeConnection{
BaseConnection: c,
Expand All @@ -639,6 +645,10 @@ func TestMaxConnections(t *testing.T) {
assert.NoError(t, err)
assert.Len(t, Connections.GetStats(""), 1)
assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH))
Connections.transfers.add(userTestUsername)
assert.Error(t, Connections.IsNewTransferAllowed(userTestUsername))
Connections.transfers.remove(userTestUsername)
assert.Equal(t, int32(0), Connections.GetTotalTransfers())

res := Connections.Close(fakeConn.GetID(), "")
assert.True(t, res)
Expand All @@ -650,6 +660,9 @@ func TestMaxConnections(t *testing.T) {
assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH))
Connections.RemoveClientConnection(ipAddr)
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolWebDAV))
Connections.transfers.add(userTestUsername)
assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH))
Connections.transfers.remove(userTestUsername)
Connections.RemoveClientConnection(ipAddr)

Config.MaxTotalConnections = oldValue
Expand Down
4 changes: 4 additions & 0 deletions internal/common/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ func (c *BaseConnection) CloseFS() error {

// AddTransfer associates a new transfer to this connection
func (c *BaseConnection) AddTransfer(t ActiveTransfer) {
Connections.transfers.add(c.User.Username)

c.Lock()
defer c.Unlock()

Expand Down Expand Up @@ -190,6 +192,8 @@ func (c *BaseConnection) AddTransfer(t ActiveTransfer) {

// RemoveTransfer removes the specified transfer from the active ones
func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) {
Connections.transfers.remove(c.User.Username)

c.Lock()
defer c.Unlock()

Expand Down
80 changes: 80 additions & 0 deletions internal/common/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8130,6 +8130,86 @@ func TestRetentionAPI(t *testing.T) {
assert.NoError(t, err)
}

func TestPerUserTransferLimits(t *testing.T) {
oldMaxPerHostConns := common.Config.MaxPerHostConnections

common.Config.MaxPerHostConnections = 2

u := getTestUser()
u.UploadBandwidth = 32
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()

var wg sync.WaitGroup
numErrors := 0
for i := 0; i <= 2; i++ {
wg.Add(1)
go func(counter int) {
defer wg.Done()

time.Sleep(20 * time.Millisecond)
err := writeSFTPFile(fmt.Sprintf("%s_%d", testFileName, counter), 64*1024, client)
if err != nil {
numErrors++
}
}(i)
}
wg.Wait()

assert.Equal(t, 1, numErrors)
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)

common.Config.MaxPerHostConnections = oldMaxPerHostConns
}

func TestMaxSessionsSameConnection(t *testing.T) {
u := getTestUser()
u.UploadBandwidth = 32
u.MaxSessions = 2
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()

var wg sync.WaitGroup
numErrors := 0
for i := 0; i <= 2; i++ {
wg.Add(1)
go func(counter int) {
defer wg.Done()

time.Sleep(20 * time.Millisecond)
var err error
if counter < 2 {
err = writeSFTPFile(fmt.Sprintf("%s_%d", testFileName, counter), 64*1024, client)
} else {
_, _, err = getSftpClient(user)
}
if err != nil {
numErrors++
}
}(i)
}

wg.Wait()
assert.Equal(t, 1, numErrors)
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
}

func TestRenameDir(t *testing.T) {
u := getTestUser()
testDir := "/dir-to-rename"
Expand Down
8 changes: 8 additions & 0 deletions internal/common/transfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ func TestRemovePartialCryptoFile(t *testing.T) {
assert.Equal(t, int64(0), size)
assert.Equal(t, 1, deletedFiles)
assert.NoFileExists(t, testFile)
err = transfer.Close()
assert.Error(t, err)
assert.Len(t, conn.GetTransfers(), 0)
}

func TestFTPMode(t *testing.T) {
Expand Down Expand Up @@ -434,6 +437,11 @@ func TestTransferQuota(t *testing.T) {
}
err = transfer.CheckWrite()
assert.True(t, conn.IsQuotaExceededError(err))

err = transfer.Close()
assert.NoError(t, err)
assert.Len(t, conn.GetTransfers(), 0)
assert.Equal(t, int32(0), Connections.GetTotalTransfers())
}

func TestUploadOutsideHomeRenameError(t *testing.T) {
Expand Down
6 changes: 6 additions & 0 deletions internal/common/transferschecker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
Connections.Remove(fakeConn5.GetID())
stats := Connections.GetStats("")
assert.Len(t, stats, 0)
assert.Equal(t, int32(0), Connections.GetTotalTransfers())

err = dataprovider.DeleteUser(user.Username, "", "", "")
assert.NoError(t, err)
Expand Down Expand Up @@ -368,11 +369,16 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
if assert.Error(t, transfer4.errAbort) {
assert.Contains(t, transfer4.errAbort.Error(), ErrReadQuotaExceeded.Error())
}
err = transfer3.Close()
assert.NoError(t, err)
err = transfer4.Close()
assert.NoError(t, err)

Connections.Remove(fakeConn3.GetID())
Connections.Remove(fakeConn4.GetID())
stats := Connections.GetStats("")
assert.Len(t, stats, 0)
assert.Equal(t, int32(0), Connections.GetTotalTransfers())

err = dataprovider.DeleteUser(user.Username, "", "", "")
assert.NoError(t, err)
Expand Down
2 changes: 2 additions & 0 deletions internal/ftpd/cryptfs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func TestBasicFTPHandlingCryptFs(t *testing.T) {
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond)
assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
50*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
}

func TestBufferedCryptFs(t *testing.T) {
Expand Down Expand Up @@ -179,6 +180,7 @@ func TestBufferedCryptFs(t *testing.T) {
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond)
assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
50*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
}

func TestZeroBytesTransfersCryptFs(t *testing.T) {
Expand Down
Loading

0 comments on commit ae1487d

Please sign in to comment.