Skip to content

Commit

Permalink
fix: improves kms key validation across providers.
Browse files Browse the repository at this point in the history
Signed-off-by: ianhundere <138915+ianhundere@users.noreply.github.com>
  • Loading branch information
ianhundere committed Jan 13, 2025
1 parent 27f0d29 commit 1d0ac01
Show file tree
Hide file tree
Showing 6 changed files with 971 additions and 304 deletions.
74 changes: 46 additions & 28 deletions cmd/certificate_maker/certificate_maker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@ package main
import (
"os"
"path/filepath"
"strings"
"testing"

"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGetConfigValue(t *testing.T) {
Expand Down Expand Up @@ -84,22 +83,27 @@ func TestGetConfigValue(t *testing.T) {
defer os.Unsetenv(tt.envVar)
}
got := getConfigValue(tt.flagValue, tt.envVar)
assert.Equal(t, tt.want, got)
if got != tt.want {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}

func TestInitLogger(t *testing.T) {
logger := initLogger()
require.NotNil(t, logger)
if logger == nil {
t.Error("logger should not be nil")
}
}

func TestRunCreate(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "cert-test-*")
require.NoError(t, err)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer os.RemoveAll(tmpDir)

// Create test template files
rootTemplate := `{
"subject": {
"commonName": "Test Root CA"
Expand Down Expand Up @@ -132,9 +136,13 @@ func TestRunCreate(t *testing.T) {
rootTmplPath := filepath.Join(tmpDir, "root-template.json")
leafTmplPath := filepath.Join(tmpDir, "leaf-template.json")
err = os.WriteFile(rootTmplPath, []byte(rootTemplate), 0600)
require.NoError(t, err)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
err = os.WriteFile(leafTmplPath, []byte(leafTemplate), 0600)
require.NoError(t, err)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

tests := []struct {
name string
Expand Down Expand Up @@ -236,7 +244,6 @@ func TestRunCreate(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set environment variables
for k, v := range tt.envVars {
os.Setenv(k, v)
defer os.Unsetenv(k)
Expand All @@ -247,7 +254,6 @@ func TestRunCreate(t *testing.T) {
RunE: runCreate,
}

// Add all flags that runCreate expects
cmd.Flags().StringVar(&kmsType, "kms-type", "", "KMS provider type (awskms, gcpkms, azurekms)")
cmd.Flags().StringVar(&kmsRegion, "aws-region", "", "AWS KMS region")
cmd.Flags().StringVar(&kmsKeyID, "kms-key-id", "", "KMS key identifier")
Expand All @@ -267,58 +273,70 @@ func TestRunCreate(t *testing.T) {
err := cmd.Execute()

if tt.wantError {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
if !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("error %q should contain %q", err.Error(), tt.errMsg)
}
} else {
require.NoError(t, err)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
})
}
}

func TestCreateCommand(t *testing.T) {
// Create a test command
cmd := &cobra.Command{
Use: "test",
RunE: func(_ *cobra.Command, _ []string) error {
return nil
},
}

// Add flags
cmd.Flags().StringVar(&kmsType, "kms-type", "", "KMS type")
cmd.Flags().StringVar(&kmsRegion, "aws-region", "", "AWS KMS region")
cmd.Flags().StringVar(&rootKeyID, "root-key-id", "", "Root key ID")
cmd.Flags().StringVar(&leafKeyID, "leaf-key-id", "", "Leaf key ID")

// Test missing required flags
err := cmd.Execute()
require.NoError(t, err) // No required flags set yet
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// Test flag parsing
err = cmd.ParseFlags([]string{
"--kms-type", "awskms",
"--aws-region", "us-west-2",
"--root-key-id", "arn:aws:kms:us-west-2:123456789012:key/1234abcd-12ab-34cd-56ef-1234567890ab",
"--leaf-key-id", "arn:aws:kms:us-west-2:123456789012:key/9876fedc-ba98-7654-3210-fedcba987654",
})
require.NoError(t, err)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// Verify flag values
assert.Equal(t, "awskms", kmsType)
assert.Equal(t, "us-west-2", kmsRegion)
assert.Equal(t, "arn:aws:kms:us-west-2:123456789012:key/1234abcd-12ab-34cd-56ef-1234567890ab", rootKeyID)
assert.Equal(t, "arn:aws:kms:us-west-2:123456789012:key/9876fedc-ba98-7654-3210-fedcba987654", leafKeyID)
if kmsType != "awskms" {
t.Errorf("got kmsType %v, want awskms", kmsType)
}
if kmsRegion != "us-west-2" {
t.Errorf("got kmsRegion %v, want us-west-2", kmsRegion)
}
if rootKeyID != "arn:aws:kms:us-west-2:123456789012:key/1234abcd-12ab-34cd-56ef-1234567890ab" {
t.Errorf("got rootKeyID %v, want arn:aws:kms:us-west-2:123456789012:key/1234abcd-12ab-34cd-56ef-1234567890ab", rootKeyID)
}
if leafKeyID != "arn:aws:kms:us-west-2:123456789012:key/9876fedc-ba98-7654-3210-fedcba987654" {
t.Errorf("got leafKeyID %v, want arn:aws:kms:us-west-2:123456789012:key/9876fedc-ba98-7654-3210-fedcba987654", leafKeyID)
}
}

func TestRootCommand(t *testing.T) {
// Test help output
rootCmd.SetArgs([]string{"--help"})
err := rootCmd.Execute()
require.NoError(t, err)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// Test unknown command
rootCmd.SetArgs([]string{"unknown"})
err = rootCmd.Execute()
require.Error(t, err)
if err == nil {
t.Error("expected error for unknown command, got nil")
}
}
9 changes: 3 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ require (
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.19.0
github.com/spiffe/go-spiffe/v2 v2.4.0
github.com/stretchr/testify v1.10.0
github.com/tink-crypto/tink-go-awskms/v2 v2.1.0
github.com/tink-crypto/tink-go-gcpkms/v2 v2.2.0
github.com/tink-crypto/tink-go/v2 v2.2.0
Expand Down Expand Up @@ -62,9 +61,9 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys v0.10.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.1.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.0 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.3.1 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver/v3 v3.3.0 // indirect
github.com/Masterminds/sprig/v3 v3.3.0 // indirect
Expand All @@ -89,7 +88,6 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chainguard-dev/clog v1.5.1 // indirect
github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-jose/go-jose/v3 v3.0.3 // indirect
github.com/go-logr/logr v1.4.2 // indirect
Expand Down Expand Up @@ -128,7 +126,6 @@ require (
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/ryanuber/go-glob v1.0.0 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
Expand Down
12 changes: 6 additions & 6 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys v0.10.0 h1:m/sWOGCREuSBqg2
github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys v0.10.0/go.mod h1:Pu5Zksi2KrU7LPbZbNINx6fuVrUp/ffvpxdDj+i8LeE=
github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 h1:FbH3BbSb4bvGluTesZZ+ttN/MDsnMmQP36OSnDuSXqw=
github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1/go.mod h1:9V2j0jn9jDEkCkv8w/bKTNppX/d0FVA1ud77xCIP4KA=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.1.0 h1:DRiANoJTiW6obBQe3SqZizkuV1PEgfiiGivmVocDy64=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.1.0/go.mod h1:qLIye2hwb/ZouqhpSD9Zn3SJipvpEnz1Ywl3VUk9Y0s=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0 h1:D3occbWoio4EBLkbkevetNMAVX197GkzbUMtqjGWn80=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0/go.mod h1:bTSOgj05NGRuHHhQwAdPnYr9TOdNmKlZTgGLL6nyAdI=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.0 h1:7rKG7UmnrxX4N53TFhkYqjc+kVUZuw0fL8I3Fh+Ld9E=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.0/go.mod h1:Wjo+24QJVhhl/L7jy6w9yzFF2yDOf3cKECAa8ecf9vE=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.0 h1:eXnN9kaS8TiDwXjoie3hMRLuwdUBUMW9KRgOqB3mCaw=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.0/go.mod h1:XIpam8wumeZ5rVMuhdDQLMfIPDf1WO3IzrCRO3e3e3o=
github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM=
github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE=
github.com/AzureAD/microsoft-authentication-library-for-go v1.3.1 h1:gUDtaZk8heteyfdmv+pcfHvhR9llnh7c7GMwZ8RVG04=
Expand Down Expand Up @@ -154,8 +154,6 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
Expand Down Expand Up @@ -408,6 +406,8 @@ go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
Expand Down
44 changes: 36 additions & 8 deletions pkg/certmaker/certmaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,19 @@ func ValidateKMSConfig(config KMSConfig) error {
if keyID == "" {
return nil
}
if !strings.HasPrefix(keyID, "arn:aws:kms:") && !strings.HasPrefix(keyID, "alias/") {
if strings.HasPrefix(keyID, "arn:aws:kms:") {
parts := strings.Split(keyID, ":")
if len(parts) < 6 {
return fmt.Errorf("invalid AWS KMS ARN format for %s", keyType)
}
if parts[3] != config.Region {
return fmt.Errorf("region in ARN (%s) does not match configured region (%s)", parts[3], config.Region)
}
} else if strings.HasPrefix(keyID, "alias/") {
if strings.TrimPrefix(keyID, "alias/") == "" {
return fmt.Errorf("alias name cannot be empty for %s", keyType)
}
} else {
return fmt.Errorf("awskms %s must start with 'arn:aws:kms:' or 'alias/'", keyType)
}
return nil
Expand All @@ -250,11 +262,20 @@ func ValidateKMSConfig(config KMSConfig) error {
if keyID == "" {
return nil
}
if !strings.HasPrefix(keyID, "projects/") {
return fmt.Errorf("gcpkms %s must start with 'projects/'", keyType)
requiredComponents := []struct {
component string
message string
}{
{"projects/", "must start with 'projects/'"},
{"/locations/", "must contain '/locations/'"},
{"/keyRings/", "must contain '/keyRings/'"},
{"/cryptoKeys/", "must contain '/cryptoKeys/'"},
{"/cryptoKeyVersions/", "must contain '/cryptoKeyVersions/'"},
}
if !strings.Contains(keyID, "/locations/") || !strings.Contains(keyID, "/keyRings/") {
return fmt.Errorf("invalid gcpkms key format for %s: %s", keyType, keyID)
for _, req := range requiredComponents {
if !strings.Contains(keyID, req.component) {
return fmt.Errorf("gcpkms %s %s", keyType, req.message)
}
}
return nil
}
Expand All @@ -280,12 +301,19 @@ func ValidateKMSConfig(config KMSConfig) error {
if keyID == "" {
return nil
}
// Validate format: azurekms:name=<key-name>;vault=<vault-name>
if !strings.HasPrefix(keyID, "azurekms:name=") {
return fmt.Errorf("azurekms %s must start with 'azurekms:name='", keyType)
}
if !strings.Contains(keyID, ";vault=") {
return fmt.Errorf("vault name is required for Azure Key Vault")
nameStart := strings.Index(keyID, "name=") + 5
vaultIndex := strings.Index(keyID, ";vault=")
if vaultIndex == -1 {
return fmt.Errorf("azurekms %s must contain ';vault=' parameter", keyType)
}
if strings.TrimSpace(keyID[nameStart:vaultIndex]) == "" {
return fmt.Errorf("key name cannot be empty for %s", keyType)
}
if strings.TrimSpace(keyID[vaultIndex+7:]) == "" {
return fmt.Errorf("vault name cannot be empty for %s", keyType)
}
return nil
}
Expand Down
Loading

0 comments on commit 1d0ac01

Please sign in to comment.