diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 64ed3b5211..b6358ec7b5 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -142,7 +142,7 @@ func (a *Agent) Run(ctx context.Context) error { } } - svidStoreCache := a.newSVIDStoreCache() + svidStoreCache := a.newSVIDStoreCache(metrics) manager, err := a.newManager(ctx, sto, cat, metrics, as, svidStoreCache, nodeAttestor) if err != nil { @@ -324,10 +324,11 @@ func (a *Agent) newManager(ctx context.Context, sto storage.Storage, cat catalog } } -func (a *Agent) newSVIDStoreCache() *storecache.Cache { +func (a *Agent) newSVIDStoreCache(metrics telemetry.Metrics) *storecache.Cache { config := &storecache.Config{ Log: a.c.Log.WithField(telemetry.SubsystemName, "svid_store_cache"), TrustDomain: a.c.TrustDomain, + Metrics: metrics, } return storecache.New(config) diff --git a/pkg/agent/manager/cache/lru_cache.go b/pkg/agent/manager/cache/lru_cache.go index 100bda2e87..49d6bc5e32 100644 --- a/pkg/agent/manager/cache/lru_cache.go +++ b/pkg/agent/manager/cache/lru_cache.go @@ -2,6 +2,7 @@ package cache import ( "context" + "crypto/x509" "fmt" "sort" "sync" @@ -14,6 +15,7 @@ import ( "github.com/spiffe/spire/pkg/common/backoff" "github.com/spiffe/spire/pkg/common/telemetry" agentmetrics "github.com/spiffe/spire/pkg/common/telemetry/agent" + "github.com/spiffe/spire/pkg/common/x509util" "github.com/spiffe/spire/proto/spire/common" ) @@ -22,6 +24,13 @@ const ( SVIDCacheMaxSize = 1000 // SVIDSyncInterval is the interval at which SVIDs are synced with subscribers SVIDSyncInterval = 500 * time.Millisecond + // Default batch size for processing tainted SVIDs + defaultProcessingBatchSize = 100 +) + +var ( + // Time interval between SVID batch processing + processingTaintedX509SVIDInterval = 5 * time.Second ) // UpdateEntries holds information for an entries update to the cache. @@ -29,6 +38,12 @@ type UpdateEntries struct { // Bundles is a set of ALL trust bundles available to the agent, keyed by trust domain Bundles map[spiffeid.TrustDomain]*spiffebundle.Bundle + // TaintedX509Authorities is a set of all tainted X.509 authorities notified by the server. + TaintedX509Authorities []string + + // TaintedJWTAuthorities is a set of all tainted JWT authorities notified by the server. + TaintedJWTAuthorities []string + // RegistrationEntries is a set of all registration entries available to the // agent, keyed by registration entry id. RegistrationEntries map[string]*common.RegistrationEntry @@ -125,6 +140,10 @@ type LRUCache struct { svids map[string]*X509SVID subscribeBackoffFn func() backoff.BackOff + + processingBatchSize int + // used to debug scheduled batchs for tainted authorities + taintedBatchProcessedCh chan struct{} } func NewLRUCache(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics, clk clock.Clock) *LRUCache { @@ -146,6 +165,7 @@ func NewLRUCache(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundl subscribeBackoffFn: func() backoff.BackOff { return backoff.NewBackoff(clk, SVIDSyncInterval) }, + processingBatchSize: defaultProcessingBatchSize, } } @@ -493,6 +513,35 @@ func (c *LRUCache) UpdateSVIDs(update *UpdateSVIDs) { } } +// TaintX509SVIDs initiates the processing of all cached SVIDs, checking if they are tainted +// by any of the provided authorities. +// It schedules the processing to run asynchronously in batches. +func (c *LRUCache) TaintX509SVIDs(ctx context.Context, taintedX509Authorities []*x509.Certificate) { + c.mu.RLock() + defer c.mu.RUnlock() + + var entriesToProcess []string + for key, svid := range c.svids { + if svid != nil && len(svid.Chain) > 0 { + entriesToProcess = append(entriesToProcess, key) + } + } + + // Check if there are any entries to process before scheduling + if len(entriesToProcess) == 0 { + c.log.Debug("No SVID entries to process for tainted X.509 authorities") + return + } + + // Schedule the rotation process in a separate goroutine + go func() { + c.scheduleRotation(ctx, entriesToProcess, taintedX509Authorities) + }() + + c.log.WithField(telemetry.Count, len(entriesToProcess)). + Debug("Scheduled rotation for SVID entries due to tainted X.509 authorities") +} + // GetStaleEntries obtains a list of stale entries func (c *LRUCache) GetStaleEntries() []*StaleEntry { c.mu.Lock() @@ -531,6 +580,86 @@ func (c *LRUCache) SyncSVIDsWithSubscribers() { c.syncSVIDsWithSubscribers() } +// scheduleRotation processes SVID entries in batches, removing those tainted by X.509 authorities. +// The process continues at regular intervals until all entries have been processed or the context is cancelled. +func (c *LRUCache) scheduleRotation(ctx context.Context, entryIDs []string, taintedX509Authorities []*x509.Certificate) { + ticker := c.clk.Ticker(processingTaintedX509SVIDInterval) + defer ticker.Stop() + + // Ensure consistent order for test cases if channel is used + if c.taintedBatchProcessedCh != nil { + sort.Strings(entryIDs) + } + + for { + // Process entries in batches + batchSize := min(c.processingBatchSize, len(entryIDs)) + processingEntries := entryIDs[:batchSize] + + c.processTaintedSVIDs(processingEntries, taintedX509Authorities) + + // Remove processed entries from the list + entryIDs = entryIDs[batchSize:] + + entriesLeftCount := len(entryIDs) + if entriesLeftCount == 0 { + c.log.Info("Finished processing all tainted entries") + c.notifyTaintedBatchProcessed() + return + } + c.log.WithField(telemetry.Count, entriesLeftCount).Info("There are tainted X.509 SVIDs left to be processed") + c.notifyTaintedBatchProcessed() + + select { + case <-ticker.C: + case <-ctx.Done(): + c.log.WithError(ctx.Err()).Warn("Context cancelled, exiting rotation schedule") + return + } + } +} + +func (c *LRUCache) notifyTaintedBatchProcessed() { + if c.taintedBatchProcessedCh != nil { + c.taintedBatchProcessedCh <- struct{}{} + } +} + +// processTaintedSVIDs identifies and removes tainted SVIDs from the cache that have been signed by the given tainted authorities. +func (c *LRUCache) processTaintedSVIDs(entryIDs []string, taintedX509Authorities []*x509.Certificate) { + counter := telemetry.StartCall(c.metrics, telemetry.CacheManager, "", telemetry.ProcessTaintedSVIDs) + defer counter.Done(nil) + + taintedSVIDs := 0 + + c.mu.Lock() + defer c.mu.Unlock() + + for _, entryID := range entryIDs { + svid, exists := c.svids[entryID] + if !exists || svid == nil { + // Skip if the SVID is not in cache or is nil + continue + } + + // Check if the SVID is signed by any tainted authority + isTainted, err := x509util.IsSignedByRoot(svid.Chain, taintedX509Authorities) + if err != nil { + c.log.WithError(err). + WithField(telemetry.RegistrationID, entryID). + Error("Failed to check if SVID is signed by tainted authority") + continue + } + if isTainted { + taintedSVIDs++ + delete(c.svids, entryID) + } + } + + agentmetrics.AddCacheManagerTaintedSVIDsSample(c.metrics, "", float32(taintedSVIDs)) + c.log.WithField(telemetry.TaintedSVIDs, taintedSVIDs).Info("Tainted X.509 SVIDs") +} + // Notify subscriber of selector set only if all SVIDs for corresponding selector set are cached // It returns whether all SVIDs are cached or not. // This method should be retried with backoff to avoid lock contention. diff --git a/pkg/agent/manager/cache/lru_cache_test.go b/pkg/agent/manager/cache/lru_cache_test.go index 63da336a42..d1c1e62543 100644 --- a/pkg/agent/manager/cache/lru_cache_test.go +++ b/pkg/agent/manager/cache/lru_cache_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" "github.com/spiffe/go-spiffe/v2/spiffeid" @@ -15,6 +16,8 @@ import ( "github.com/spiffe/spire/proto/spire/common" "github.com/spiffe/spire/test/clock" "github.com/spiffe/spire/test/fakes/fakemetrics" + "github.com/spiffe/spire/test/spiretest" + "github.com/spiffe/spire/test/testca" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -964,6 +967,196 @@ func TestSubscribeToLRUCacheChanges(t *testing.T) { } } +func TestTaintX509SVIDs(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + clk := clock.NewMock(t) + fakeMetrics := fakemetrics.New() + log, logHook := test.NewNullLogger() + log.Level = logrus.DebugLevel + + batchProcessedCh := make(chan struct{}, 1) + + // Initialize cache with configuration + cache := newTestLRUCacheWithConfig(clk) + cache.processingBatchSize = 4 + cache.log = log + cache.taintedBatchProcessedCh = batchProcessedCh + cache.metrics = fakeMetrics + + entries := createTestEntries(10) + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(entries...), + } + + // Add entries to cache + cache.UpdateEntries(updateEntries, nil) + + taintedCA := testca.New(t, trustDomain1) + newCA := testca.New(t, trustDomain1) + svids := makeX509SVIDs(entries...) + + // Prepare SVIDs (some are signed by tainted authority, others are not) + prepareSVIDs(t, entries[:3], svids, taintedCA) // SVIDs for e0-e2 tainted + prepareSVIDs(t, entries[3:5], svids, newCA) // SVIDs for e3-e4 not tainted + prepareSVIDs(t, entries[5:], svids, taintedCA) // SVIDs for e5-e9 tainted + + cache.svids = svids + require.Equal(t, 10, cache.CountX509SVIDs()) + + waitForBatchFinished := func() { + select { + case <-cache.taintedBatchProcessedCh: + case <-ctx.Done(): + require.Fail(t, "failed to process tainted authorities") + } + } + + assertBatchProcess := func(expectLogs []spiretest.LogEntry, expectMetrics []fakemetrics.MetricItem, svidIDs ...string) { + waitForBatchFinished() + spiretest.AssertLogs(t, logHook.AllEntries(), expectLogs) + assert.Equal(t, expectMetrics, fakeMetrics.AllMetrics()) + + assert.Len(t, cache.svids, len(svidIDs)) + for _, svidID := range svidIDs { + _, found := cache.svids[svidID] + assert.True(t, found, "svid not found: %q", svidID) + } + } + + expectElapsedTimeMetric := []fakemetrics.MetricItem{ + { + Type: fakemetrics.IncrCounterWithLabelsType, + Key: []string{"cache_manager", "", "process_tainted_svids"}, + Val: 1, + Labels: []telemetry.Label{ + { + Name: "status", + Value: "OK", + }, + }, + }, + { + Type: fakemetrics.MeasureSinceWithLabelsType, + Key: []string{"cache_manager", "", "process_tainted_svids", "elapsed_time"}, + Val: 0, + Labels: []telemetry.Label{ + { + Name: "status", + Value: "OK", + }, + }, + }, + } + + // Reset logs and metrics before testing + resetLogsAndMetrics(logHook, fakeMetrics) + + // Schedule taint and assert initial batch processing + cache.TaintX509SVIDs(ctx, taintedCA.X509Authorities()) + + expectLog := []spiretest.LogEntry{ + { + Level: logrus.DebugLevel, + Message: "Scheduled rotation for SVID entries due to tainted X.509 authorities", + Data: logrus.Fields{telemetry.Count: "10"}, + }, + { + Level: logrus.InfoLevel, + Message: "Tainted X.509 SVIDs", + Data: logrus.Fields{telemetry.TaintedSVIDs: "3"}, + }, + { + Level: logrus.InfoLevel, + Message: "There are tainted X.509 SVIDs left to be processed", + Data: logrus.Fields{telemetry.Count: "6"}, + }, + } + expectMetrics := append([]fakemetrics.MetricItem{ + {Type: fakemetrics.AddSampleType, Key: []string{telemetry.CacheManager, "", telemetry.TaintedSVIDs}, Val: 3}}, + expectElapsedTimeMetric...) + assertBatchProcess(expectLog, expectMetrics, "e3", "e4", "e5", "e6", "e7", "e8", "e9") + + // Advance clock, reset logs and metrics, and verify batch processing + resetLogsAndMetrics(logHook, fakeMetrics) + clk.Add(6 * time.Second) + + expectLog = []spiretest.LogEntry{ + { + Level: logrus.InfoLevel, + Message: "Tainted X.509 SVIDs", + Data: logrus.Fields{telemetry.TaintedSVIDs: "3"}, + }, + { + Level: logrus.InfoLevel, + Message: "There are tainted X.509 SVIDs left to be processed", + Data: logrus.Fields{telemetry.Count: "2"}, + }, + } + expectMetrics = append([]fakemetrics.MetricItem{ + {Type: fakemetrics.AddSampleType, Key: []string{telemetry.CacheManager, "", telemetry.TaintedSVIDs}, Val: 3}}, + expectElapsedTimeMetric...) + assertBatchProcess(expectLog, expectMetrics, "e3", "e4", "e8", "e9") + + // Advance clock again for the final batch + resetLogsAndMetrics(logHook, fakeMetrics) + clk.Add(6 * time.Second) + + expectLog = []spiretest.LogEntry{ + { + Level: logrus.InfoLevel, + Message: "Tainted X.509 SVIDs", + Data: logrus.Fields{telemetry.TaintedSVIDs: "2"}, + }, + { + Level: logrus.InfoLevel, + Message: "Finished processing all tainted entries", + }, + } + expectMetrics = append([]fakemetrics.MetricItem{ + {Type: fakemetrics.AddSampleType, Key: []string{telemetry.CacheManager, "", telemetry.TaintedSVIDs}, Val: 2}}, + expectElapsedTimeMetric...) + assertBatchProcess(expectLog, expectMetrics, "e3", "e4") +} + +func TestTaintX509SVIDsNoSVIDs(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + clk := clock.NewMock(t) + log, logHook := test.NewNullLogger() + log.Level = logrus.DebugLevel + + // Initialize cache with configuration + cache := newTestLRUCacheWithConfig(clk) + cache.log = log + + entries := createTestEntries(10) + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(entries...), + } + // All entries has no chain... + cache.svids = makeX509SVIDs(entries...) + + // Add entries to cache + cache.UpdateEntries(updateEntries, nil) + logHook.Reset() + + fakeBundle := []*x509.Certificate{{Raw: []byte("foo")}} + cache.TaintX509SVIDs(ctx, fakeBundle) + + expectLog := []spiretest.LogEntry{ + { + Level: logrus.DebugLevel, + Message: "No SVID entries to process for tainted X.509 authorities", + }, + } + spiretest.AssertLogs(t, logHook.AllEntries(), expectLog) +} + func TestMetrics(t *testing.T) { cache := newTestLRUCache(t) fakeMetrics := fakemetrics.New() @@ -1081,7 +1274,7 @@ func newTestLRUCache(t testing.TB) *LRUCache { func newTestLRUCacheWithConfig(clk clock.Clock) *LRUCache { log, _ := test.NewNullLogger() - return NewLRUCache(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, telemetry.Blackhole{}, clk) + return NewLRUCache(log, trustDomain1, bundleV1, telemetry.Blackhole{}, clk) } // numEntries should not be more than 12 digits @@ -1225,3 +1418,31 @@ func makeFederatesWith(bundles ...*Bundle) []string { } return out } + +func createTestEntries(count int) []*common.RegistrationEntry { + var entries []*common.RegistrationEntry + for i := 0; i < count; i++ { + entry := makeRegistrationEntry(fmt.Sprintf("e%d", i), fmt.Sprintf("s%d", i)) + entries = append(entries, entry) + } + return entries +} + +func prepareSVIDs(t *testing.T, entries []*common.RegistrationEntry, svids map[string]*X509SVID, ca *testca.CA) { + for _, entry := range entries { + svid, ok := svids[entry.EntryId] + require.True(t, ok) + + chain, key := ca.CreateX509Certificate( + testca.WithID(spiffeid.RequireFromPath(trustDomain1, "/"+entry.EntryId)), + ) + + svid.Chain = chain + svid.PrivateKey = key + } +} + +func resetLogsAndMetrics(logHook *test.Hook, fakeMetrics *fakemetrics.FakeMetrics) { + logHook.Reset() + fakeMetrics.Reset() +} diff --git a/pkg/agent/manager/config.go b/pkg/agent/manager/config.go index cd90472525..abba259284 100644 --- a/pkg/agent/manager/config.go +++ b/pkg/agent/manager/config.go @@ -94,6 +94,9 @@ func newManager(c *Config) *manager { client: client, clk: c.Clk, svidStoreCache: c.SVIDStoreCache, + + processedTaintedX509Authorities: make(map[string]struct{}), + processedTaintedJWTAuthorities: make(map[string]struct{}), } return m diff --git a/pkg/agent/manager/manager.go b/pkg/agent/manager/manager.go index 3294b1c2fc..56013309df 100644 --- a/pkg/agent/manager/manager.go +++ b/pkg/agent/manager/manager.go @@ -22,6 +22,7 @@ import ( "github.com/spiffe/spire/pkg/common/rotationutil" "github.com/spiffe/spire/pkg/common/telemetry" "github.com/spiffe/spire/pkg/common/util" + "github.com/spiffe/spire/pkg/common/x509util" "github.com/spiffe/spire/pkg/server/api/limits" "github.com/spiffe/spire/proto/spire/common" ) @@ -171,6 +172,14 @@ type manager struct { // cache. syncedEntries map[string]*common.RegistrationEntry syncedBundles map[string]*common.Bundle + + // processedTaintedX509Authorities holds all the already processed tainted X.509 Authorities + // to prevent processing them again. + processedTaintedX509Authorities map[string]struct{} + + // processedTaintedJWTAuthorities holds all the already processed tainted JWT Authorities + // to prevent processing them again. + processedTaintedJWTAuthorities map[string]struct{} } func (m *manager) Initialize(ctx context.Context) error { @@ -316,7 +325,7 @@ func (m *manager) runSynchronizer(ctx context.Context) error { err := m.synchronize(ctx) switch { - case nodeutil.IsUnknownAuthorityError(err): + case x509util.IsUnknownAuthorityError(err): m.c.Log.WithError(err).Info("Synchronize failed, non-recoverable error") return fmt.Errorf("failed to sync with SPIRE Server: %w", err) case err != nil && nodeutil.ShouldAgentReattest(err): diff --git a/pkg/agent/manager/manager_test.go b/pkg/agent/manager/manager_test.go index 6fd93bf266..ccdb2eec94 100644 --- a/pkg/agent/manager/manager_test.go +++ b/pkg/agent/manager/manager_test.go @@ -9,11 +9,13 @@ import ( "errors" "fmt" "net" + "reflect" "sync" "sync/atomic" "testing" "time" + "github.com/sirupsen/logrus" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" "github.com/spiffe/go-spiffe/v2/spiffeid" @@ -772,6 +774,178 @@ func TestSynchronizationUpdatesRegistrationEntries(t *testing.T) { m.cache.Entries()) } +func TestForceRotation(t *testing.T) { + dir := spiretest.TempDir(t) + km := fakeagentkeymanager.New(t, dir) + + clk := clock.NewMock(t) + // Big number to never get into regular rotation + ttl := 10000 + api := newMockAPI(t, &mockAPIConfig{ + km: km, + getAuthorizedEntries: func(*mockAPI, int32, *entryv1.GetAuthorizedEntriesRequest) (*entryv1.GetAuthorizedEntriesResponse, error) { + return makeGetAuthorizedEntriesResponse(t, "resp1", "resp2"), nil + }, + batchNewX509SVIDEntries: func(*mockAPI, int32) []*common.RegistrationEntry { + return makeBatchNewX509SVIDEntries("resp1", "resp2") + }, + svidTTL: ttl, + clk: clk, + }) + + baseSVID, baseSVIDKey := api.newSVID(joinTokenID, 1*time.Hour) + cat := fakeagentcatalog.New() + cat.SetKeyManager(km) + + log, logHook := testlog.NewNullLogger() + log.Level = logrus.DebugLevel + + c := &Config{ + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: log, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + RotationInterval: time.Hour, + SyncInterval: time.Hour, + Clk: clk, + Catalog: cat, + WorkloadKeyType: workloadkey.ECP256, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger, Metrics: &telemetry.Blackhole{}}), + RotationStrategy: rotationutil.NewRotationStrategy(0), + } + + m := newManager(c) + + sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{ + {Type: "unix", Value: "uid:1111"}, + {Type: "spiffe_id", Value: joinTokenID.String()}, + }) + require.NoError(t, err) + defer sub.Finish() + + if err := m.Initialize(context.Background()); err != nil { + t.Fatal(err) + } + require.Equal(t, clk.Now(), m.GetLastSync()) + + // Before synchronization + identitiesBefore := identitiesByEntryID(m.cache.Identities()) + if len(identitiesBefore) != 3 { + t.Fatalf("3 cached identities were expected; got %d", len(identitiesBefore)) + } + + // This is the initial update based on the selector set + u := <-sub.Updates() + if len(u.Identities) != 3 { + t.Fatalf("expected 3 identities, got: %d", len(u.Identities)) + } + + if len(u.Bundle.X509Authorities()) != 1 { + t.Fatal("expected 1 bundle root CA") + } + + if !u.Bundle.Equal(api.bundle) { + t.Fatal("received bundle should be equals to the server bundle") + } + + for key, eu := range identitiesByEntryID(u.Identities) { + eb, ok := identitiesBefore[key] + if !ok { + t.Fatalf("an update was received for an inexistent entry on the cache with EntryId=%v", key) + } + require.Equal(t, eb, eu, "identity received does not match identity on cache") + } + + require.Equal(t, clk.Now(), m.GetLastSync()) + + // No ttl and bundle updates + clk.Add(time.Second) + require.NoError(t, m.synchronize(context.Background())) + select { + case <-sub.Updates(): + t.Fatal("update unexpected after 1 second") + default: + } + assert.False(t, m.svid.IsTainted()) + + // Taint authority + api.taintCurrentX509Authority() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + // Initial synchronization + require.NoError(t, m.synchronize(ctx)) + + // Wait until tainted authorities are fully processed, then retry synchronization + assert.Eventually(t, func() bool { + for _, logEntry := range logHook.Entries { + if logEntry.Message == "Finished processing all tainted entries" { + return true + } + } + return false + }, time.Minute, 50*time.Millisecond, "No tainted authority processed") + + // Retry synchronization to handle potential edge case + require.NoError(t, m.synchronize(ctx)) + + select { + case u = <-sub.Updates(): + case <-ctx.Done(): + t.Fatal("Expected update after tainting authority, but none received") + } + + // SVID is signed by a tainted authority, it must be tainted + assert.True(t, m.svid.IsTainted()) + taintedSubjectKeyID := x509util.SubjectKeyIDToString(api.taintedX509Authority.SubjectKeyId) + expectProcessedTaintedX509Authorities := map[string]struct{}{ + taintedSubjectKeyID: {}, + } + assert.Equal(t, expectProcessedTaintedX509Authorities, m.processedTaintedX509Authorities) + + // Make sure the update contains the updated entries and that the cache + // has a consistent view. + identitiesAfter := identitiesByEntryID(m.cache.Identities()) + if len(identitiesAfter) != 3 { + t.Fatalf("expected 3 identities, got: %d", len(identitiesAfter)) + } + + for key, eb := range identitiesBefore { + ea, ok := identitiesAfter[key] + if !ok { + t.Fatalf("expected identity with EntryId=%v after synchronization", key) + } + require.NotEqual(t, eb, ea, "there is at least one identity that was not refreshed: %v", ea) + } + + if len(u.Identities) != 3 { + t.Fatalf("expected 3 identities, got: %d", len(u.Identities)) + } + + if len(u.Bundle.X509Authorities()) != 2 { + t.Fatal("expected 1 bundle root CA") + } + + if !u.Bundle.Equal(api.bundle) { + t.Fatal("received bundle should be equals to the server bundle") + } + + for key, eu := range identitiesByEntryID(u.Identities) { + ea, ok := identitiesAfter[key] + if !ok { + t.Fatalf("an update was received for an inexistent entry on the cache with EntryId=%v", key) + } + require.Equal(t, eu, ea, "entry received does not match entry on cache") + } + + require.Equal(t, clk.Now(), m.GetLastSync()) +} + func TestSubscribersGetUpToDateBundle(t *testing.T) { dir := spiretest.TempDir(t) km := fakeagentkeymanager.New(t, dir) @@ -1597,6 +1771,8 @@ type mockAPI struct { getAuthorizedEntriesCount int32 batchNewX509SVIDCount int32 + taintedX509Authority *x509.Certificate + clk clock.Clock // Add latest's SVIDs per entry, to verify returned SVIDs are valid @@ -1716,7 +1892,16 @@ func (h *mockAPI) NewJWTSVID(_ context.Context, req *svidv1.NewJWTSVIDRequest) ( } func (h *mockAPI) GetBundle(context.Context, *bundlev1.GetBundleRequest) (*types.Bundle, error) { - return api.BundleToProto(bundleutil.BundleProtoFromRootCAs(h.bundle.TrustDomain().IDString(), h.bundle.X509Authorities())) + bundle := bundleutil.BundleProtoFromRootCAs(h.bundle.TrustDomain().IDString(), h.bundle.X509Authorities()) + if h.taintedX509Authority != nil { + for _, eachRootCA := range bundle.RootCas { + if reflect.DeepEqual(eachRootCA.DerBytes, h.taintedX509Authority.Raw) { + eachRootCA.TaintedKey = true + } + } + } + + return api.BundleToProto(bundle) } func (h *mockAPI) GetFederatedBundle(_ context.Context, req *bundlev1.GetFederatedBundleRequest) (*types.Bundle, error) { @@ -1728,6 +1913,15 @@ func (h *mockAPI) GetFederatedBundle(_ context.Context, req *bundlev1.GetFederat }, nil } +// taintCurrentX509Authority create a new X.509 authority and taint old +func (h *mockAPI) taintCurrentX509Authority() { + h.taintedX509Authority = h.ca + ca, caKey := createCA(h.t, h.clk) + h.ca = ca + h.caKey = caKey + h.bundle.AddX509Authority(ca) +} + func (h *mockAPI) rotateCA() { ca, caKey := createCA(h.t, h.clk) h.ca = ca diff --git a/pkg/agent/manager/storecache/cache.go b/pkg/agent/manager/storecache/cache.go index f12225bd00..a2c7538f84 100644 --- a/pkg/agent/manager/storecache/cache.go +++ b/pkg/agent/manager/storecache/cache.go @@ -1,6 +1,8 @@ package storecache import ( + "context" + "crypto/x509" "sort" "sync" "time" @@ -10,6 +12,8 @@ import ( "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/manager/cache" "github.com/spiffe/spire/pkg/common/telemetry" + telemetry_agent "github.com/spiffe/spire/pkg/common/telemetry/agent" + "github.com/spiffe/spire/pkg/common/x509util" "github.com/spiffe/spire/proto/spire/common" ) @@ -46,6 +50,7 @@ type cachedRecord struct { type Config struct { Log logrus.FieldLogger TrustDomain spiffeid.TrustDomain + Metrics telemetry.Metrics } type Cache struct { @@ -219,6 +224,38 @@ func (c *Cache) UpdateSVIDs(update *cache.UpdateSVIDs) { } } +func (c *Cache) TaintX509SVIDs(ctx context.Context, taintedX509Authorities []*x509.Certificate) { + c.mtx.Lock() + defer c.mtx.Unlock() + + counter := telemetry.StartCall(c.c.Metrics, telemetry.CacheManager, "svid_store", telemetry.ProcessTaintedSVIDs) + defer counter.Done(nil) + + taintedSVIDs := 0 + for _, record := range c.records { + // Skip nil or already tainted SVIDs + if record.svid == nil { + continue + } + + isTainted, err := x509util.IsSignedByRoot(record.svid.Chain, taintedX509Authorities) + if err != nil { + c.c.Log.WithError(err). + WithField(telemetry.RegistrationID, record.entry.EntryId). + Error("Failed to check if SVID is signed by tainted authority") + continue + } + + if isTainted { + taintedSVIDs++ + record.svid = nil // Mark SVID as tainted by setting it to nil + } + } + + telemetry_agent.AddCacheManagerExpiredSVIDsSample(c.c.Metrics, "svid_store", float32(taintedSVIDs)) + c.c.Log.WithField(telemetry.TaintedSVIDs, taintedSVIDs).Info("Tainted X.509 SVIDs") +} + // GetStaleEntries obtains a list of stale entries, that needs new SVIDs func (c *Cache) GetStaleEntries() []*cache.StaleEntry { c.mtx.Lock() diff --git a/pkg/agent/manager/storecache/cache_test.go b/pkg/agent/manager/storecache/cache_test.go index 1a05dc50ff..0e10503ee4 100644 --- a/pkg/agent/manager/storecache/cache_test.go +++ b/pkg/agent/manager/storecache/cache_test.go @@ -1,11 +1,14 @@ package storecache_test import ( + "context" "crypto/x509" + "fmt" "net/url" "testing" "time" + "github.com/hashicorp/go-metrics" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" @@ -14,7 +17,9 @@ import ( "github.com/spiffe/spire/pkg/agent/manager/storecache" "github.com/spiffe/spire/pkg/common/telemetry" "github.com/spiffe/spire/proto/spire/common" + "github.com/spiffe/spire/test/fakes/fakemetrics" "github.com/spiffe/spire/test/spiretest" + "github.com/spiffe/spire/test/testca" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -907,6 +912,141 @@ func TestUpdateEntriesCreatesNewEntriesOnCache(t *testing.T) { spiretest.AssertLogsAnyOrder(t, hook.AllEntries(), expectedLogs) } +func TestTaintX509SVIDs(t *testing.T) { + ctx := context.Background() + log, hook := test.NewNullLogger() + log.Level = logrus.DebugLevel + fakeMetrics := fakemetrics.New() + taintedAuthority := testca.New(t, td) + newAuthority := testca.New(t, td) + + c := storecache.New(&storecache.Config{ + Log: log, + TrustDomain: td, + Metrics: fakeMetrics, + }) + + // Create initial entries + entries := makeEntries(td, "e1", "e2", "e3", "e4", "e5") + updateEntries := &cache.UpdateEntries{ + Bundles: map[spiffeid.TrustDomain]*spiffebundle.Bundle{ + td: tdBundle, + }, + RegistrationEntries: entries, + } + + // Set entries to cache + c.UpdateEntries(updateEntries, nil) + + noTaintedSVID := createX509SVID(td, "e3", newAuthority) + updateSVIDs := &cache.UpdateSVIDs{ + X509SVIDs: map[string]*cache.X509SVID{ + "e1": createX509SVID(td, "e1", taintedAuthority), + "e2": createX509SVID(td, "e2", taintedAuthority), + "e3": noTaintedSVID, + "e5": createX509SVID(td, "e5", taintedAuthority), + }, + } + c.UpdateSVIDs(updateSVIDs) + + for _, tt := range []struct { + name string + taintedAuthorities []*x509.Certificate + expectSVID map[string]*cache.X509SVID + expectLogs []spiretest.LogEntry + expectMetrics []fakemetrics.MetricItem + }{ + { + name: "taint SVIDs", + taintedAuthorities: taintedAuthority.X509Authorities(), + expectSVID: map[string]*cache.X509SVID{ + "e1": nil, + "e2": nil, + "e3": noTaintedSVID, + "e4": nil, + "e5": nil, + }, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.InfoLevel, + Message: "Tainted X.509 SVIDs", + Data: logrus.Fields{ + telemetry.TaintedSVIDs: "3", + }, + }, + }, + expectMetrics: []fakemetrics.MetricItem{ + { + Type: fakemetrics.AddSampleType, + Key: []string{"cache_manager", "svid_store", "expiring_svids", "svid_store"}, + Val: 3, + }, + { + Type: fakemetrics.IncrCounterWithLabelsType, + Key: []string{"cache_manager", "svid_store", "process_tainted_svids"}, + Val: 1, + Labels: []metrics.Label{{Name: "status", Value: "OK"}}, + }, + { + Type: fakemetrics.MeasureSinceWithLabelsType, + Key: []string{"cache_manager", "svid_store", "process_tainted_svids", "elapsed_time"}, + Val: 0, + Labels: []metrics.Label{{Name: "status", Value: "OK"}}, + }, + }, + }, + { + name: "taint again", + taintedAuthorities: taintedAuthority.X509Authorities(), + expectSVID: map[string]*cache.X509SVID{ + "e1": nil, + "e2": nil, + "e3": noTaintedSVID, + "e4": nil, + "e5": nil, + }, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.InfoLevel, + Message: "Tainted X.509 SVIDs", + Data: logrus.Fields{ + telemetry.TaintedSVIDs: "0", + }, + }, + }, + expectMetrics: []fakemetrics.MetricItem{ + { + Type: fakemetrics.AddSampleType, + Key: []string{"cache_manager", "svid_store", "expiring_svids", "svid_store"}, + Val: 0, + }, + { + Type: fakemetrics.IncrCounterWithLabelsType, + Key: []string{"cache_manager", "svid_store", "process_tainted_svids"}, + Val: 1, + Labels: []metrics.Label{{Name: "status", Value: "OK"}}, + }, + { + Type: fakemetrics.MeasureSinceWithLabelsType, + Key: []string{"cache_manager", "svid_store", "process_tainted_svids", "elapsed_time"}, + Val: 0, + Labels: []metrics.Label{{Name: "status", Value: "OK"}}, + }, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + hook.Reset() + fakeMetrics.Reset() + + c.TaintX509SVIDs(ctx, tt.taintedAuthorities) + assert.Equal(t, tt.expectSVID, svidMapFromRecords(c.Records())) + spiretest.AssertLogs(t, hook.AllEntries(), tt.expectLogs) + assert.Equal(t, tt.expectMetrics, fakeMetrics.AllMetrics()) + }) + } +} + func TestUpdateSVIDs(t *testing.T) { log, hook := test.NewNullLogger() log.Level = logrus.DebugLevel @@ -1241,3 +1381,45 @@ func createTestEntry() *common.RegistrationEntry { RevisionNumber: 1, } } + +func svidMapFromRecords(records []*storecache.Record) map[string]*cache.X509SVID { + recordsMap := make(map[string]*cache.X509SVID, len(records)) + for _, eachRecord := range records { + recordsMap[eachRecord.ID] = eachRecord.Svid + } + return recordsMap +} + +func createX509SVID(td spiffeid.TrustDomain, id string, ca *testca.CA) *cache.X509SVID { + chain, key := ca.CreateX509Certificate( + testca.WithID(spiffeid.RequireFromPath(td, "/"+id)), + ) + return &cache.X509SVID{ + Chain: chain, + PrivateKey: key, + } +} + +func makeEntries(td spiffeid.TrustDomain, ids ...string) map[string]*common.RegistrationEntry { + entries := make(map[string]*common.RegistrationEntry, len(ids)) + for _, id := range ids { + entries[id] = &common.RegistrationEntry{ + EntryId: id, + SpiffeId: spiffeid.RequireFromPath(td, "/"+id).String(), + Selectors: makeSelectors(id), + StoreSvid: true, + } + } + return entries +} + +func makeSelectors(values ...string) []*common.Selector { + var selectors []*common.Selector + for _, value := range values { + selectors = append(selectors, &common.Selector{ + Type: "t", + Value: fmt.Sprintf("v:%s", value), + }) + } + return selectors +} diff --git a/pkg/agent/manager/sync.go b/pkg/agent/manager/sync.go index 6e826716cc..b856bdede7 100644 --- a/pkg/agent/manager/sync.go +++ b/pkg/agent/manager/sync.go @@ -4,6 +4,8 @@ import ( "context" "crypto" "crypto/x509" + "fmt" + "strings" "time" "github.com/sirupsen/logrus" @@ -15,6 +17,7 @@ import ( "github.com/spiffe/spire/pkg/common/telemetry" telemetry_agent "github.com/spiffe/spire/pkg/common/telemetry/agent" "github.com/spiffe/spire/pkg/common/util" + "github.com/spiffe/spire/pkg/common/x509util" "github.com/spiffe/spire/proto/spire/common" ) @@ -33,6 +36,10 @@ type SVIDCache interface { // GetStaleEntries gets a list of records that need update SVIDs GetStaleEntries() []*cache.StaleEntry + + // TaintX509SVIDs marks all SVIDs signed by a tainted X.509 authority as tainted + // to force their rotation. + TaintX509SVIDs(ctx context.Context, taintedX509Authorities []*x509.Certificate) } func (m *manager) syncSVIDs(ctx context.Context) (err error) { @@ -40,6 +47,44 @@ func (m *manager) syncSVIDs(ctx context.Context) (err error) { return m.updateSVIDs(ctx, m.c.Log.WithField(telemetry.CacheType, "workload"), m.cache) } +// processTaintedAuthorities verifies if a new authority is tainted and forces rotation in all caches if required. +func (m *manager) processTaintedAuthorities(ctx context.Context, x509Authorities []string, jwtAuthorities []string) error { + newTaintedX509Authorities := getNewItems(m.processedTaintedX509Authorities, x509Authorities) + if len(newTaintedX509Authorities) > 0 { + m.c.Log.WithField(telemetry.SubjectKeyIDs, strings.Join(newTaintedX509Authorities, ",")). + Debug("New tainted X.509 authorities found") + + taintedX509Authorities, err := bundleutil.FindX509Authorities(m.c.Bundle, newTaintedX509Authorities) + if err != nil { + return fmt.Errorf("failed to search X.509 authorities: %w", err) + } + + // Taint all regular X.509 SVIDs + m.cache.TaintX509SVIDs(ctx, taintedX509Authorities) + + // Taint all SVIDStore SVIDs + m.svidStoreCache.TaintX509SVIDs(ctx, taintedX509Authorities) + + // Notify rotator about new tainted authorities + if err := m.svid.NotifyTaintedAuthorities(taintedX509Authorities); err != nil { + return err + } + + for _, subjectKeyID := range newTaintedX509Authorities { + m.processedTaintedX509Authorities[subjectKeyID] = struct{}{} + } + } + + newTaintedJWTAuthorities := getNewItems(m.processedTaintedJWTAuthorities, jwtAuthorities) + if len(newTaintedJWTAuthorities) > 0 { + m.c.Log.WithField(telemetry.SubjectKeyIDs, strings.Join(newTaintedJWTAuthorities, ",")). + Debug("New tainted JWT authorities found") + // TODO: IMPLEMENT!!! + } + + return nil +} + // synchronize fetches the authorized entries from the server, updates the // cache, and fetches missing/expiring SVIDs. func (m *manager) synchronize(ctx context.Context) (err error) { @@ -48,6 +93,11 @@ func (m *manager) synchronize(ctx context.Context) (err error) { return err } + // Process all tainted authorities. The bundle is shared between both caches using regular cache data. + if err := m.processTaintedAuthorities(ctx, cacheUpdate.TaintedX509Authorities, cacheUpdate.TaintedJWTAuthorities); err != nil { + return err + } + if err := m.updateCache(ctx, cacheUpdate, m.c.Log.WithField(telemetry.CacheType, "workload"), "", m.cache); err != nil { return err } @@ -254,6 +304,27 @@ func (m *manager) fetchEntries(ctx context.Context) (_ *cache.UpdateEntries, _ * return nil, nil, err } + // Get all Subject Key IDs and KeyIDs of tainted authorities + var taintedX509Authorities []string + var taintedJWTAuthorities []string + if b, ok := update.Bundles[m.c.TrustDomain.IDString()]; ok { + for _, rootCA := range b.RootCas { + if rootCA.TaintedKey { + cert, err := x509.ParseCertificate(rootCA.DerBytes) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse tainted x509 authority: %w", err) + } + subjectKeyID := x509util.SubjectKeyIDToString(cert.SubjectKeyId) + taintedX509Authorities = append(taintedX509Authorities, subjectKeyID) + } + } + for _, jwtKey := range b.JwtSigningKeys { + if jwtKey.TaintedKey { + taintedJWTAuthorities = append(taintedJWTAuthorities, jwtKey.Kid) + } + } + } + cacheEntries := make(map[string]*common.RegistrationEntry) storeEntries := make(map[string]*common.RegistrationEntry) @@ -267,11 +338,15 @@ func (m *manager) fetchEntries(ctx context.Context) (_ *cache.UpdateEntries, _ * } return &cache.UpdateEntries{ - Bundles: bundles, - RegistrationEntries: cacheEntries, + Bundles: bundles, + RegistrationEntries: cacheEntries, + TaintedJWTAuthorities: taintedJWTAuthorities, + TaintedX509Authorities: taintedX509Authorities, }, &cache.UpdateEntries{ - Bundles: bundles, - RegistrationEntries: storeEntries, + Bundles: bundles, + RegistrationEntries: storeEntries, + TaintedJWTAuthorities: taintedJWTAuthorities, + TaintedX509Authorities: taintedX509Authorities, }, nil } @@ -303,3 +378,14 @@ func parseBundles(bundles map[string]*common.Bundle) (map[spiffeid.TrustDomain]* } return out, nil } + +func getNewItems(current map[string]struct{}, items []string) []string { + var newItems []string + for _, subjectKeyID := range items { + if _, ok := current[subjectKeyID]; !ok { + newItems = append(newItems, subjectKeyID) + } + } + + return newItems +} diff --git a/pkg/agent/svid/rotator.go b/pkg/agent/svid/rotator.go index 434230d854..0b66fe3df9 100644 --- a/pkg/agent/svid/rotator.go +++ b/pkg/agent/svid/rotator.go @@ -21,12 +21,17 @@ import ( "github.com/spiffe/spire/pkg/common/telemetry" telemetry_agent "github.com/spiffe/spire/pkg/common/telemetry/agent" "github.com/spiffe/spire/pkg/common/util" + "github.com/spiffe/spire/pkg/common/x509util" "google.golang.org/grpc" ) type Rotator interface { Run(ctx context.Context) error Reattest(ctx context.Context) error + // NotifyTaintedAuthorities processes new tainted authorities. If the current SVID is compromised, + // it is marked to force rotation. + NotifyTaintedAuthorities([]*x509.Certificate) error + IsTainted() bool State() State Subscribe() observer.Stream @@ -58,6 +63,8 @@ type rotator struct { // Hook that will be called when the SVID rotation finishes rotationFinishedHook func() + + tainted bool } type State struct { @@ -130,6 +137,43 @@ func (r *rotator) Subscribe() observer.Stream { return r.state.Observe() } +func (r *rotator) IsTainted() bool { + r.rotMtx.RLock() + defer r.rotMtx.RUnlock() + + return r.tainted +} + +func (r *rotator) setTainted(tainted bool) { + r.rotMtx.Lock() + defer r.rotMtx.Unlock() + + r.tainted = tainted +} + +func (r *rotator) NotifyTaintedAuthorities(taintedAuthorities []*x509.Certificate) error { + state, ok := r.state.Value().(State) + if !ok { + return fmt.Errorf("unexpected state value type: %T", r.state.Value()) + } + + if r.IsTainted() { + r.c.Log.Debug("Agent SVID already tainted") + return nil + } + + tainted, err := x509util.IsSignedByRoot(state.SVID, taintedAuthorities) + if err != nil { + return fmt.Errorf("failed to check if SVID is tainted: %w", err) + } + + if tainted { + r.c.Log.Info("Agent SVID is tainted by a root authority, forcing rotation") + r.setTainted(tainted) + } + return nil +} + func (r *rotator) GetRotationMtx() *sync.RWMutex { return r.rotMtx } @@ -162,7 +206,7 @@ func (r *rotator) rotateSVIDIfNeeded(ctx context.Context) (err error) { return fmt.Errorf("unexpected value type: %T", r.state.Value()) } - if r.c.RotationStrategy.ShouldRotateX509(r.clk.Now(), state.SVID[0]) { + if r.c.RotationStrategy.ShouldRotateX509(r.clk.Now(), state.SVID[0]) || r.IsTainted() { if state.Reattestable { err = r.reattest(ctx) } else { @@ -222,6 +266,7 @@ func (r *rotator) reattest(ctx context.Context) (err error) { } r.state.Update(s) + r.tainted = false // We must release the client because its underlaying connection is tied to an // expired SVID, so next time the client is used, it will get a new connection with @@ -269,6 +314,7 @@ func (r *rotator) rotateSVID(ctx context.Context) (err error) { } r.state.Update(s) + r.tainted = false // We must release the client because its underlaying connection is tied to an // expired SVID, so next time the client is used, it will get a new connection with diff --git a/pkg/agent/svid/rotator_test.go b/pkg/agent/svid/rotator_test.go index bc7c2fc279..cacc69c3b8 100644 --- a/pkg/agent/svid/rotator_test.go +++ b/pkg/agent/svid/rotator_test.go @@ -120,7 +120,9 @@ func TestRotator(t *testing.T) { // Create the starting SVID svidKey, err := svidKM.GenerateKey(context.Background(), nil) require.NoError(t, err) + svid, err := createTestSVID(svidKey.Public(), caCert, caKey, clk.Now(), clk.Now().Add(tt.notAfter)) + require.NoError(t, err) // Advance the clock by one second so SVID will always be expired @@ -381,6 +383,163 @@ func TestRotationFails(t *testing.T) { } } +func TestNotifyTaintedAuthority(t *testing.T) { + caCert, caKey := testca.CreateCACertificate(t, nil, nil) + anotherCert, _ := testca.CreateCACertificate(t, nil, nil) + + svidKM := keymanager.ForSVID(fakeagentkeymanager.New(t, "")) + clk := clock.NewMock(t) + log, logHook := test.NewNullLogger() + log.Level = logrus.DebugLevel + + mockClient := &fakeClient{ + clk: clk, + caCert: caCert, + caKey: caKey, + } + + // Create the bundle + bundle := make(map[spiffeid.TrustDomain]*spiffebundle.Bundle) + bundle[trustDomain] = spiffebundle.FromX509Authorities(trustDomain, []*x509.Certificate{caCert}) + + // Create the starting SVID + svidKey, err := svidKM.GenerateKey(context.Background(), nil) + require.NoError(t, err) + + svid, err := createTestSVID(svidKey.Public(), caCert, caKey, clk.Now(), clk.Now().Add(time.Minute)) + require.NoError(t, err) + + // Initialize the rotator + rotator, _ := newRotator(&RotatorConfig{ + SVIDKeyManager: svidKM, + Log: log, + Metrics: telemetry.Blackhole{}, + TrustDomain: trustDomain, + BundleStream: cache.NewBundleStream(observer.NewProperty(bundle).Observe()), + Clk: clk, + SVID: svid, + SVIDKey: svidKey, + NodeAttestor: fakeagentnodeattestor.New(t, fakeagentnodeattestor.Config{}), + RotationStrategy: rotationutil.NewRotationStrategy(0), + }) + rotator.client = mockClient + + // Ensure cert is not tainted initially + require.False(t, rotator.IsTainted()) + + for _, tt := range []struct { + name string + authorities []*x509.Certificate + + expectTainted bool + expectLogs []spiretest.LogEntry + }{ + { + name: "no tainted", + authorities: []*x509.Certificate{anotherCert}, + expectTainted: false, + }, + { + name: "taint successfully", + authorities: []*x509.Certificate{caCert}, + expectTainted: true, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.InfoLevel, + Message: "Agent SVID is tainted by a root authority, forcing rotation", + }, + }, + }, + { + name: "already tainted", + authorities: []*x509.Certificate{caCert}, + expectTainted: true, + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.DebugLevel, + Message: "Agent SVID already tainted", + }, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + logHook.Reset() + + err := rotator.NotifyTaintedAuthorities(tt.authorities) + require.NoError(t, err) + + assert.Equal(t, tt.expectTainted, rotator.IsTainted()) + spiretest.AssertLogs(t, logHook.AllEntries(), tt.expectLogs) + }) + } +} + +func TestTaintedSVIDIsRotated(t *testing.T) { + caCert, caKey := testca.CreateCACertificate(t, nil, nil) + + svidKM := keymanager.ForSVID(fakeagentkeymanager.New(t, "")) + clk := clock.NewMock(t) + log, _ := test.NewNullLogger() + + mockClient := &fakeClient{ + clk: clk, + caCert: caCert, + caKey: caKey, + } + + // Create the bundle + bundle := make(map[spiffeid.TrustDomain]*spiffebundle.Bundle) + bundle[trustDomain] = spiffebundle.FromX509Authorities(trustDomain, []*x509.Certificate{caCert}) + + // Create the starting SVID + svidKey, err := svidKM.GenerateKey(context.Background(), nil) + require.NoError(t, err) + + svid, err := createTestSVID(svidKey.Public(), caCert, caKey, clk.Now(), clk.Now().Add(time.Minute)) + require.NoError(t, err) + + // Initialize the rotator + rotator, _ := newRotator(&RotatorConfig{ + SVIDKeyManager: svidKM, + Log: log, + Metrics: telemetry.Blackhole{}, + TrustDomain: trustDomain, + BundleStream: cache.NewBundleStream(observer.NewProperty(bundle).Observe()), + Clk: clk, + SVID: svid, + SVIDKey: svidKey, + NodeAttestor: fakeagentnodeattestor.New(t, fakeagentnodeattestor.Config{}), + RotationStrategy: rotationutil.NewRotationStrategy(0), + }) + rotator.client = mockClient + rotationFinishedCh := make(chan struct{}, 1) + rotator.rotationFinishedHook = func() { + close(rotationFinishedCh) + } + + // Mark SVID as tainted + rotator.tainted = true + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + errCh := make(chan error) + go func() { + errCh <- rotator.Run(ctx) + }() + + select { + case err = <-errCh: + t.Fatalf("unexpected error during first rotation loop: %v", err) + case <-rotationFinishedCh: + // Rotation expected + case <-ctx.Done(): + t.Fatal("expected rotation to finish before timeout") + } + + require.False(t, rotator.IsTainted(), "SVID must not be tainted after rotation") +} + type fakeClient struct { clk clock.Clock caCert *x509.Certificate diff --git a/pkg/common/bundleutil/bundle.go b/pkg/common/bundleutil/bundle.go index ac3dd8bff9..592c348648 100644 --- a/pkg/common/bundleutil/bundle.go +++ b/pkg/common/bundleutil/bundle.go @@ -12,6 +12,7 @@ import ( "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/spiffe/spire/pkg/common/telemetry" + "github.com/spiffe/spire/pkg/common/x509util" "github.com/spiffe/spire/proto/spire/common" "google.golang.org/protobuf/proto" ) @@ -29,7 +30,8 @@ func CommonBundleFromProto(b *types.Bundle) (*common.Bundle, error) { var rootCAs []*common.Certificate for _, rootCA := range b.X509Authorities { rootCAs = append(rootCAs, &common.Certificate{ - DerBytes: rootCA.Asn1, + DerBytes: rootCA.Asn1, + TaintedKey: rootCA.Tainted, }) } @@ -40,9 +42,10 @@ func CommonBundleFromProto(b *types.Bundle) (*common.Bundle, error) { } jwtKeys = append(jwtKeys, &common.PublicKey{ - PkixBytes: key.PublicKey, - Kid: key.KeyId, - NotAfter: key.ExpiresAt, + PkixBytes: key.PublicKey, + Kid: key.KeyId, + NotAfter: key.ExpiresAt, + TaintedKey: key.Tainted, }) } @@ -237,6 +240,32 @@ pruneRootCA: return newBundle, changed, nil } +// FindX509Authorities search for all X.509 authorities with provided subjectKeyIDs +func FindX509Authorities(bundle *spiffebundle.Bundle, subjectKeyIDs []string) ([]*x509.Certificate, error) { + var x509Authorities []*x509.Certificate + for _, subjectKeyID := range subjectKeyIDs { + x509Authority, err := getX509Authority(bundle, subjectKeyID) + if err != nil { + return nil, err + } + + x509Authorities = append(x509Authorities, x509Authority) + } + + return x509Authorities, nil +} + +func getX509Authority(bundle *spiffebundle.Bundle, subjectKeyID string) (*x509.Certificate, error) { + for _, x509Authority := range bundle.X509Authorities() { + authoritySKID := x509util.SubjectKeyIDToString(x509Authority.SubjectKeyId) + if authoritySKID == subjectKeyID { + return x509Authority, nil + } + } + + return nil, fmt.Errorf("no X.509 authority found with SubjectKeyID %q", subjectKeyID) +} + func cloneBundle(b *common.Bundle) *common.Bundle { return proto.Clone(b).(*common.Bundle) } diff --git a/pkg/common/bundleutil/bundle_test.go b/pkg/common/bundleutil/bundle_test.go index e8d2229547..51a19e4c17 100644 --- a/pkg/common/bundleutil/bundle_test.go +++ b/pkg/common/bundleutil/bundle_test.go @@ -16,6 +16,7 @@ import ( testlog "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" + "github.com/spiffe/spire/pkg/common/x509util" "github.com/spiffe/spire/proto/spire/common" "github.com/spiffe/spire/test/spiretest" "github.com/spiffe/spire/test/testca" @@ -153,6 +154,42 @@ func TestCommonBundleFromProto(t *testing.T) { }, }, }, + { + name: "tainted authority", + bundle: &types.Bundle{ + TrustDomain: td.Name(), + RefreshHint: 10, + X509Authorities: []*types.X509Certificate{ + { + Asn1: rootCA.Raw, + Tainted: true, + }, + }, + JwtAuthorities: []*types.JWTKey{ + { + PublicKey: pkixBytes, + KeyId: "key-id-1", + ExpiresAt: 1590514224, + Tainted: true, + }, + }, + SequenceNumber: 42, + }, + expectBundle: &common.Bundle{ + TrustDomainId: td.IDString(), + RefreshHint: 10, + SequenceNumber: 42, + RootCas: []*common.Certificate{{DerBytes: rootCA.Raw, TaintedKey: true}}, + JwtSigningKeys: []*common.PublicKey{ + { + PkixBytes: pkixBytes, + Kid: "key-id-1", + NotAfter: 1590514224, + TaintedKey: true, + }, + }, + }, + }, { name: "Empty key ID", bundle: &types.Bundle{ @@ -362,6 +399,38 @@ func TestSPIFFEBundleFromProto(t *testing.T) { } } +func TestFindX509Authorities(t *testing.T) { + td := spiffeid.RequireTrustDomainFromString("example.org") + + skID1 := x509util.SubjectKeyIDToString([]byte("ca1")) + ca1 := &x509.Certificate{ + SubjectKeyId: []byte("ca1"), + } + ca2 := &x509.Certificate{ + SubjectKeyId: []byte("ca2"), + } + skID3 := x509util.SubjectKeyIDToString([]byte("ca3")) + ca3 := &x509.Certificate{ + SubjectKeyId: []byte("ca3"), + } + testBundle := spiffebundle.FromX509Authorities(td, []*x509.Certificate{ca1, ca2, ca3}) + + runTest := func(skIDs []string, expectErr string, expectResp ...*x509.Certificate) { + found, err := FindX509Authorities(testBundle, skIDs) + if expectErr != "" { + require.EqualError(t, err, expectErr) + require.Nil(t, found) + return + } + require.NoError(t, err) + require.Equal(t, expectResp, found) + } + + runTest([]string{skID1}, "", ca1) + runTest([]string{skID1, skID3}, "", ca1, ca3) + runTest([]string{skID1, "foo"}, `no X.509 authority found with SubjectKeyID "foo"`) +} + func createBundle(certs []*x509.Certificate, jwtKeys []*common.PublicKey) *common.Bundle { bundle := BundleProtoFromRootCAs("spiffe://foo", certs) bundle.JwtSigningKeys = jwtKeys diff --git a/pkg/common/nodeutil/node.go b/pkg/common/nodeutil/node.go index f331441bf7..b74a8f7d20 100644 --- a/pkg/common/nodeutil/node.go +++ b/pkg/common/nodeutil/node.go @@ -2,7 +2,6 @@ package nodeutil import ( "errors" - "strings" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/spiffe/spire/proto/spire/common" @@ -20,7 +19,6 @@ var ( shouldShutDown = map[types.PermissionDeniedDetails_Reason]struct{}{ types.PermissionDeniedDetails_AGENT_BANNED: {}, } - unknowAuthorityErr = "x509: certificate signed by unknown authority" ) // IsAgentBanned determines if a given attested node is banned or not. @@ -34,17 +32,6 @@ func ShouldAgentReattest(err error) bool { return isExpectedPermissionDenied(err, shouldReattest) } -// IsUnknownAuthorityError returns tru if the Server returned an unknow authority error when verifying -// presented SVID -func IsUnknownAuthorityError(err error) bool { - if err == nil { - return false - } - - // Since it is an rpc error we are unable to use errors.As since it is not possible to unwrap - return strings.Contains(err.Error(), unknowAuthorityErr) -} - // ShouldAgentShutdown returns true if the Server returned an error worth shutting down the Agent func ShouldAgentShutdown(err error) bool { return isExpectedPermissionDenied(err, shouldShutDown) diff --git a/pkg/common/nodeutil/node_test.go b/pkg/common/nodeutil/node_test.go index 2d4b33ddc5..730c2d5ae1 100644 --- a/pkg/common/nodeutil/node_test.go +++ b/pkg/common/nodeutil/node_test.go @@ -1,16 +1,12 @@ package nodeutil_test import ( - "errors" "fmt" "testing" - "github.com/spiffe/go-spiffe/v2/spiffeid" - "github.com/spiffe/go-spiffe/v2/svid/x509svid" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/spiffe/spire/pkg/common/nodeutil" "github.com/spiffe/spire/proto/spire/common" - "github.com/spiffe/spire/test/testca" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -51,29 +47,6 @@ func TestShouldAgentReattest(t *testing.T) { require.False(t, nodeutil.ShouldAgentReattest(getError(t, codes.PermissionDenied, nil))) } -func TestIsUnknownAuthority(t *testing.T) { - t.Run("no error provided", func(t *testing.T) { - require.False(t, nodeutil.IsUnknownAuthorityError(nil)) - }) - - t.Run("unexpected error", func(t *testing.T) { - require.False(t, nodeutil.IsUnknownAuthorityError(errors.New("oh no"))) - }) - - t.Run("unknown authority err", func(t *testing.T) { - // Create two bundles with same TD and an SVID that is signed by one of them - ca := testca.New(t, spiffeid.RequireTrustDomainFromString("test.td")) - ca2 := testca.New(t, spiffeid.RequireTrustDomainFromString("test.td")) - svid := ca2.CreateX509SVID(spiffeid.RequireFromString("spiffe://test.td/w1")) - - // Verify must fail - _, _, err := x509svid.Verify(svid.Certificates, ca.X509Bundle()) - require.Error(t, err) - - require.True(t, nodeutil.IsUnknownAuthorityError(err)) - }) -} - func TestShouldAgentShutdown(t *testing.T) { agentExpired := &types.PermissionDeniedDetails{ Reason: types.PermissionDeniedDetails_AGENT_EXPIRED, diff --git a/pkg/common/telemetry/agent/manager.go b/pkg/common/telemetry/agent/manager.go index bd17d420df..b30105a397 100644 --- a/pkg/common/telemetry/agent/manager.go +++ b/pkg/common/telemetry/agent/manager.go @@ -46,6 +46,16 @@ func AddCacheManagerOutdatedSVIDsSample(m telemetry.Metrics, cacheType string, c m.AddSample(key, count) } +// AddCacheManagerTaintedSVIDsSample count of tainted SVIDs according to +// agent cache manager +func AddCacheManagerTaintedSVIDsSample(m telemetry.Metrics, cacheType string, count float32) { + key := []string{telemetry.CacheManager, cacheType, telemetry.TaintedSVIDs} + if cacheType != "" { + key = append(key, cacheType) + } + m.AddSample(key, count) +} + // End Add Samples func SetSyncStats(m telemetry.Metrics, stats client.SyncStats) { diff --git a/pkg/common/telemetry/names.go b/pkg/common/telemetry/names.go index f88341b364..165e41f2f8 100644 --- a/pkg/common/telemetry/names.go +++ b/pkg/common/telemetry/names.go @@ -547,6 +547,9 @@ const ( // SubjectKeyID tags a certificate subject key ID SubjectKeyID = "subject_key_id" + // SubjectKeyIDs tags a list of subject key ID + SubjectKeyIDs = "subject_key_ids" + // SVIDMapSize is the gauge key for the size of the LRU cache SVID map SVIDMapSize = "lru_cache_svid_map_size" @@ -777,6 +780,9 @@ const ( // RegistrationManager functionality related to a registration manager RegistrationManager = "registration_manager" + // TaintedSVIDs tags tainted SVID count/list + TaintedSVIDs = "tainted_svids" + // Telemetry tags a telemetry module Telemetry = "telemetry" @@ -915,6 +921,9 @@ const ( // PushJWTKeyUpstream functionality related to pushing a public JWT Key to an upstream server. PushJWTKeyUpstream = "push_jwtkey_upstream" + // ProcessTaintedSVIDs functionality related to processing tainted SVIDs. + ProcessTaintedSVIDs = "process_tainted_svids" + // SDSAPI functionality related to SDS; should be used with other tags // to add clarity SDSAPI = "sds_api" diff --git a/pkg/common/x509util/cert.go b/pkg/common/x509util/cert.go index 28ce5f960e..2cfb4ba216 100644 --- a/pkg/common/x509util/cert.go +++ b/pkg/common/x509util/cert.go @@ -4,10 +4,16 @@ import ( "crypto" "crypto/rand" "crypto/x509" + "fmt" + "strings" "github.com/spiffe/spire/pkg/common/cryptoutil" ) +const ( + unknowAuthorityErr = "x509: certificate signed by unknown authority" +) + func CreateCertificate(template, parent *x509.Certificate, pub, priv any) (*x509.Certificate, error) { certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pub, priv) if err != nil { @@ -72,3 +78,45 @@ func RawCertsFromCertificates(certs []*x509.Certificate) [][]byte { } return rawCerts } + +// IsUnknownAuthorityError returns tru if the Server returned an unknow authority error when verifying +// presented SVID +func IsUnknownAuthorityError(err error) bool { + if err == nil { + return false + } + + // Since it is an rpc error we are unable to use errors.As since it is not possible to unwrap + return strings.Contains(err.Error(), unknowAuthorityErr) +} + +// IsSignedByRoot checks if the provided certificate chain is signed by one of the specified root CAs. +func IsSignedByRoot(chain []*x509.Certificate, rootCAs []*x509.Certificate) (bool, error) { + if len(chain) == 0 { + return false, nil + } + rootPool := x509.NewCertPool() + for _, x509Authority := range rootCAs { + rootPool.AddCert(x509Authority) + } + + intermediatePool := x509.NewCertPool() + for _, intermediateCA := range chain[1:] { + intermediatePool.AddCert(intermediateCA) + } + + // Verify certificate chain, using tainted authorities as root + _, err := chain[0].Verify(x509.VerifyOptions{ + Intermediates: intermediatePool, + Roots: rootPool, + }) + if err == nil { + return true, nil + } + + if IsUnknownAuthorityError(err) { + return false, nil + } + + return false, fmt.Errorf("failed to verify certificate chain: %w", err) +} diff --git a/pkg/common/x509util/cert_test.go b/pkg/common/x509util/cert_test.go new file mode 100644 index 0000000000..72cae6b90a --- /dev/null +++ b/pkg/common/x509util/cert_test.go @@ -0,0 +1,71 @@ +package x509util_test + +import ( + "crypto/x509" + "errors" + "testing" + + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/go-spiffe/v2/svid/x509svid" + "github.com/spiffe/spire/pkg/common/x509util" + "github.com/spiffe/spire/test/testca" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsUnknownAuthority(t *testing.T) { + t.Run("no error provided", func(t *testing.T) { + require.False(t, x509util.IsUnknownAuthorityError(nil)) + }) + + t.Run("unexpected error", func(t *testing.T) { + require.False(t, x509util.IsUnknownAuthorityError(errors.New("oh no"))) + }) + + t.Run("unknown authority err", func(t *testing.T) { + // Create two bundles with same TD and an SVID that is signed by one of them + ca := testca.New(t, spiffeid.RequireTrustDomainFromString("test.td")) + ca2 := testca.New(t, spiffeid.RequireTrustDomainFromString("test.td")) + svid := ca2.CreateX509SVID(spiffeid.RequireFromString("spiffe://test.td/w1")) + + // Verify must fail + _, _, err := x509svid.Verify(svid.Certificates, ca.X509Bundle()) + require.Error(t, err) + + require.True(t, x509util.IsUnknownAuthorityError(err)) + }) +} + +func TestIsSignedByRoot(t *testing.T) { + td := spiffeid.RequireTrustDomainFromString("example.org") + ca1 := testca.New(t, td) + intermediate := ca1.ChildCA(testca.WithID(td.ID())) + svid1 := intermediate.CreateX509SVID(spiffeid.RequireFromPath(td, "/w1")) + + ca2 := testca.New(t, td) + svid2 := ca2.CreateX509SVID(spiffeid.RequireFromPath(td, "/w2")) + + invalidCertificate := []*x509.Certificate{{Raw: []byte("invalid")}} + + testSignedByRoot := func(t *testing.T, chain []*x509.Certificate, rootCAs []*x509.Certificate, expect bool, expectError string) { + isSigned, err := x509util.IsSignedByRoot(chain, rootCAs) + if expect { + assert.True(t, isSigned, "Expected chain to be signed by root") + } else { + assert.False(t, isSigned, "Expected chain NOT to be signed by root") + } + if expectError != "" { + assert.ErrorContains(t, err, expectError) + } else { + assert.NoError(t, err) + } + } + + testSignedByRoot(t, svid1.Certificates, ca1.X509Authorities(), true, "") + testSignedByRoot(t, svid2.Certificates, ca2.X509Authorities(), true, "") + testSignedByRoot(t, svid2.Certificates, ca1.X509Authorities(), false, "") + testSignedByRoot(t, svid1.Certificates, ca2.X509Authorities(), false, "") + testSignedByRoot(t, nil, ca2.X509Authorities(), false, "") + testSignedByRoot(t, svid1.Certificates, nil, false, "") + testSignedByRoot(t, invalidCertificate, ca1.X509Authorities(), false, "failed to verify certificate chain: x509: certificate has expired or is not yet valid") +} diff --git a/pkg/server/api/localauthority/v1/service.go b/pkg/server/api/localauthority/v1/service.go index 0c69411f2f..8faa7b3432 100644 --- a/pkg/server/api/localauthority/v1/service.go +++ b/pkg/server/api/localauthority/v1/service.go @@ -407,6 +407,10 @@ func (s *Service) TaintX509UpstreamAuthority(ctx context.Context, req *localauth return nil, api.MakeErr(log, codes.Internal, "failed to taint upstream authority", err) } + if err := s.ca.NotifyTaintedX509Authority(ctx, subjectKeyIDRequest); err != nil { + return nil, api.MakeErr(log, codes.Internal, "failed to notify tainted authority", err) + } + rpccontext.AuditRPC(ctx) log.Info("X.509 upstream authority tainted successfully") diff --git a/pkg/server/ca/manager/manager.go b/pkg/server/ca/manager/manager.go index e67530ef13..caa800e1c5 100644 --- a/pkg/server/ca/manager/manager.go +++ b/pkg/server/ca/manager/manager.go @@ -195,8 +195,8 @@ func (m *Manager) Close() { } } -func (m *Manager) NotifyTaintedX509Authority(ctx context.Context, authoirtyID string) error { - taintedAuthority, err := m.fetchRootCAByAuthorityID(ctx, authoirtyID) +func (m *Manager) NotifyTaintedX509Authority(ctx context.Context, authorityID string) error { + taintedAuthority, err := m.fetchRootCAByAuthorityID(ctx, authorityID) if err != nil { return err } diff --git a/pkg/server/svid/rotator.go b/pkg/server/svid/rotator.go index fc62dbaff8..169c16988d 100644 --- a/pkg/server/svid/rotator.go +++ b/pkg/server/svid/rotator.go @@ -49,7 +49,9 @@ func (r *Rotator) Interval() time.Duration { } func (r *Rotator) triggerTaintedReceived(tainted bool) { - r.taintedReceived <- tainted + if r.taintedReceived != nil { + r.taintedReceived <- tainted + } } // Run starts a ticker which monitors the server SVID