Skip to content

Commit

Permalink
Improve SOCKS lifecycle mangement and resolve leak
Browse files Browse the repository at this point in the history
  • Loading branch information
navsec committed Nov 7, 2024
1 parent 79a8758 commit c18c193
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 42 deletions.
4 changes: 2 additions & 2 deletions client/core/socks.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ const leakyBufSize = 4108 // data.len(2) + hmacsha1(10) + data(4096)
var leakyBuf = leaky.NewLeakyBuf(2048, leakyBufSize)

func connect(conn net.Conn, stream rpcpb.SliverRPC_SocksProxyClient, frame *sliverpb.SocksData) {
// Client Rate Limiter: 20 operations per second, burst of 1
limiter := rate.NewLimiter(rate.Limit(20), 1)
// Client Rate Limiter: 5 operations per second, burst of 1
limiter := rate.NewLimiter(rate.Limit(5), 1)

SocksConnPool.Store(frame.TunnelID, conn)

Expand Down
213 changes: 173 additions & 40 deletions server/rpc/rpc-socks.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ import (

var (
// SessionID->Tunnels[TunnelID]->Tunnel->Cache map[uint64]*sliverpb.SocksData{}
toImplantCacheSocks = socksDataCache{mutex: &sync.RWMutex{}, cache: map[uint64]map[uint64]*sliverpb.SocksData{}}
toImplantCacheSocks = socksDataCache{mutex: &sync.RWMutex{}, cache: map[uint64]map[uint64]*sliverpb.SocksData{}, lastActivity: map[uint64]time.Time{}}

// SessionID->Tunnels[TunnelID]->Tunnel->Cache
fromImplantCacheSocks = socksDataCache{mutex: &sync.RWMutex{}, cache: map[uint64]map[uint64]*sliverpb.SocksData{}}
fromImplantCacheSocks = socksDataCache{mutex: &sync.RWMutex{}, cache: map[uint64]map[uint64]*sliverpb.SocksData{}, lastActivity: map[uint64]time.Time{}}
)

type socksDataCache struct {
mutex *sync.RWMutex
cache map[uint64]map[uint64]*sliverpb.SocksData
mutex *sync.RWMutex
cache map[uint64]map[uint64]*sliverpb.SocksData
lastActivity map[uint64]time.Time
}

func (c *socksDataCache) Add(tunnelID uint64, sequence uint64, tunnelData *sliverpb.SocksData) {
Expand Down Expand Up @@ -75,6 +76,7 @@ func (c *socksDataCache) DeleteTun(tunnelID uint64) {
defer c.mutex.Unlock()

delete(c.cache, tunnelID)
delete(c.lastActivity, tunnelID)
}

func (c *socksDataCache) DeleteSeq(tunnelID uint64, sequence uint64) {
Expand All @@ -88,70 +90,171 @@ func (c *socksDataCache) DeleteSeq(tunnelID uint64, sequence uint64) {
delete(c.cache[tunnelID], sequence)
}

func (c *socksDataCache) recordActivity(tunnelID uint64) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.lastActivity[tunnelID] = time.Now()
}

// Socks - Open an in-band port forward

const (
writeTimeout = 5 * time.Second
batchSize = 100 // Maximum number of sequences to batch
writeTimeout = 5 * time.Second
batchSize = 100 // Maximum number of sequences to batch
inactivityCheckInterval = 5 * time.Second
inactivityTimeout = 15 * time.Second
)

func (s *Server) SocksProxy(stream rpcpb.SliverRPC_SocksProxyServer) error {
errChan := make(chan error, 2)
defer close(errChan)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

connDone := make(chan struct{})
defer close(connDone)

// Track all goroutines spawned for this session
var wg sync.WaitGroup
defer wg.Wait()

// Track all tunnels created for this session
activeTunnels := make(map[uint64]bool)
var tunnelMutex sync.Mutex

// Cleanup all tunnels on SocksProxy closure
defer func() {
if r := recover(); r != nil {
rpcLog.Errorf("Recovered from panic in SocksProxy: %v", r)
tunnelMutex.Lock()
for tunnelID := range activeTunnels {
if tunnel := core.SocksTunnels.Get(tunnelID); tunnel != nil {
rpcLog.Infof("Cleaning up tunnel %d on proxy closure", tunnelID)
close(tunnel.FromImplant)
tunnel.Client = nil
s.CloseSocks(context.Background(), &sliverpb.Socks{TunnelID: tunnelID})
}
}
tunnelMutex.Unlock()
}()

for {
select {
case err := <-errChan:
rpcLog.Errorf("SocksProxy error: %v", err)
return err
default:
}

fromClient, err := stream.Recv()
if err == io.EOF {
break
return nil
}
//fmt.Println("Send Agent 1 ",fromClient.TunnelID,len(fromClient.Data))
if err != nil {
rpcLog.Warnf("Error on stream recv %s", err)
return err
}
tunnelLog.Debugf("Tunnel %d: From client %d byte(s)",
fromClient.TunnelID, len(fromClient.Data))
socks := core.SocksTunnels.Get(fromClient.TunnelID)
if socks == nil {
return nil

tunnelMutex.Lock()
activeTunnels[fromClient.TunnelID] = true // Mark this as an active tunnel
tunnelMutex.Unlock()

tunnel := core.SocksTunnels.Get(fromClient.TunnelID)
if tunnel == nil {
continue
}
if socks.Client == nil {
socks.Client = stream // Bind client to tunnel

if tunnel.Client == nil {
tunnel.Client = stream
tunnel.FromImplant = make(chan *sliverpb.SocksData, 100)

// Monitor tunnel goroutines for inactivity and cleanup
wg.Add(1)
go func(tunnelID uint64) {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
rpcLog.Errorf("Recovered from panic in monitor: %v", r)
errChan <- fmt.Errorf("monitor goroutine panic: %v", r)
cancel() // Cancel context in case of a panic
}
}()

ticker := time.NewTicker(inactivityCheckInterval)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
return
case <-connDone:
return
case <-ticker.C:
tunnel := core.SocksTunnels.Get(tunnelID)
if tunnel == nil || tunnel.Client == nil {
return
}
session := core.Sessions.Get(tunnel.SessionID)

// Check both caches for activity
toImplantCacheSocks.mutex.RLock()
fromImplantCacheSocks.mutex.RLock()
toLastActivity := toImplantCacheSocks.lastActivity[tunnelID]
fromLastActivity := fromImplantCacheSocks.lastActivity[tunnelID]
toImplantCacheSocks.mutex.RUnlock()
fromImplantCacheSocks.mutex.RUnlock()

// Clean up goroutine if both directions have hit the idle threshold or if client has disconnected
if time.Since(toLastActivity) > inactivityTimeout &&
time.Since(fromLastActivity) > inactivityTimeout ||
tunnel.Client == nil || session == nil {
s.CloseSocks(context.Background(), &sliverpb.Socks{TunnelID: tunnelID})
return
}

}
}
}(fromClient.TunnelID)

// Send Client
wg.Add(1)
go func() {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
errChan <- fmt.Errorf("client sender panic: %v", r)
rpcLog.Errorf("Recovered from panic in client sender: %v", r)
errChan <- fmt.Errorf("client sender panic: %v", r)
cancel() // Cancel context in case of a panic
}
}()

pendingData := make(map[uint64]*sliverpb.SocksData)
ticker := time.NewTicker(1 * time.Millisecond) // 1ms ticker - data coming back from implant is usually larger response data
ticker := time.NewTicker(50 * time.Millisecond) // 50ms ticker - data coming back from implant is usually larger response data
defer ticker.Stop()

for {
select {
case tunnelData, ok := <-socks.FromImplant:
case <-ctx.Done():
return
case <-connDone:
return
case tunnelData, ok := <-tunnel.FromImplant:
if !ok {
rpcLog.Debug("FromImplant channel closed")
return
}
sequence := tunnelData.Sequence
fromImplantCacheSocks.Add(fromClient.TunnelID, sequence, tunnelData)
pendingData[sequence] = tunnelData
fromImplantCacheSocks.recordActivity(fromClient.TunnelID)

case <-ticker.C:
if tunnel.Client == nil {
return
}
if len(pendingData) == 0 {
continue
}

expectedSequence := atomic.LoadUint64(&socks.FromImplantSequence)
expectedSequence := atomic.LoadUint64(&tunnel.FromImplantSequence)
processed := 0

// Perform Batching
Expand All @@ -164,8 +267,7 @@ func (s *Server) SocksProxy(stream rpcpb.SliverRPC_SocksProxyServer) error {
func() {
defer func() {
if r := recover(); r != nil {
errChan <- fmt.Errorf("client sender panic: %v", r)
rpcLog.Errorf("Recovered from panic in client sender: %v", r)
errChan <- fmt.Errorf("Client sender panic: %v", r)
}
}()

Expand All @@ -182,32 +284,43 @@ func (s *Server) SocksProxy(stream rpcpb.SliverRPC_SocksProxyServer) error {

delete(pendingData, expectedSequence)
fromImplantCacheSocks.DeleteSeq(fromClient.TunnelID, expectedSequence)
atomic.AddUint64(&socks.FromImplantSequence, 1)
atomic.AddUint64(&tunnel.FromImplantSequence, 1)
expectedSequence++
processed++
}()
}

}
}
}()

// Send Agent
wg.Add(1)
go func() {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
errChan <- fmt.Errorf("agent sender panic: %v", r)
rpcLog.Errorf("Recovered from panic in agent sender: %v", r)
errChan <- fmt.Errorf("agent sender panic: %v", r)
cancel() // Cancel context in case of a panic
}
}()

pendingData := make(map[uint64]*sliverpb.SocksData)
ticker := time.NewTicker(10 * time.Millisecond) // 10ms ticker - data going towards implact is usually smaller request data
ticker := time.NewTicker(100 * time.Millisecond) // 100ms ticker - data going towards implant is usually smaller request data
defer ticker.Stop()

for {
select {
case <-ctx.Done():
return
case <-connDone:
return
case <-ticker.C:
sequence := atomic.LoadUint64(&socks.ToImplantSequence)
if tunnel.Client == nil {
return
}
sequence := atomic.LoadUint64(&tunnel.ToImplantSequence)

func() {
defer func() {
Expand Down Expand Up @@ -240,7 +353,7 @@ func (s *Server) SocksProxy(stream rpcpb.SliverRPC_SocksProxyServer) error {
Data: data,
}:
toImplantCacheSocks.DeleteSeq(fromClient.TunnelID, sequence)
atomic.AddUint64(&socks.ToImplantSequence, 1)
atomic.AddUint64(&tunnel.ToImplantSequence, 1)
sequence++
case <-time.After(writeTimeout):
rpcLog.Error("Write timeout to implant")
Expand All @@ -256,14 +369,6 @@ func (s *Server) SocksProxy(stream rpcpb.SliverRPC_SocksProxyServer) error {

toImplantCacheSocks.Add(fromClient.TunnelID, fromClient.Sequence, fromClient)
}

select {
case err := <-errChan:
rpcLog.Errorf("SocksProxy Goroutine error: %v", err)
default:
}

return nil
}

// CreateSocks5 - Create requests we close a Socks
Expand All @@ -285,11 +390,39 @@ func (s *Server) CreateSocks(ctx context.Context, req *sliverpb.Socks) (*sliverp

// CloseSocks - Client requests we close a Socks
func (s *Server) CloseSocks(ctx context.Context, req *sliverpb.Socks) (*commonpb.Empty, error) {
err := core.SocksTunnels.Close(req.TunnelID)
defer func() {
if r := recover(); r != nil {
rpcLog.Errorf("Recovered from panic in CloseSocks for tunnel %d: %v", req.TunnelID, r)
}
}()

tunnel := core.SocksTunnels.Get(req.TunnelID)
if tunnel != nil {
// We mark the tunnel closed first to prevent new operations
tunnel.Client = nil

// Close down the FromImplant channel if it exists
if tunnel.FromImplant != nil {
select {
case _, ok := <-tunnel.FromImplant:
if ok {
close(tunnel.FromImplant)
}
default:
close(tunnel.FromImplant)
}
tunnel.FromImplant = nil
}
}

// Clean up caches
toImplantCacheSocks.DeleteTun(req.TunnelID)
fromImplantCacheSocks.DeleteTun(req.TunnelID)
if err != nil {
return nil, err

// Remove from core tunnels last
if err := core.SocksTunnels.Close(req.TunnelID); err != nil {
rpcLog.Errorf("Error closing tunnel %d: %v", req.TunnelID, err)
}

return &commonpb.Empty{}, nil
}

0 comments on commit c18c193

Please sign in to comment.