Skip to content

Commit

Permalink
Force rotation X.509 SVIDs in Agent side (#5446)
Browse files Browse the repository at this point in the history
* Force rotation of X.509 workload SVIDs in lru cache
* Force rotation of X.509 workload SVIDs in store SVID cache
* Force rotation of Agent SVID

Signed-off-by: Marcos Yacob <marcosyacob@gmail.com>
  • Loading branch information
MarcosDY authored Sep 28, 2024
1 parent 182b594 commit 8f82eba
Show file tree
Hide file tree
Showing 22 changed files with 1,326 additions and 57 deletions.
5 changes: 3 additions & 2 deletions pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (a *Agent) Run(ctx context.Context) error {
}
}

svidStoreCache := a.newSVIDStoreCache()
svidStoreCache := a.newSVIDStoreCache(metrics)

manager, err := a.newManager(ctx, sto, cat, metrics, as, svidStoreCache, nodeAttestor)
if err != nil {
Expand Down Expand Up @@ -324,10 +324,11 @@ func (a *Agent) newManager(ctx context.Context, sto storage.Storage, cat catalog
}
}

func (a *Agent) newSVIDStoreCache() *storecache.Cache {
func (a *Agent) newSVIDStoreCache(metrics telemetry.Metrics) *storecache.Cache {
config := &storecache.Config{
Log: a.c.Log.WithField(telemetry.SubsystemName, "svid_store_cache"),
TrustDomain: a.c.TrustDomain,
Metrics: metrics,
}

return storecache.New(config)
Expand Down
129 changes: 129 additions & 0 deletions pkg/agent/manager/cache/lru_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cache

import (
"context"
"crypto/x509"
"fmt"
"sort"
"sync"
Expand All @@ -14,6 +15,7 @@ import (
"github.com/spiffe/spire/pkg/common/backoff"
"github.com/spiffe/spire/pkg/common/telemetry"
agentmetrics "github.com/spiffe/spire/pkg/common/telemetry/agent"
"github.com/spiffe/spire/pkg/common/x509util"
"github.com/spiffe/spire/proto/spire/common"
)

Expand All @@ -22,13 +24,26 @@ const (
SVIDCacheMaxSize = 1000
// SVIDSyncInterval is the interval at which SVIDs are synced with subscribers
SVIDSyncInterval = 500 * time.Millisecond
// Default batch size for processing tainted SVIDs
defaultProcessingBatchSize = 100
)

var (
// Time interval between SVID batch processing
processingTaintedX509SVIDInterval = 5 * time.Second
)

// UpdateEntries holds information for an entries update to the cache.
type UpdateEntries struct {
// Bundles is a set of ALL trust bundles available to the agent, keyed by trust domain
Bundles map[spiffeid.TrustDomain]*spiffebundle.Bundle

// TaintedX509Authorities is a set of all tainted X.509 authorities notified by the server.
TaintedX509Authorities []string

// TaintedJWTAuthorities is a set of all tainted JWT authorities notified by the server.
TaintedJWTAuthorities []string

// RegistrationEntries is a set of all registration entries available to the
// agent, keyed by registration entry id.
RegistrationEntries map[string]*common.RegistrationEntry
Expand Down Expand Up @@ -125,6 +140,10 @@ type LRUCache struct {
svids map[string]*X509SVID

subscribeBackoffFn func() backoff.BackOff

processingBatchSize int
// used to debug scheduled batchs for tainted authorities
taintedBatchProcessedCh chan struct{}
}

func NewLRUCache(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics, clk clock.Clock) *LRUCache {
Expand All @@ -146,6 +165,7 @@ func NewLRUCache(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundl
subscribeBackoffFn: func() backoff.BackOff {
return backoff.NewBackoff(clk, SVIDSyncInterval)
},
processingBatchSize: defaultProcessingBatchSize,
}
}

Expand Down Expand Up @@ -493,6 +513,35 @@ func (c *LRUCache) UpdateSVIDs(update *UpdateSVIDs) {
}
}

// TaintX509SVIDs initiates the processing of all cached SVIDs, checking if they are tainted
// by any of the provided authorities.
// It schedules the processing to run asynchronously in batches.
func (c *LRUCache) TaintX509SVIDs(ctx context.Context, taintedX509Authorities []*x509.Certificate) {
c.mu.RLock()
defer c.mu.RUnlock()

var entriesToProcess []string
for key, svid := range c.svids {
if svid != nil && len(svid.Chain) > 0 {
entriesToProcess = append(entriesToProcess, key)
}
}

// Check if there are any entries to process before scheduling
if len(entriesToProcess) == 0 {
c.log.Debug("No SVID entries to process for tainted X.509 authorities")
return
}

// Schedule the rotation process in a separate goroutine
go func() {
c.scheduleRotation(ctx, entriesToProcess, taintedX509Authorities)
}()

c.log.WithField(telemetry.Count, len(entriesToProcess)).
Debug("Scheduled rotation for SVID entries due to tainted X.509 authorities")
}

// GetStaleEntries obtains a list of stale entries
func (c *LRUCache) GetStaleEntries() []*StaleEntry {
c.mu.Lock()
Expand Down Expand Up @@ -531,6 +580,86 @@ func (c *LRUCache) SyncSVIDsWithSubscribers() {
c.syncSVIDsWithSubscribers()
}

// scheduleRotation processes SVID entries in batches, removing those tainted by X.509 authorities.
// The process continues at regular intervals until all entries have been processed or the context is cancelled.
func (c *LRUCache) scheduleRotation(ctx context.Context, entryIDs []string, taintedX509Authorities []*x509.Certificate) {
ticker := c.clk.Ticker(processingTaintedX509SVIDInterval)
defer ticker.Stop()

// Ensure consistent order for test cases if channel is used
if c.taintedBatchProcessedCh != nil {
sort.Strings(entryIDs)
}

for {
// Process entries in batches
batchSize := min(c.processingBatchSize, len(entryIDs))
processingEntries := entryIDs[:batchSize]

c.processTaintedSVIDs(processingEntries, taintedX509Authorities)

// Remove processed entries from the list
entryIDs = entryIDs[batchSize:]

entriesLeftCount := len(entryIDs)
if entriesLeftCount == 0 {
c.log.Info("Finished processing all tainted entries")
c.notifyTaintedBatchProcessed()
return
}
c.log.WithField(telemetry.Count, entriesLeftCount).Info("There are tainted X.509 SVIDs left to be processed")
c.notifyTaintedBatchProcessed()

select {
case <-ticker.C:
case <-ctx.Done():
c.log.WithError(ctx.Err()).Warn("Context cancelled, exiting rotation schedule")
return
}
}
}

func (c *LRUCache) notifyTaintedBatchProcessed() {
if c.taintedBatchProcessedCh != nil {
c.taintedBatchProcessedCh <- struct{}{}
}
}

// processTaintedSVIDs identifies and removes tainted SVIDs from the cache that have been signed by the given tainted authorities.
func (c *LRUCache) processTaintedSVIDs(entryIDs []string, taintedX509Authorities []*x509.Certificate) {
counter := telemetry.StartCall(c.metrics, telemetry.CacheManager, "", telemetry.ProcessTaintedSVIDs)
defer counter.Done(nil)

taintedSVIDs := 0

c.mu.Lock()
defer c.mu.Unlock()

for _, entryID := range entryIDs {
svid, exists := c.svids[entryID]
if !exists || svid == nil {
// Skip if the SVID is not in cache or is nil
continue
}

// Check if the SVID is signed by any tainted authority
isTainted, err := x509util.IsSignedByRoot(svid.Chain, taintedX509Authorities)
if err != nil {
c.log.WithError(err).
WithField(telemetry.RegistrationID, entryID).
Error("Failed to check if SVID is signed by tainted authority")
continue
}
if isTainted {
taintedSVIDs++
delete(c.svids, entryID)
}
}

agentmetrics.AddCacheManagerTaintedSVIDsSample(c.metrics, "", float32(taintedSVIDs))
c.log.WithField(telemetry.TaintedSVIDs, taintedSVIDs).Info("Tainted X.509 SVIDs")
}

// Notify subscriber of selector set only if all SVIDs for corresponding selector set are cached
// It returns whether all SVIDs are cached or not.
// This method should be retried with backoff to avoid lock contention.
Expand Down
Loading

0 comments on commit 8f82eba

Please sign in to comment.