Skip to content

Commit

Permalink
resolve validate credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
jokestax committed Dec 30, 2024
1 parent 481f39c commit e3d1b5d
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 11 deletions.
14 changes: 4 additions & 10 deletions cmd/aws/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ssm"
"github.com/aws/aws-sdk-go-v2/service/sts"
internalssh "github.com/konstructio/kubefirst-api/pkg/ssh"
pkg "github.com/konstructio/kubefirst-api/pkg/utils"
"github.com/konstructio/kubefirst/internal/catalog"
Expand Down Expand Up @@ -66,7 +65,8 @@ func createAws(cmd *cobra.Command, _ []string) error {
return nil
}

creds, err := ValidateAWSRegionAndRetrieveCredentials(cfg)
ctx := context.Background()
creds, err := getSessionCredentials(ctx, cfg.Credentials)
if err != nil {
progress.Error(err.Error())
return fmt.Errorf("failed to retrieve AWS credentials: %w", err)
Expand Down Expand Up @@ -166,16 +166,10 @@ func ValidateProvidedFlags(cfg aws.Config, gitProvider, amiType, nodeType string
return nil
}

func ValidateAWSRegionAndRetrieveCredentials(cfg aws.Config) (*aws.Credentials, error) {
// Validate region by creating a client
stsClient := sts.NewFromConfig(cfg)
_, err := stsClient.GetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{})
if err != nil {
return nil, fmt.Errorf("failed to validate AWS region: %w", err)
}
func getSessionCredentials(ctx context.Context, cfg aws.CredentialsProvider) (*aws.Credentials, error) {

Check failure on line 169 in cmd/aws/create.go

View workflow job for this annotation

GitHub Actions / build

unnecessary leading newline (whitespace)

// Retrieve credentials
creds, err := cfg.Credentials.Retrieve(context.TODO())
creds, err := cfg.Retrieve(ctx)
if err != nil {
return nil, fmt.Errorf("failed to retrieve AWS credentials: %w", err)
}
Expand Down
64 changes: 64 additions & 0 deletions cmd/aws/create_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package aws

import (
"context"
"errors"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/stretchr/testify/assert"
)

type mockCredentialsProvider struct {
creds aws.Credentials
err error
}

func (m *mockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
return m.creds, m.err
}

func TestValidateCredentials(t *testing.T) {
tests := []struct {
name string
creds aws.Credentials
err error
expectedErr error
}{
{
name: "valid credentials",
creds: aws.Credentials{
AccessKeyID: "test-access-key-id",
SecretAccessKey: "test-secret-access-key",
SessionToken: "test-session-token",
},
err: nil,
expectedErr: nil,
},
{
name: "failed to retrieve credentials",
creds: aws.Credentials{},
err: errors.New("failed to retrieve credentials"),
expectedErr: errors.New("failed to retrieve AWS credentials: failed to retrieve credentials"),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockProvider := &mockCredentialsProvider{
creds: tt.creds,
err: tt.err,
}

creds, err := getSessionCredentials(context.Background(), mockProvider)
if tt.expectedErr != nil {
assert.Nil(t, creds)
assert.EqualError(t, err, tt.expectedErr.Error())
} else {
assert.NotNil(t, creds)
assert.NoError(t, err)
assert.Equal(t, tt.creds, *creds)
}
})
}
}
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ require (
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.8.1
github.com/spf13/viper v1.19.0
github.com/stretchr/testify v1.9.0
go.mongodb.org/mongo-driver v1.17.1
golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f
gopkg.in/yaml.v3 v3.0.1
Expand Down Expand Up @@ -62,6 +63,7 @@ require (
github.com/gorilla/websocket v1.5.3 // indirect
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/x448/float16 v0.8.4 // indirect
go.opentelemetry.io/contrib/detectors/gcp v1.29.0 // indirect
Expand Down Expand Up @@ -120,7 +122,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/servicequotas v1.25.7 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.33.2
github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 // indirect
github.com/aws/smithy-go v1.22.1 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
Expand Down

0 comments on commit e3d1b5d

Please sign in to comment.