From e3d1b5d8792060ae40662d3a392565d30ad72d87 Mon Sep 17 00:00:00 2001 From: mrrishi Date: Tue, 31 Dec 2024 01:48:34 +0530 Subject: [PATCH] resolve validate credentials --- cmd/aws/create.go | 14 +++------ cmd/aws/create_test.go | 64 ++++++++++++++++++++++++++++++++++++++++++ go.mod | 4 ++- 3 files changed, 71 insertions(+), 11 deletions(-) create mode 100644 cmd/aws/create_test.go diff --git a/cmd/aws/create.go b/cmd/aws/create.go index 59e539ba..e6d29a94 100644 --- a/cmd/aws/create.go +++ b/cmd/aws/create.go @@ -16,7 +16,6 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ssm" - "github.com/aws/aws-sdk-go-v2/service/sts" internalssh "github.com/konstructio/kubefirst-api/pkg/ssh" pkg "github.com/konstructio/kubefirst-api/pkg/utils" "github.com/konstructio/kubefirst/internal/catalog" @@ -66,7 +65,8 @@ func createAws(cmd *cobra.Command, _ []string) error { return nil } - creds, err := ValidateAWSRegionAndRetrieveCredentials(cfg) + ctx := context.Background() + creds, err := getSessionCredentials(ctx, cfg.Credentials) if err != nil { progress.Error(err.Error()) return fmt.Errorf("failed to retrieve AWS credentials: %w", err) @@ -166,16 +166,10 @@ func ValidateProvidedFlags(cfg aws.Config, gitProvider, amiType, nodeType string return nil } -func ValidateAWSRegionAndRetrieveCredentials(cfg aws.Config) (*aws.Credentials, error) { - // Validate region by creating a client - stsClient := sts.NewFromConfig(cfg) - _, err := stsClient.GetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}) - if err != nil { - return nil, fmt.Errorf("failed to validate AWS region: %w", err) - } +func getSessionCredentials(ctx context.Context, cfg aws.CredentialsProvider) (*aws.Credentials, error) { // Retrieve credentials - creds, err := cfg.Credentials.Retrieve(context.TODO()) + creds, err := cfg.Retrieve(ctx) if err != nil { return nil, fmt.Errorf("failed to retrieve AWS credentials: %w", err) } diff --git a/cmd/aws/create_test.go b/cmd/aws/create_test.go new file mode 100644 index 00000000..1e6ae53c --- /dev/null +++ b/cmd/aws/create_test.go @@ -0,0 +1,64 @@ +package aws + +import ( + "context" + "errors" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/stretchr/testify/assert" +) + +type mockCredentialsProvider struct { + creds aws.Credentials + err error +} + +func (m *mockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { + return m.creds, m.err +} + +func TestValidateCredentials(t *testing.T) { + tests := []struct { + name string + creds aws.Credentials + err error + expectedErr error + }{ + { + name: "valid credentials", + creds: aws.Credentials{ + AccessKeyID: "test-access-key-id", + SecretAccessKey: "test-secret-access-key", + SessionToken: "test-session-token", + }, + err: nil, + expectedErr: nil, + }, + { + name: "failed to retrieve credentials", + creds: aws.Credentials{}, + err: errors.New("failed to retrieve credentials"), + expectedErr: errors.New("failed to retrieve AWS credentials: failed to retrieve credentials"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockProvider := &mockCredentialsProvider{ + creds: tt.creds, + err: tt.err, + } + + creds, err := getSessionCredentials(context.Background(), mockProvider) + if tt.expectedErr != nil { + assert.Nil(t, creds) + assert.EqualError(t, err, tt.expectedErr.Error()) + } else { + assert.NotNil(t, creds) + assert.NoError(t, err) + assert.Equal(t, tt.creds, *creds) + } + }) + } +} diff --git a/go.mod b/go.mod index c4cb9999..ba159603 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.19.0 + github.com/stretchr/testify v1.9.0 go.mongodb.org/mongo-driver v1.17.1 golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f gopkg.in/yaml.v3 v3.0.1 @@ -62,6 +63,7 @@ require ( github.com/gorilla/websocket v1.5.3 // indirect github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/x448/float16 v0.8.4 // indirect go.opentelemetry.io/contrib/detectors/gcp v1.29.0 // indirect @@ -120,7 +122,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/servicequotas v1.25.7 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 + github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 // indirect github.com/aws/smithy-go v1.22.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect