From 1d0ac019e94c720b6f6f5b60b73b8177866ff9b4 Mon Sep 17 00:00:00 2001 From: ianhundere <138915+ianhundere@users.noreply.github.com> Date: Sat, 14 Dec 2024 13:38:04 -0500 Subject: [PATCH] fix: improves kms key validation across providers. Signed-off-by: ianhundere <138915+ianhundere@users.noreply.github.com> --- .../certificate_maker_test.go | 74 +- go.mod | 9 +- go.sum | 12 +- pkg/certmaker/certmaker.go | 44 +- pkg/certmaker/certmaker_test.go | 995 ++++++++++++++---- pkg/certmaker/template_test.go | 141 ++- 6 files changed, 971 insertions(+), 304 deletions(-) diff --git a/cmd/certificate_maker/certificate_maker_test.go b/cmd/certificate_maker/certificate_maker_test.go index e3056b526..89f9a05f7 100644 --- a/cmd/certificate_maker/certificate_maker_test.go +++ b/cmd/certificate_maker/certificate_maker_test.go @@ -18,11 +18,10 @@ package main import ( "os" "path/filepath" + "strings" "testing" "github.com/spf13/cobra" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestGetConfigValue(t *testing.T) { @@ -84,22 +83,27 @@ func TestGetConfigValue(t *testing.T) { defer os.Unsetenv(tt.envVar) } got := getConfigValue(tt.flagValue, tt.envVar) - assert.Equal(t, tt.want, got) + if got != tt.want { + t.Errorf("got %v, want %v", got, tt.want) + } }) } } func TestInitLogger(t *testing.T) { logger := initLogger() - require.NotNil(t, logger) + if logger == nil { + t.Error("logger should not be nil") + } } func TestRunCreate(t *testing.T) { tmpDir, err := os.MkdirTemp("", "cert-test-*") - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } defer os.RemoveAll(tmpDir) - // Create test template files rootTemplate := `{ "subject": { "commonName": "Test Root CA" @@ -132,9 +136,13 @@ func TestRunCreate(t *testing.T) { rootTmplPath := filepath.Join(tmpDir, "root-template.json") leafTmplPath := filepath.Join(tmpDir, "leaf-template.json") err = os.WriteFile(rootTmplPath, []byte(rootTemplate), 0600) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } err = os.WriteFile(leafTmplPath, []byte(leafTemplate), 0600) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } tests := []struct { name string @@ -236,7 +244,6 @@ func TestRunCreate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Set environment variables for k, v := range tt.envVars { os.Setenv(k, v) defer os.Unsetenv(k) @@ -247,7 +254,6 @@ func TestRunCreate(t *testing.T) { RunE: runCreate, } - // Add all flags that runCreate expects cmd.Flags().StringVar(&kmsType, "kms-type", "", "KMS provider type (awskms, gcpkms, azurekms)") cmd.Flags().StringVar(&kmsRegion, "aws-region", "", "AWS KMS region") cmd.Flags().StringVar(&kmsKeyID, "kms-key-id", "", "KMS key identifier") @@ -267,17 +273,19 @@ func TestRunCreate(t *testing.T) { err := cmd.Execute() if tt.wantError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errMsg) + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("error %q should contain %q", err.Error(), tt.errMsg) + } } else { - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } } }) } } func TestCreateCommand(t *testing.T) { - // Create a test command cmd := &cobra.Command{ Use: "test", RunE: func(_ *cobra.Command, _ []string) error { @@ -285,40 +293,50 @@ func TestCreateCommand(t *testing.T) { }, } - // Add flags cmd.Flags().StringVar(&kmsType, "kms-type", "", "KMS type") cmd.Flags().StringVar(&kmsRegion, "aws-region", "", "AWS KMS region") cmd.Flags().StringVar(&rootKeyID, "root-key-id", "", "Root key ID") cmd.Flags().StringVar(&leafKeyID, "leaf-key-id", "", "Leaf key ID") - // Test missing required flags err := cmd.Execute() - require.NoError(t, err) // No required flags set yet + if err != nil { + t.Fatalf("unexpected error: %v", err) + } - // Test flag parsing err = cmd.ParseFlags([]string{ "--kms-type", "awskms", "--aws-region", "us-west-2", "--root-key-id", "arn:aws:kms:us-west-2:123456789012:key/1234abcd-12ab-34cd-56ef-1234567890ab", "--leaf-key-id", "arn:aws:kms:us-west-2:123456789012:key/9876fedc-ba98-7654-3210-fedcba987654", }) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } - // Verify flag values - assert.Equal(t, "awskms", kmsType) - assert.Equal(t, "us-west-2", kmsRegion) - assert.Equal(t, "arn:aws:kms:us-west-2:123456789012:key/1234abcd-12ab-34cd-56ef-1234567890ab", rootKeyID) - assert.Equal(t, "arn:aws:kms:us-west-2:123456789012:key/9876fedc-ba98-7654-3210-fedcba987654", leafKeyID) + if kmsType != "awskms" { + t.Errorf("got kmsType %v, want awskms", kmsType) + } + if kmsRegion != "us-west-2" { + t.Errorf("got kmsRegion %v, want us-west-2", kmsRegion) + } + if rootKeyID != "arn:aws:kms:us-west-2:123456789012:key/1234abcd-12ab-34cd-56ef-1234567890ab" { + t.Errorf("got rootKeyID %v, want arn:aws:kms:us-west-2:123456789012:key/1234abcd-12ab-34cd-56ef-1234567890ab", rootKeyID) + } + if leafKeyID != "arn:aws:kms:us-west-2:123456789012:key/9876fedc-ba98-7654-3210-fedcba987654" { + t.Errorf("got leafKeyID %v, want arn:aws:kms:us-west-2:123456789012:key/9876fedc-ba98-7654-3210-fedcba987654", leafKeyID) + } } func TestRootCommand(t *testing.T) { - // Test help output rootCmd.SetArgs([]string{"--help"}) err := rootCmd.Execute() - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } - // Test unknown command rootCmd.SetArgs([]string{"unknown"}) err = rootCmd.Execute() - require.Error(t, err) + if err == nil { + t.Error("expected error for unknown command, got nil") + } } diff --git a/go.mod b/go.mod index 4f6b3cdb1..88ab3969a 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,6 @@ require ( github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.19.0 github.com/spiffe/go-spiffe/v2 v2.4.0 - github.com/stretchr/testify v1.10.0 github.com/tink-crypto/tink-go-awskms/v2 v2.1.0 github.com/tink-crypto/tink-go-gcpkms/v2 v2.2.0 github.com/tink-crypto/tink-go/v2 v2.2.0 @@ -62,9 +61,9 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys v0.10.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 // indirect - github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.1.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0 // indirect - github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect + github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.0 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.3.1 // indirect github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.3.0 // indirect github.com/Masterminds/sprig/v3 v3.3.0 // indirect @@ -89,7 +88,6 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chainguard-dev/clog v1.5.1 // indirect github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be // indirect - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-jose/go-jose/v3 v3.0.3 // indirect github.com/go-logr/logr v1.4.2 // indirect @@ -128,7 +126,6 @@ require ( github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/ryanuber/go-glob v1.0.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect diff --git a/go.sum b/go.sum index 92961219e..0bf43c1ab 100644 --- a/go.sum +++ b/go.sum @@ -35,10 +35,10 @@ github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys v0.10.0 h1:m/sWOGCREuSBqg2 github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys v0.10.0/go.mod h1:Pu5Zksi2KrU7LPbZbNINx6fuVrUp/ffvpxdDj+i8LeE= github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 h1:FbH3BbSb4bvGluTesZZ+ttN/MDsnMmQP36OSnDuSXqw= github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1/go.mod h1:9V2j0jn9jDEkCkv8w/bKTNppX/d0FVA1ud77xCIP4KA= -github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.1.0 h1:DRiANoJTiW6obBQe3SqZizkuV1PEgfiiGivmVocDy64= -github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.1.0/go.mod h1:qLIye2hwb/ZouqhpSD9Zn3SJipvpEnz1Ywl3VUk9Y0s= -github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0 h1:D3occbWoio4EBLkbkevetNMAVX197GkzbUMtqjGWn80= -github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0/go.mod h1:bTSOgj05NGRuHHhQwAdPnYr9TOdNmKlZTgGLL6nyAdI= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.0 h1:7rKG7UmnrxX4N53TFhkYqjc+kVUZuw0fL8I3Fh+Ld9E= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.0/go.mod h1:Wjo+24QJVhhl/L7jy6w9yzFF2yDOf3cKECAa8ecf9vE= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.0 h1:eXnN9kaS8TiDwXjoie3hMRLuwdUBUMW9KRgOqB3mCaw= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.0/go.mod h1:XIpam8wumeZ5rVMuhdDQLMfIPDf1WO3IzrCRO3e3e3o= github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.3.1 h1:gUDtaZk8heteyfdmv+pcfHvhR9llnh7c7GMwZ8RVG04= @@ -154,8 +154,6 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= -github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= @@ -408,6 +406,8 @@ go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= diff --git a/pkg/certmaker/certmaker.go b/pkg/certmaker/certmaker.go index cb4c6ec7b..cf5e03363 100644 --- a/pkg/certmaker/certmaker.go +++ b/pkg/certmaker/certmaker.go @@ -229,7 +229,19 @@ func ValidateKMSConfig(config KMSConfig) error { if keyID == "" { return nil } - if !strings.HasPrefix(keyID, "arn:aws:kms:") && !strings.HasPrefix(keyID, "alias/") { + if strings.HasPrefix(keyID, "arn:aws:kms:") { + parts := strings.Split(keyID, ":") + if len(parts) < 6 { + return fmt.Errorf("invalid AWS KMS ARN format for %s", keyType) + } + if parts[3] != config.Region { + return fmt.Errorf("region in ARN (%s) does not match configured region (%s)", parts[3], config.Region) + } + } else if strings.HasPrefix(keyID, "alias/") { + if strings.TrimPrefix(keyID, "alias/") == "" { + return fmt.Errorf("alias name cannot be empty for %s", keyType) + } + } else { return fmt.Errorf("awskms %s must start with 'arn:aws:kms:' or 'alias/'", keyType) } return nil @@ -250,11 +262,20 @@ func ValidateKMSConfig(config KMSConfig) error { if keyID == "" { return nil } - if !strings.HasPrefix(keyID, "projects/") { - return fmt.Errorf("gcpkms %s must start with 'projects/'", keyType) + requiredComponents := []struct { + component string + message string + }{ + {"projects/", "must start with 'projects/'"}, + {"/locations/", "must contain '/locations/'"}, + {"/keyRings/", "must contain '/keyRings/'"}, + {"/cryptoKeys/", "must contain '/cryptoKeys/'"}, + {"/cryptoKeyVersions/", "must contain '/cryptoKeyVersions/'"}, } - if !strings.Contains(keyID, "/locations/") || !strings.Contains(keyID, "/keyRings/") { - return fmt.Errorf("invalid gcpkms key format for %s: %s", keyType, keyID) + for _, req := range requiredComponents { + if !strings.Contains(keyID, req.component) { + return fmt.Errorf("gcpkms %s %s", keyType, req.message) + } } return nil } @@ -280,12 +301,19 @@ func ValidateKMSConfig(config KMSConfig) error { if keyID == "" { return nil } - // Validate format: azurekms:name=;vault= if !strings.HasPrefix(keyID, "azurekms:name=") { return fmt.Errorf("azurekms %s must start with 'azurekms:name='", keyType) } - if !strings.Contains(keyID, ";vault=") { - return fmt.Errorf("vault name is required for Azure Key Vault") + nameStart := strings.Index(keyID, "name=") + 5 + vaultIndex := strings.Index(keyID, ";vault=") + if vaultIndex == -1 { + return fmt.Errorf("azurekms %s must contain ';vault=' parameter", keyType) + } + if strings.TrimSpace(keyID[nameStart:vaultIndex]) == "" { + return fmt.Errorf("key name cannot be empty for %s", keyType) + } + if strings.TrimSpace(keyID[vaultIndex+7:]) == "" { + return fmt.Errorf("vault name cannot be empty for %s", keyType) } return nil } diff --git a/pkg/certmaker/certmaker_test.go b/pkg/certmaker/certmaker_test.go index 87a778ea9..2150e18a3 100644 --- a/pkg/certmaker/certmaker_test.go +++ b/pkg/certmaker/certmaker_test.go @@ -21,6 +21,7 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/pem" @@ -28,58 +29,91 @@ import ( "math/big" "os" "path/filepath" + "strings" "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "go.step.sm/crypto/kms/apiv1" ) -// mockKMS provides an in-memory KMS for testing -type mockKMS struct { - keys map[string]crypto.Signer +type mockKMSProvider struct { + name string + keys map[string]*ecdsa.PrivateKey + signers map[string]crypto.Signer } -func newMockKMS() *mockKMS { - keys := make(map[string]crypto.Signer) - // Create test keys - for _, id := range []string{"root-key", "intermediate-key", "leaf-key"} { - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - panic(err) +func newMockKMSProvider() *mockKMSProvider { + m := &mockKMSProvider{ + name: "test", + keys: make(map[string]*ecdsa.PrivateKey), + signers: make(map[string]crypto.Signer), + } + + rootKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + intermediateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + leafKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + m.keys["root-key"] = rootKey + m.keys["intermediate-key"] = intermediateKey + m.keys["leaf-key"] = leafKey + + return m +} + +func (m *mockKMSProvider) CreateKey(*apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) { + return nil, fmt.Errorf("not implemented") +} + +func (m *mockKMSProvider) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) { + keyName := req.SigningKey + if strings.HasPrefix(keyName, "arn:aws:kms:") { + parts := strings.Split(keyName, "/") + if len(parts) > 0 { + keyName = parts[len(parts)-1] } - keys[id] = priv } - return &mockKMS{keys: keys} + + key, ok := m.keys[keyName] + if !ok { + return nil, fmt.Errorf("key not found: %s", req.SigningKey) + } + m.signers[keyName] = key + return key, nil } -func (m *mockKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) { - if signer, ok := m.keys[req.SigningKey]; ok { - return signer, nil +func (m *mockKMSProvider) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) { + key, ok := m.keys[req.Name] + if !ok { + return nil, fmt.Errorf("key not found: %s", req.Name) } - return nil, fmt.Errorf("key not found: %s", req.SigningKey) + return key.Public(), nil } -func (m *mockKMS) CreateKey(_ *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) { - return nil, fmt.Errorf("CreateKey is not supported in mockKMS") +func (m *mockKMSProvider) Close() error { + return nil } -func (m *mockKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) { - if signer, ok := m.keys[req.Name]; ok { - return signer.Public(), nil - } - return nil, fmt.Errorf("key not found: %s", req.Name) +type mockInvalidKMS struct { + apiv1.KeyManager } -func (m *mockKMS) Close() error { +func (m *mockInvalidKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) { + return nil, fmt.Errorf("invalid KMS configuration: unsupported KMS type") +} + +func (m *mockInvalidKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) { + return nil, fmt.Errorf("invalid KMS configuration: unsupported KMS type") +} + +func (m *mockInvalidKMS) Close() error { return nil } -// TestParseTemplate tests JSON template parsing func TestParseTemplate(t *testing.T) { tmpFile, err := os.CreateTemp("", "cert-template-*.json") - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } defer os.Remove(tmpFile.Name()) templateContent := `{ @@ -105,162 +139,463 @@ func TestParseTemplate(t *testing.T) { }` err = os.WriteFile(tmpFile.Name(), []byte(templateContent), 0600) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } tmpl, err := ParseTemplate(tmpFile.Name(), nil) - require.NoError(t, err) - assert.Equal(t, "Test CA", tmpl.Subject.CommonName) - assert.True(t, tmpl.IsCA) - assert.Equal(t, 0, tmpl.MaxPathLen) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tmpl.Subject.CommonName != "Test CA" { + t.Errorf("got %v, want Test CA", tmpl.Subject.CommonName) + } + if !tmpl.IsCA { + t.Errorf("got %v, want true", tmpl.IsCA) + } + if tmpl.MaxPathLen != 0 { + t.Errorf("got %v, want 0", tmpl.MaxPathLen) + } } -// TestCreateCertificates tests certificate chain creation func TestCreateCertificates(t *testing.T) { - rootContent := `{ - "subject": { - "country": ["US"], - "organization": ["Sigstore"], - "organizationalUnit": ["Fulcio Root CA"], - "commonName": "fulcio.sigstore.dev" + tests := []struct { + name string + setup func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) + wantError string + }{ + { + name: "successful certificate creation", + setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { + tmpDir, err := os.MkdirTemp("", "cert-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + rootTemplate := filepath.Join(tmpDir, "root.json") + err = os.WriteFile(rootTemplate, []byte(`{ + "subject": {"commonName": "Test Root CA"}, + "issuer": {"commonName": "Test Root CA"}, + "keyUsage": ["certSign", "crlSign"], + "basicConstraints": {"isCA": true, "maxPathLen": 1}, + "notBefore": "2024-01-01T00:00:00Z", + "notAfter": "2025-01-01T00:00:00Z" + }`), 0600) + if err != nil { + t.Fatalf("Failed to write root template: %v", err) + } + + leafTemplate := filepath.Join(tmpDir, "leaf.json") + err = os.WriteFile(leafTemplate, []byte(`{ + "subject": {"commonName": "Test Leaf"}, + "keyUsage": ["digitalSignature"], + "basicConstraints": {"isCA": false}, + "extKeyUsage": ["CodeSigning"], + "notBefore": "2024-01-01T00:00:00Z", + "notAfter": "2025-01-01T00:00:00Z" + }`), 0600) + if err != nil { + t.Fatalf("Failed to write leaf template: %v", err) + } + + config := KMSConfig{ + Type: "awskms", + Region: "us-west-2", + RootKeyID: "arn:aws:kms:us-west-2:123456789012:key/root-key", + LeafKeyID: "arn:aws:kms:us-west-2:123456789012:key/leaf-key", + } + + return tmpDir, config, newMockKMSProvider() + }, }, - "issuer": { - "commonName": "fulcio.sigstore.dev" + { + name: "invalid template path", + setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { + tmpDir, err := os.MkdirTemp("", "cert-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + config := KMSConfig{ + Type: "awskms", + Region: "us-west-2", + RootKeyID: "arn:aws:kms:us-west-2:123456789012:key/root-key", + LeafKeyID: "arn:aws:kms:us-west-2:123456789012:key/leaf-key", + } + + return tmpDir, config, newMockKMSProvider() + }, + wantError: "error parsing root template", }, - "notBefore": "2024-01-01T00:00:00Z", - "notAfter": "2034-01-01T00:00:00Z", - "basicConstraints": { - "isCA": true, - "maxPathLen": 1 + { + name: "invalid KMS configuration", + setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { + tmpDir, err := os.MkdirTemp("", "cert-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + rootTemplate := filepath.Join(tmpDir, "root.json") + err = os.WriteFile(rootTemplate, []byte(`{ + "subject": {"commonName": "Test Root CA"}, + "issuer": {"commonName": "Test Root CA"}, + "keyUsage": ["certSign", "crlSign"], + "basicConstraints": {"isCA": true, "maxPathLen": 1}, + "notBefore": "2024-01-01T00:00:00Z", + "notAfter": "2025-01-01T00:00:00Z" + }`), 0600) + if err != nil { + t.Fatalf("Failed to write root template: %v", err) + } + + leafTemplate := filepath.Join(tmpDir, "leaf.json") + err = os.WriteFile(leafTemplate, []byte(`{ + "subject": {"commonName": "Test Leaf"}, + "keyUsage": ["digitalSignature"], + "basicConstraints": {"isCA": false}, + "extKeyUsage": ["CodeSigning"], + "notBefore": "2024-01-01T00:00:00Z", + "notAfter": "2025-01-01T00:00:00Z" + }`), 0600) + if err != nil { + t.Fatalf("Failed to write leaf template: %v", err) + } + + config := KMSConfig{ + Type: "invalid", + RootKeyID: "test-key", + LeafKeyID: "leaf-key", + } + + return tmpDir, config, &mockInvalidKMS{} + }, + wantError: "invalid KMS configuration: unsupported KMS type", }, - "keyUsage": [ - "certSign", - "crlSign" - ] - }` - - leafContent := `{ - "subject": { - "country": ["US"], - "organization": ["Sigstore"], - "organizationalUnit": ["Fulcio"], - "commonName": "fulcio.sigstore.dev" + { + name: "with intermediate certificate", + setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { + tmpDir, err := os.MkdirTemp("", "cert-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + rootTemplate := filepath.Join(tmpDir, "root.json") + err = os.WriteFile(rootTemplate, []byte(`{ + "subject": {"commonName": "Test Root CA"}, + "issuer": {"commonName": "Test Root CA"}, + "keyUsage": ["certSign", "crlSign"], + "basicConstraints": {"isCA": true, "maxPathLen": 1}, + "notBefore": "2024-01-01T00:00:00Z", + "notAfter": "2025-01-01T00:00:00Z" + }`), 0600) + if err != nil { + t.Fatalf("Failed to write root template: %v", err) + } + + intermediateTemplate := filepath.Join(tmpDir, "intermediate.json") + err = os.WriteFile(intermediateTemplate, []byte(`{ + "subject": {"commonName": "Test Intermediate CA"}, + "keyUsage": ["certSign", "crlSign"], + "basicConstraints": {"isCA": true, "maxPathLen": 0}, + "notBefore": "2024-01-01T00:00:00Z", + "notAfter": "2025-01-01T00:00:00Z" + }`), 0600) + if err != nil { + t.Fatalf("Failed to write intermediate template: %v", err) + } + + leafTemplate := filepath.Join(tmpDir, "leaf.json") + err = os.WriteFile(leafTemplate, []byte(`{ + "subject": {"commonName": "Test Leaf"}, + "keyUsage": ["digitalSignature"], + "basicConstraints": {"isCA": false}, + "extKeyUsage": ["CodeSigning"], + "notBefore": "2024-01-01T00:00:00Z", + "notAfter": "2025-01-01T00:00:00Z" + }`), 0600) + if err != nil { + t.Fatalf("Failed to write leaf template: %v", err) + } + + config := KMSConfig{ + Type: "awskms", + Region: "us-west-2", + RootKeyID: "arn:aws:kms:us-west-2:123456789012:key/root-key", + IntermediateKeyID: "arn:aws:kms:us-west-2:123456789012:key/intermediate-key", + LeafKeyID: "arn:aws:kms:us-west-2:123456789012:key/leaf-key", + } + + return tmpDir, config, newMockKMSProvider() + }, }, - "issuer": { - "commonName": "fulcio.sigstore.dev" + { + name: "invalid intermediate template", + setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { + tmpDir, err := os.MkdirTemp("", "cert-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + rootTemplate := filepath.Join(tmpDir, "root.json") + err = os.WriteFile(rootTemplate, []byte(`{ + "subject": {"commonName": "Test Root CA"}, + "issuer": {"commonName": "Test Root CA"}, + "keyUsage": ["certSign", "crlSign"], + "basicConstraints": {"isCA": true, "maxPathLen": 1}, + "notBefore": "2024-01-01T00:00:00Z", + "notAfter": "2025-01-01T00:00:00Z" + }`), 0600) + if err != nil { + t.Fatalf("Failed to write root template: %v", err) + } + + leafTemplate := filepath.Join(tmpDir, "leaf.json") + err = os.WriteFile(leafTemplate, []byte(`{ + "subject": {"commonName": "Test Leaf"}, + "keyUsage": ["digitalSignature"], + "basicConstraints": {"isCA": false}, + "extKeyUsage": ["CodeSigning"], + "notBefore": "2024-01-01T00:00:00Z", + "notAfter": "2025-01-01T00:00:00Z" + }`), 0600) + if err != nil { + t.Fatalf("Failed to write leaf template: %v", err) + } + + config := KMSConfig{ + Type: "awskms", + Region: "us-west-2", + RootKeyID: "arn:aws:kms:us-west-2:123456789012:key/root-key", + IntermediateKeyID: "arn:aws:kms:us-west-2:123456789012:key/intermediate-key", + LeafKeyID: "arn:aws:kms:us-west-2:123456789012:key/leaf-key", + } + + return tmpDir, config, newMockKMSProvider() + }, + wantError: "error parsing intermediate template", }, - "notBefore": "2024-01-01T00:00:00Z", - "notAfter": "2034-01-01T00:00:00Z", - "basicConstraints": { - "isCA": false + { + name: "invalid intermediate key", + setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { + tmpDir, err := os.MkdirTemp("", "cert-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + rootTemplate := filepath.Join(tmpDir, "root.json") + err = os.WriteFile(rootTemplate, []byte(`{ + "subject": {"commonName": "Test Root CA"}, + "issuer": {"commonName": "Test Root CA"}, + "keyUsage": ["certSign", "crlSign"], + "basicConstraints": {"isCA": true, "maxPathLen": 1}, + "notBefore": "2024-01-01T00:00:00Z", + "notAfter": "2025-01-01T00:00:00Z" + }`), 0600) + if err != nil { + t.Fatalf("Failed to write root template: %v", err) + } + + intermediateTemplate := filepath.Join(tmpDir, "intermediate.json") + err = os.WriteFile(intermediateTemplate, []byte(`{ + "subject": {"commonName": "Test Intermediate CA"}, + "keyUsage": ["certSign", "crlSign"], + "basicConstraints": {"isCA": true, "maxPathLen": 0}, + "notBefore": "2024-01-01T00:00:00Z", + "notAfter": "2025-01-01T00:00:00Z" + }`), 0600) + if err != nil { + t.Fatalf("Failed to write intermediate template: %v", err) + } + + leafTemplate := filepath.Join(tmpDir, "leaf.json") + err = os.WriteFile(leafTemplate, []byte(`{ + "subject": {"commonName": "Test Leaf"}, + "keyUsage": ["digitalSignature"], + "basicConstraints": {"isCA": false}, + "extKeyUsage": ["CodeSigning"], + "notBefore": "2024-01-01T00:00:00Z", + "notAfter": "2025-01-01T00:00:00Z" + }`), 0600) + if err != nil { + t.Fatalf("Failed to write leaf template: %v", err) + } + + config := KMSConfig{ + Type: "awskms", + Region: "us-west-2", + RootKeyID: "arn:aws:kms:us-west-2:123456789012:key/root-key", + IntermediateKeyID: "arn:aws:kms:us-west-2:123456789012:key/nonexistent-key", + LeafKeyID: "arn:aws:kms:us-west-2:123456789012:key/leaf-key", + } + + return tmpDir, config, newMockKMSProvider() + }, + wantError: "error creating intermediate signer", }, - "keyUsage": [ - "digitalSignature" - ], - "extKeyUsage": [ - "CodeSigning" - ] - }` - - t.Run("Fulcio without intermediate", func(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "cert-test-fulcio-*") - require.NoError(t, err) - t.Cleanup(func() { os.RemoveAll(tmpDir) }) - - km := newMockKMS() - config := KMSConfig{ - Type: "mockkms", - RootKeyID: "root-key", - LeafKeyID: "leaf-key", - Options: make(map[string]string), - } - - rootTmplPath := filepath.Join(tmpDir, "root-template.json") - leafTmplPath := filepath.Join(tmpDir, "leaf-template.json") - rootCertPath := filepath.Join(tmpDir, "root.pem") - leafCertPath := filepath.Join(tmpDir, "leaf.pem") - - err = os.WriteFile(rootTmplPath, []byte(rootContent), 0600) - require.NoError(t, err) - - err = os.WriteFile(leafTmplPath, []byte(leafContent), 0600) - require.NoError(t, err) - - err = CreateCertificates(km, config, - rootTmplPath, leafTmplPath, - rootCertPath, leafCertPath, - "", "", "") - require.NoError(t, err) - - verifyDirectChain(t, rootCertPath, leafCertPath) - }) - - t.Run("Fulcio with intermediate", func(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "cert-test-fulcio-*") - require.NoError(t, err) - t.Cleanup(func() { os.RemoveAll(tmpDir) }) - - intermediateContent := `{ - "subject": { - "country": ["US"], - "organization": ["Sigstore"], - "organizationalUnit": ["Fulcio Intermediate CA"], - "commonName": "fulcio.sigstore.dev" + { + name: "error creating root certificate", + setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { + tmpDir, err := os.MkdirTemp("", "cert-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + rootTemplate := filepath.Join(tmpDir, "root.json") + err = os.WriteFile(rootTemplate, []byte(`{ + "subject": {}, + "issuer": {} + }`), 0600) + if err != nil { + t.Fatalf("Failed to write root template: %v", err) + } + + config := KMSConfig{ + Type: "awskms", + Region: "us-west-2", + RootKeyID: "arn:aws:kms:us-west-2:123456789012:key/root-key", + LeafKeyID: "arn:aws:kms:us-west-2:123456789012:key/leaf-key", + } + + return tmpDir, config, newMockKMSProvider() }, - "issuer": { - "commonName": "fulcio.sigstore.dev" + wantError: "error parsing root template: notBefore and notAfter times must be specified", + }, + { + name: "error creating leaf certificate", + setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { + tmpDir, err := os.MkdirTemp("", "cert-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + rootTemplate := filepath.Join(tmpDir, "root.json") + err = os.WriteFile(rootTemplate, []byte(`{ + "subject": {"commonName": "Test Root CA"}, + "issuer": {"commonName": "Test Root CA"}, + "keyUsage": ["certSign", "crlSign"], + "basicConstraints": {"isCA": true, "maxPathLen": 1}, + "notBefore": "2024-01-01T00:00:00Z", + "notAfter": "2025-01-01T00:00:00Z" + }`), 0600) + if err != nil { + t.Fatalf("Failed to write root template: %v", err) + } + + leafTemplate := filepath.Join(tmpDir, "leaf.json") + err = os.WriteFile(leafTemplate, []byte(`{ + "subject": {}, + "issuer": {} + }`), 0600) + if err != nil { + t.Fatalf("Failed to write leaf template: %v", err) + } + + config := KMSConfig{ + Type: "awskms", + Region: "us-west-2", + RootKeyID: "arn:aws:kms:us-west-2:123456789012:key/root-key", + LeafKeyID: "arn:aws:kms:us-west-2:123456789012:key/leaf-key", + } + + return tmpDir, config, newMockKMSProvider() }, - "notBefore": "2024-01-01T00:00:00Z", - "notAfter": "2034-01-01T00:00:00Z", - "basicConstraints": { - "isCA": true, - "maxPathLen": 0 + wantError: "error parsing leaf template: notBefore and notAfter times must be specified", + }, + { + name: "error writing certificates", + setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { + tmpDir, err := os.MkdirTemp("", "cert-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + rootTemplate := filepath.Join(tmpDir, "root.json") + err = os.WriteFile(rootTemplate, []byte(`{ + "subject": {"commonName": "Test Root CA"}, + "issuer": {"commonName": "Test Root CA"}, + "keyUsage": ["certSign", "crlSign"], + "basicConstraints": {"isCA": true, "maxPathLen": 1}, + "notBefore": "2024-01-01T00:00:00Z", + "notAfter": "2025-01-01T00:00:00Z" + }`), 0600) + if err != nil { + t.Fatalf("Failed to write root template: %v", err) + } + + leafTemplate := filepath.Join(tmpDir, "leaf.json") + err = os.WriteFile(leafTemplate, []byte(`{ + "subject": {"commonName": "Test Leaf"}, + "keyUsage": ["digitalSignature"], + "basicConstraints": {"isCA": false}, + "extKeyUsage": ["CodeSigning"], + "notBefore": "2024-01-01T00:00:00Z", + "notAfter": "2025-01-01T00:00:00Z" + }`), 0600) + if err != nil { + t.Fatalf("Failed to write leaf template: %v", err) + } + + config := KMSConfig{ + Type: "awskms", + Region: "us-west-2", + RootKeyID: "arn:aws:kms:us-west-2:123456789012:key/root-key", + LeafKeyID: "arn:aws:kms:us-west-2:123456789012:key/leaf-key", + } + + outDir := filepath.Join(tmpDir, "out") + err = os.MkdirAll(outDir, 0444) + if err != nil { + t.Fatalf("Failed to create output directory: %v", err) + } + + return tmpDir, config, newMockKMSProvider() }, - "keyUsage": [ - "certSign", - "crlSign" - ] - }` - - km := newMockKMS() - config := KMSConfig{ - Type: "mockkms", - RootKeyID: "root-key", - IntermediateKeyID: "intermediate-key", - LeafKeyID: "leaf-key", - Options: make(map[string]string), - } + wantError: "error writing root certificate", + }, + } - rootTmplPath := filepath.Join(tmpDir, "root-template.json") - leafTmplPath := filepath.Join(tmpDir, "leaf-template.json") - intermediateTmplPath := filepath.Join(tmpDir, "intermediate-template.json") - rootCertPath := filepath.Join(tmpDir, "root.pem") - intermediateCertPath := filepath.Join(tmpDir, "intermediate.pem") - leafCertPath := filepath.Join(tmpDir, "leaf.pem") - - err = os.WriteFile(rootTmplPath, []byte(rootContent), 0600) - require.NoError(t, err) - err = os.WriteFile(intermediateTmplPath, []byte(intermediateContent), 0600) - require.NoError(t, err) - err = os.WriteFile(leafTmplPath, []byte(leafContent), 0600) - require.NoError(t, err) - - err = CreateCertificates(km, config, - rootTmplPath, leafTmplPath, - rootCertPath, leafCertPath, - "intermediate-key", intermediateTmplPath, intermediateCertPath) - require.NoError(t, err) - - verifyIntermediateChain(rootCertPath, intermediateCertPath, leafCertPath) - }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir, config, kms := tt.setup(t) + defer os.RemoveAll(tmpDir) + + outDir := filepath.Join(tmpDir, "out") + err := os.MkdirAll(outDir, 0755) + if err != nil { + t.Fatalf("Failed to create output directory: %v", err) + } + + err = CreateCertificates(kms, config, + filepath.Join(tmpDir, "root.json"), + filepath.Join(tmpDir, "leaf.json"), + filepath.Join(outDir, "root.crt"), + filepath.Join(outDir, "leaf.crt"), + config.IntermediateKeyID, + filepath.Join(tmpDir, "intermediate.json"), + filepath.Join(outDir, "intermediate.crt")) + + if tt.wantError != "" { + if err == nil { + t.Error("Expected error but got none") + } else if !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("Expected error containing %q, got %q", tt.wantError, err.Error()) + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + }) + } } -// TestWriteCertificateToFile tests certificate file writing func TestWriteCertificateToFile(t *testing.T) { tmpDir, err := os.MkdirTemp("", "cert-write-test-*") - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } defer os.RemoveAll(tmpDir) - // Create a test certificate cert := &x509.Certificate{ Subject: pkix.Name{ CommonName: "Test CA", @@ -303,23 +638,32 @@ func TestWriteCertificateToFile(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := WriteCertificateToFile(tt.cert, tt.path) if tt.wantError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errMsg) + if err == nil { + t.Errorf("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("error %q should contain %q", err.Error(), tt.errMsg) + } } else { - require.NoError(t, err) - // Verify the file exists and contains a PEM block + if err != nil { + t.Errorf("unexpected error: %v", err) + } content, err := os.ReadFile(tt.path) - require.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } block, _ := pem.Decode(content) - require.NotNil(t, block) - assert.Equal(t, "CERTIFICATE", block.Type) + if block == nil { + t.Errorf("failed to decode PEM block") + } + if block.Type != "CERTIFICATE" { + t.Errorf("got %v, want CERTIFICATE", block.Type) + } } }) } } func verifyIntermediateChain(rootPath, intermediatePath, leafPath string) error { - // Read certificates rootPEM, err := os.ReadFile(rootPath) if err != nil { return fmt.Errorf("error reading root certificate: %w", err) @@ -333,7 +677,6 @@ func verifyIntermediateChain(rootPath, intermediatePath, leafPath string) error return fmt.Errorf("error reading leaf certificate: %w", err) } - // Parse certificates rootBlock, _ := pem.Decode(rootPEM) if rootBlock == nil { return fmt.Errorf("failed to decode root certificate PEM") @@ -361,14 +704,12 @@ func verifyIntermediateChain(rootPath, intermediatePath, leafPath string) error return fmt.Errorf("error parsing leaf certificate: %w", err) } - // Create certificate pools roots := x509.NewCertPool() roots.AddCert(rootCert) intermediates := x509.NewCertPool() intermediates.AddCert(intermediateCert) - // Verify the chain opts := x509.VerifyOptions{ Roots: roots, Intermediates: intermediates, @@ -390,18 +731,26 @@ func verifyDirectChain(t *testing.T, rootPath, leafPath string) { Roots: rootPool, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageCodeSigning}, }) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } } func loadCertificate(t *testing.T, path string) *x509.Certificate { data, err := os.ReadFile(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } block, _ := pem.Decode(data) - require.NotNil(t, block) + if block == nil { + t.Fatalf("failed to decode PEM block") + } cert, err := x509.ParseCertificate(block.Bytes) - require.NoError(t, err) + if err != nil { + t.Fatalf("error parsing certificate: %v", err) + } return cert } @@ -453,7 +802,7 @@ func TestValidateKMSConfig(t *testing.T) { RootKeyID: "invalid-key-id", }, wantErr: true, - wantErrMsg: "gcpkms RootKeyID must start with 'projects/'", + wantErrMsg: "must start with 'projects/'", }, { name: "azure_kms_missing_tenant_id", @@ -477,7 +826,7 @@ func TestValidateKMSConfig(t *testing.T) { }, }, wantErr: true, - wantErrMsg: "vault name is required for Azure Key Vault", + wantErrMsg: "azurekms RootKeyID must contain ';vault=' parameter", }, { name: "azure_kms_missing_options", @@ -497,22 +846,80 @@ func TestValidateKMSConfig(t *testing.T) { wantErr: true, wantErrMsg: "unsupported KMS type: unsupported", }, + { + name: "aws_kms_invalid_arn_format", + config: KMSConfig{ + Type: "awskms", + Region: "us-west-2", + RootKeyID: "arn:aws:kms:us-west-2:invalid", + }, + wantErr: true, + wantErrMsg: "invalid AWS KMS ARN format for RootKeyID", + }, + { + name: "aws_kms_region_mismatch", + config: KMSConfig{ + Type: "awskms", + Region: "us-west-2", + RootKeyID: "arn:aws:kms:us-east-1:123456789012:key/test-key", + }, + wantErr: true, + wantErrMsg: "region in ARN (us-east-1) does not match configured region (us-west-2)", + }, + { + name: "aws_kms_empty_alias", + config: KMSConfig{ + Type: "awskms", + Region: "us-west-2", + RootKeyID: "alias/", + }, + wantErr: true, + wantErrMsg: "alias name cannot be empty for RootKeyID", + }, + { + name: "azure_kms_empty_key_name", + config: KMSConfig{ + Type: "azurekms", + RootKeyID: "azurekms:name=;vault=test-vault", + Options: map[string]string{ + "tenant-id": "test-tenant", + }, + }, + wantErr: true, + wantErrMsg: "key name cannot be empty for RootKeyID", + }, + { + name: "azure_kms_empty_vault_name", + config: KMSConfig{ + Type: "azurekms", + RootKeyID: "azurekms:name=test-key;vault=", + Options: map[string]string{ + "tenant-id": "test-tenant", + }, + }, + wantErr: true, + wantErrMsg: "vault name cannot be empty for RootKeyID", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := ValidateKMSConfig(tt.config) if tt.wantErr { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantErrMsg) + if err == nil { + t.Errorf("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Errorf("error %q should contain %q", err.Error(), tt.wantErrMsg) + } } else { - require.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } } }) } } -// TestValidateTemplate tests template validation func TestValidateTemplate(t *testing.T) { tests := []struct { name string @@ -891,10 +1298,15 @@ func TestValidateTemplate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := ValidateTemplate(tt.tmpl, tt.parent, tt.certType) if tt.wantError != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantError) + if err == nil { + t.Errorf("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("error %q should contain %q", err.Error(), tt.wantError) + } } else { - require.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } } }) } @@ -945,10 +1357,15 @@ func TestValidateTemplateKeyUsageCombinations(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := ValidateTemplate(tt.tmpl, tt.parent, tt.certType) if tt.wantError != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantError) + if err == nil { + t.Errorf("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("error %q should contain %q", err.Error(), tt.wantError) + } } else { - require.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } } }) } @@ -999,10 +1416,15 @@ func TestValidateLeafCertificateKeyUsage(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := ValidateTemplate(tt.tmpl, tt.parent, "leaf") if tt.wantError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errMsg) + if err == nil { + t.Errorf("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("error %q should contain %q", err.Error(), tt.errMsg) + } } else { - require.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } } }) } @@ -1010,8 +1432,9 @@ func TestValidateLeafCertificateKeyUsage(t *testing.T) { func TestValidateTemplatePath(t *testing.T) { tests := []struct { - name string - path string + name string + path string + setup func() string wantError string }{ @@ -1025,7 +1448,9 @@ func TestValidateTemplatePath(t *testing.T) { path: "template.txt", setup: func() string { f, err := os.CreateTemp("", "template.txt") - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } return f.Name() }, wantError: "must have .json extension", @@ -1035,9 +1460,13 @@ func TestValidateTemplatePath(t *testing.T) { path: "invalid.json", setup: func() string { f, err := os.CreateTemp("", "template*.json") - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } err = os.WriteFile(f.Name(), []byte("invalid json"), 0600) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } return f.Name() }, wantError: "invalid JSON", @@ -1047,9 +1476,13 @@ func TestValidateTemplatePath(t *testing.T) { path: "valid.json", setup: func() string { f, err := os.CreateTemp("", "template*.json") - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } err = os.WriteFile(f.Name(), []byte(`{"key": "value"}`), 0600) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } return f.Name() }, }, @@ -1065,10 +1498,15 @@ func TestValidateTemplatePath(t *testing.T) { err := ValidateTemplatePath(path) if tt.wantError != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantError) + if err == nil { + t.Errorf("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("error %q should contain %q", err.Error(), tt.wantError) + } } else { - require.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } } }) } @@ -1089,18 +1527,18 @@ func TestGCPKMSValidation(t *testing.T) { wantError: "must start with 'projects/'", }, { - name: "missing_locations_in_key_path", + name: "missing_required_components", config: KMSConfig{ Type: "gcpkms", RootKeyID: "projects/test-project", }, - wantError: "invalid gcpkms key format", + wantError: "gcpkms RootKeyID must contain '/locations/'", }, { name: "valid_GCP_key_format", config: KMSConfig{ Type: "gcpkms", - RootKeyID: "projects/test-project/locations/global/keyRings/test-ring/cryptoKeys/test-key", + RootKeyID: "projects/test-project/locations/global/keyRings/test-ring/cryptoKeys/test-key/cryptoKeyVersions/1", }, wantError: "", }, @@ -1110,10 +1548,15 @@ func TestGCPKMSValidation(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := ValidateKMSConfig(tt.config) if tt.wantError != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantError) + if err == nil { + t.Errorf("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("error %q should contain %q", err.Error(), tt.wantError) + } } else { - require.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } } }) } @@ -1162,7 +1605,7 @@ func TestAzureKMSValidation(t *testing.T) { "tenant-id": "test-tenant", }, }, - wantError: "vault name is required", + wantError: "azurekms RootKeyID must contain ';vault=' parameter", }, { name: "valid config", @@ -1180,10 +1623,15 @@ func TestAzureKMSValidation(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := ValidateKMSConfig(tt.config) if tt.wantError != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantError) + if err == nil { + t.Errorf("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("error %q should contain %q", err.Error(), tt.wantError) + } } else { - require.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } } }) } @@ -1222,7 +1670,7 @@ func TestInitKMSErrors(t *testing.T) { name: "GCP KMS with nonexistent credentials file", config: KMSConfig{ Type: "gcpkms", - RootKeyID: "projects/test-project/locations/global/keyRings/test-ring/cryptoKeys/test-key", + RootKeyID: "projects/test-project/locations/global/keyRings/test-ring/cryptoKeys/test-key/cryptoKeyVersions/1", Options: map[string]string{ "credentials-file": "/nonexistent/credentials.json", }, @@ -1234,8 +1682,123 @@ func TestInitKMSErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := InitKMS(ctx, tt.config) - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantError) + if err == nil { + t.Errorf("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("error %q should contain %q", err.Error(), tt.wantError) + } + }) + } +} + +func TestInitKMS(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "kms-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate private key: %v", err) + } + + privKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privKey), + }) + + credsFile := filepath.Join(tmpDir, "test-credentials.json") + err = os.WriteFile(credsFile, []byte(fmt.Sprintf(`{ + "type": "service_account", + "project_id": "test-project", + "private_key_id": "test-key-id", + "private_key": %q, + "client_email": "test@test-project.iam.gserviceaccount.com", + "client_id": "123456789", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test@test-project.iam.gserviceaccount.com" + }`, string(privKeyPEM))), 0600) + if err != nil { + t.Fatalf("Failed to write credentials file: %v", err) + } + + ctx := context.Background() + tests := []struct { + name string + config KMSConfig + wantError bool + errMsg string + }{ + { + name: "valid AWS KMS config", + config: KMSConfig{ + Type: "awskms", + Region: "us-west-2", + RootKeyID: "arn:aws:kms:us-west-2:123456789012:key/test-key", + LeafKeyID: "arn:aws:kms:us-west-2:123456789012:key/leaf-key", + Options: map[string]string{}, + }, + wantError: false, + }, + { + name: "valid GCP KMS config", + config: KMSConfig{ + Type: "gcpkms", + RootKeyID: "projects/test-project/locations/global/keyRings/test-ring/cryptoKeys/test-key/cryptoKeyVersions/1", + LeafKeyID: "projects/test-project/locations/global/keyRings/test-ring/cryptoKeys/leaf-key/cryptoKeyVersions/1", + Options: map[string]string{ + "credentials-file": credsFile, + }, + }, + wantError: false, + }, + { + name: "valid Azure KMS config", + config: KMSConfig{ + Type: "azurekms", + RootKeyID: "azurekms:name=test-key;vault=test-vault", + LeafKeyID: "azurekms:name=leaf-key;vault=test-vault", + Options: map[string]string{ + "tenant-id": "test-tenant", + }, + }, + wantError: false, + }, + { + name: "invalid KMS type", + config: KMSConfig{ + Type: "invalid", + RootKeyID: "test-key", + LeafKeyID: "leaf-key", + }, + wantError: true, + errMsg: "invalid KMS configuration: unsupported KMS type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + km, err := InitKMS(ctx, tt.config) + if tt.wantError { + if err == nil { + t.Error("expected error but got nil") + } else if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("error %q should contain %q", err.Error(), tt.errMsg) + } + if km != nil { + t.Error("expected nil KMS but got non-nil") + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if km == nil { + t.Error("expected non-nil KMS but got nil") + } + } }) } } diff --git a/pkg/certmaker/template_test.go b/pkg/certmaker/template_test.go index 68a4df2c9..055fbbdbf 100644 --- a/pkg/certmaker/template_test.go +++ b/pkg/certmaker/template_test.go @@ -19,10 +19,8 @@ import ( "crypto/x509" "crypto/x509/pkix" "os" + "strings" "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestValidateTemplateFields(t *testing.T) { @@ -223,10 +221,9 @@ func TestValidateTemplateFields(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := ValidateTemplate(tt.tmpl, tt.parent, tt.certType) if tt.wantError != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantError) - } else { - require.NoError(t, err) + if !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("error %q should contain %q", err.Error(), tt.wantError) + } } }) } @@ -268,22 +265,31 @@ func TestParseTemplateErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tmpFile, err := os.CreateTemp("", "cert-template-*.json") - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } defer os.Remove(tmpFile.Name()) err = os.WriteFile(tmpFile.Name(), []byte(tt.content), 0600) - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } _, err = ParseTemplate(tmpFile.Name(), nil) - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantError) + if err == nil { + t.Errorf("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("error %q should contain %q", err.Error(), tt.wantError) + } }) } - // Test non-existent file _, err := ParseTemplate("nonexistent.json", nil) - require.Error(t, err) - assert.Contains(t, err.Error(), "error reading template file") + if err == nil { + t.Errorf("expected error, got nil") + } else if !strings.Contains(err.Error(), "error reading template file") { + t.Errorf("error %q should contain %q", err.Error(), "error reading template file") + } } func TestInvalidCertificateType(t *testing.T) { @@ -299,18 +305,28 @@ func TestInvalidCertificateType(t *testing.T) { } err := ValidateTemplate(tmpl, nil, "invalid") - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid certificate type") + if err == nil { + t.Errorf("expected error, got nil") + } else if !strings.Contains(err.Error(), "invalid certificate type") { + t.Errorf("error %q should contain %q", err.Error(), "invalid certificate type") + } } func TestContainsExtKeyUsage(t *testing.T) { - assert.False(t, containsExtKeyUsage(nil, "CodeSigning"), "empty list should return false") - assert.False(t, containsExtKeyUsage([]string{}, "CodeSigning"), "empty list should return false") - assert.True(t, containsExtKeyUsage([]string{"CodeSigning"}, "CodeSigning"), "should find matching usage") - assert.False(t, containsExtKeyUsage([]string{"OtherUsage"}, "CodeSigning"), "should not find non-matching usage") + if containsExtKeyUsage(nil, "CodeSigning") { + t.Error("empty list (nil) should return false") + } + if containsExtKeyUsage([]string{}, "CodeSigning") { + t.Error("empty list should return false") + } + if !containsExtKeyUsage([]string{"CodeSigning"}, "CodeSigning") { + t.Error("should find matching usage") + } + if containsExtKeyUsage([]string{"OtherUsage"}, "CodeSigning") { + t.Error("should not find non-matching usage") + } } -// Helper function to check if an extended key usage is present func containsExtKeyUsage(usages []string, target string) bool { for _, usage := range usages { if usage == target { @@ -402,20 +418,49 @@ func TestCreateCertificateFromTemplate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cert, err := CreateCertificateFromTemplate(tt.tmpl, tt.parent) if tt.wantError { - require.Error(t, err) + if err == nil { + t.Errorf("expected error, got nil") + } return } - require.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } - // Verify the certificate fields - assert.Equal(t, tt.tmpl.Subject.CommonName, cert.Subject.CommonName) - assert.Equal(t, tt.tmpl.Subject.Country, cert.Subject.Country) - assert.Equal(t, tt.tmpl.Subject.Organization, cert.Subject.Organization) - assert.Equal(t, tt.tmpl.Subject.OrganizationalUnit, cert.Subject.OrganizationalUnit) - assert.Equal(t, tt.tmpl.BasicConstraints.IsCA, cert.IsCA) + if cert.Subject.CommonName != tt.tmpl.Subject.CommonName { + t.Errorf("CommonName got %v, want %v", cert.Subject.CommonName, tt.tmpl.Subject.CommonName) + } - if tt.tmpl.BasicConstraints.IsCA { - assert.Equal(t, tt.tmpl.BasicConstraints.MaxPathLen, cert.MaxPathLen) + for _, usage := range tt.tmpl.KeyUsage { + switch usage { + case "certSign": + if cert.KeyUsage&x509.KeyUsageCertSign == 0 { + t.Error("expected KeyUsageCertSign to be set") + } + case "crlSign": + if cert.KeyUsage&x509.KeyUsageCRLSign == 0 { + t.Error("expected KeyUsageCRLSign to be set") + } + case "digitalSignature": + if cert.KeyUsage&x509.KeyUsageDigitalSignature == 0 { + t.Error("expected KeyUsageDigitalSignature to be set") + } + } + } + + for _, usage := range tt.tmpl.ExtKeyUsage { + if usage == "CodeSigning" { + found := false + for _, certUsage := range cert.ExtKeyUsage { + if certUsage == x509.ExtKeyUsageCodeSigning { + found = true + break + } + } + if !found { + t.Error("expected ExtKeyUsageCodeSigning to be set") + } + } } }) } @@ -424,20 +469,36 @@ func TestCreateCertificateFromTemplate(t *testing.T) { func TestSetKeyUsagesAndExtKeyUsages(t *testing.T) { cert := &x509.Certificate{} - // Test key usages SetKeyUsages(cert, []string{"certSign", "crlSign", "digitalSignature"}) - assert.True(t, cert.KeyUsage&x509.KeyUsageCertSign != 0) - assert.True(t, cert.KeyUsage&x509.KeyUsageCRLSign != 0) - assert.True(t, cert.KeyUsage&x509.KeyUsageDigitalSignature != 0) + if cert.KeyUsage&x509.KeyUsageCertSign == 0 { + t.Error("expected KeyUsageCertSign to be set") + } + if cert.KeyUsage&x509.KeyUsageCRLSign == 0 { + t.Error("expected KeyUsageCRLSign to be set") + } + if cert.KeyUsage&x509.KeyUsageDigitalSignature == 0 { + t.Error("expected KeyUsageDigitalSignature to be set") + } - // Test extended key usages SetExtKeyUsages(cert, []string{"CodeSigning"}) - assert.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageCodeSigning) + found := false + for _, usage := range cert.ExtKeyUsage { + if usage == x509.ExtKeyUsageCodeSigning { + found = true + break + } + } + if !found { + t.Error("expected ExtKeyUsageCodeSigning to be set") + } - // Test with empty usages newCert := &x509.Certificate{} SetKeyUsages(newCert, nil) SetExtKeyUsages(newCert, nil) - assert.Equal(t, x509.KeyUsage(0), newCert.KeyUsage) - assert.Empty(t, newCert.ExtKeyUsage) + if newCert.KeyUsage != x509.KeyUsage(0) { + t.Error("expected KeyUsage to be cleared") + } + if len(newCert.ExtKeyUsage) != 0 { + t.Error("expected ExtKeyUsage to be cleared") + } }