Skip to content

Commit

Permalink
Refactor registry tester to support aws config (#6761)
Browse files Browse the repository at this point in the history
  • Loading branch information
d8660091 authored Oct 4, 2023
1 parent 17d6a80 commit 79c297c
Show file tree
Hide file tree
Showing 4 changed files with 297 additions and 47 deletions.
12 changes: 5 additions & 7 deletions pkg/curatedpackages/packagecontrollerclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ const (

type PackageControllerClientOpt func(client *PackageControllerClient)

type registryAccessTester func(ctx context.Context, accessKey, secret, registry, region string) error

type PackageControllerClient struct {
kubeConfig string
chart *releasev1.Image
Expand Down Expand Up @@ -76,7 +74,7 @@ type PackageControllerClient struct {
mu sync.Mutex

// registryAccessTester test if the aws credential has access to registry
registryAccessTester registryAccessTester
registryAccessTester RegistryAccessTester
}

// ClientBuilder returns a k8s client for the specified cluster.
Expand Down Expand Up @@ -112,7 +110,7 @@ func NewPackageControllerClientFullLifecycle(logger logr.Logger, chartManager Ch
skipWaitForPackageBundle: true,
eksaRegion: eksaDefaultRegion,
clientBuilder: clientBuilder,
registryAccessTester: TestRegistryAccess,
registryAccessTester: &DefaultRegistryAccessTester{},
}
}

Expand Down Expand Up @@ -171,7 +169,7 @@ func NewPackageControllerClient(chartManager ChartManager, kubectl KubectlRunner
kubectl: kubectl,
registryMirror: registryMirror,
eksaRegion: eksaDefaultRegion,
registryAccessTester: TestRegistryAccess,
registryAccessTester: &DefaultRegistryAccessTester{},
}

for _, o := range options {
Expand Down Expand Up @@ -269,7 +267,7 @@ func (pc *PackageControllerClient) GetCuratedPackagesRegistries(ctx context.Cont
}

regionalRegistry := GetRegionalRegistry(defaultRegistry, pc.eksaRegion)
if err := pc.registryAccessTester(ctx, pc.eksaAccessKeyID, pc.eksaSecretAccessKey, regionalRegistry, pc.eksaRegion); err == nil {
if err := pc.registryAccessTester.Test(ctx, pc.eksaAccessKeyID, pc.eksaSecretAccessKey, pc.eksaRegion, pc.eksaAwsConfig, regionalRegistry); err == nil {
// use regional registry when the above credential is good
logger.V(6).Info("Using regional registry")
defaultRegistry = regionalRegistry
Expand Down Expand Up @@ -619,7 +617,7 @@ func WithClusterSpec(clusterSpec *cluster.Spec) func(client *PackageControllerCl
}

// WithRegistryAccessTester sets the registryTester.
func WithRegistryAccessTester(registryTester registryAccessTester) func(client *PackageControllerClient) {
func WithRegistryAccessTester(registryTester RegistryAccessTester) func(client *PackageControllerClient) {
return func(config *PackageControllerClient) {
config.registryAccessTester = registryTester
}
Expand Down
10 changes: 7 additions & 3 deletions pkg/curatedpackages/packagecontrollerclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,12 @@ func TestEnableFullLifecyclePath(t *testing.T) {
}
}

type stubRegistryAccessTester struct{}

func (s *stubRegistryAccessTester) Test(ctx context.Context, accessKey, secret, registry, region, awsConfig string) error {
return nil
}

func TestGetCuratedPackagesRegistries(s *testing.T) {
s.Run("substitutes a region if set", func(t *testing.T) {
ctrl := gomock.NewController(t)
Expand Down Expand Up @@ -1248,9 +1254,7 @@ func TestGetCuratedPackagesRegistries(s *testing.T) {
cm, k, clusterName, kubeConfig, chart, nil,
curatedpackages.WithManagementClusterName(clusterName),
curatedpackages.WithValuesFileWriter(writer),
curatedpackages.WithRegistryAccessTester(func(ctx context.Context, accessKey, secret, registry, region string) error {
return nil
}),
curatedpackages.WithRegistryAccessTester(&stubRegistryAccessTester{}),
)

expected := "TODO.dkr.ecr.us-west-2.amazonaws.com"
Expand Down
131 changes: 116 additions & 15 deletions pkg/curatedpackages/regional_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"fmt"
"io"
"net/http"
"os"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/ecr"
Expand All @@ -22,28 +24,33 @@ var prodRegionalECRMap = map[string]string{
"us-east-2": "TODO.dkr.ecr.us-east-2.amazonaws.com",
}

// TestRegistryAccess test if the packageControllerClient has valid credential to access registry.
func TestRegistryAccess(ctx context.Context, accessKey, secret, registry, region string) error {
cfg, err := config.LoadDefaultConfig(ctx,
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secret, "")),
config.WithRegion(region),
)
if err != nil {
return err
}
// RegistryAccessTester test if AWS credentials has valid permission to access an ECR registry.
type RegistryAccessTester interface {
Test(ctx context.Context, accessKey, secret, region, awsConfig, registry string) error
}

ecrClient := ecr.NewFromConfig(cfg)
out, err := ecrClient.GetAuthorizationToken(context.Background(), &ecr.GetAuthorizationTokenInput{})
// DefaultRegistryAccessTester the default implementation of RegistryAccessTester.
type DefaultRegistryAccessTester struct{}

// Test if the AWS static credential or sharedConfig has valid permission to access an ECR registry.
func (r *DefaultRegistryAccessTester) Test(ctx context.Context, accessKey, secret, region, awsConfig, registry string) (err error) {
authTokenProvider := &DefaultRegistryAuthTokenProvider{}

var authToken string
if len(awsConfig) > 0 {
authToken, err = authTokenProvider.GetTokenByAWSConfig(ctx, awsConfig)
} else {
authToken, err = authTokenProvider.GetTokenByAWSKeySecret(ctx, accessKey, secret, region)
}
if err != nil {
return err
}
authToken := out.AuthorizationData[0].AuthorizationToken

return TestRegistryWithAuthToken(*authToken, registry, http.DefaultClient.Do)
return TestRegistryWithAuthToken(authToken, registry, http.DefaultClient.Do)
}

// TestRegistryWithAuthToken test if the registry can be acccessed with auth token.
func TestRegistryWithAuthToken(authToken, registry string, getResponse func(req *http.Request) (*http.Response, error)) error {
func TestRegistryWithAuthToken(authToken, registry string, do Do) error {
manifestPath := "/v2/eks-anywhere-packages/manifests/latest"

req, err := http.NewRequest("GET", "https://"+registry+manifestPath, nil)
Expand All @@ -52,7 +59,7 @@ func TestRegistryWithAuthToken(authToken, registry string, getResponse func(req
}
req.Header.Add("Authorization", "Basic "+authToken)

resp2, err := getResponse(req)
resp2, err := do(req)
if err != nil {
return err
}
Expand All @@ -76,3 +83,97 @@ func GetRegionalRegistry(defaultRegistry, region string) string {
}
return prodRegionalECRMap[region]
}

// RegistryAuthTokenProvider provides auth token for registry access.
type RegistryAuthTokenProvider interface {
GetTokenByAWSConfig(ctx context.Context, awsConfig string) (string, error)
GetTokenByAWSKeySecret(ctx context.Context, key, secret, region string) (string, error)
}

// DefaultRegistryAuthTokenProvider provides auth token for AWS ECR registry access.
type DefaultRegistryAuthTokenProvider struct{}

// GetTokenByAWSConfig get auth token by AWS config.
func (d *DefaultRegistryAuthTokenProvider) GetTokenByAWSConfig(ctx context.Context, awsConfig string) (string, error) {
cfg, err := ParseAWSConfig(ctx, awsConfig)
if err != nil {
return "", err
}
return getAuthorizationToken(*cfg)
}

// ParseAWSConfig parse AWS config from string.
func ParseAWSConfig(ctx context.Context, awsConfig string) (*aws.Config, error) {
file, err := os.CreateTemp("", "eksa-temp-aws-config-*")
if err != nil {
return nil, err
}
if _, err := file.Write([]byte(awsConfig)); err != nil {
return nil, err
}
defer os.Remove(file.Name())
if err != nil {
return nil, err
}

cfg, err := config.LoadDefaultConfig(ctx,
config.WithSharedConfigFiles([]string{file.Name()}),
)
if err != nil {
return nil, err
}
return &cfg, nil
}

// GetAWSConfigFromKeySecret get AWS config from key, secret and region.
func GetAWSConfigFromKeySecret(ctx context.Context, key, secret, region string) (*aws.Config, error) {
cfg, err := config.LoadDefaultConfig(ctx,
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(key, secret, "")),
config.WithRegion(region),
)
if err != nil {
return nil, err
}
return &cfg, nil
}

// GetTokenByAWSKeySecret get auth token by AWS key and secret.
func (d *DefaultRegistryAuthTokenProvider) GetTokenByAWSKeySecret(ctx context.Context, key, secret, region string) (string, error) {
cfg, err := GetAWSConfigFromKeySecret(ctx, key, secret, region)
if err != nil {
return "", err
}

return getAuthorizationToken(*cfg)
}

func getAuthorizationToken(cfg aws.Config) (string, error) {
ecrClient := ecr.NewFromConfig(cfg)
out, err := ecrClient.GetAuthorizationToken(context.Background(), &ecr.GetAuthorizationTokenInput{})
if err != nil {
return "", fmt.Errorf("ecrClient cannot get authorization token: %w", err)
}
authToken := out.AuthorizationData[0].AuthorizationToken
return *authToken, nil
}

// Do is a function type that takes a http request and returns a http response.
type Do func(req *http.Request) (*http.Response, error)

// TestRegistryAccessWithAWSConfig test if the AWS config has valid permission to access container registry.
func TestRegistryAccessWithAWSConfig(ctx context.Context, awsConfig, registry string, tokenProvider RegistryAuthTokenProvider, do Do) error {
token, err := tokenProvider.GetTokenByAWSConfig(ctx, awsConfig)
if err != nil {
return err
}
return TestRegistryWithAuthToken(token, registry, do)
}

// TestRegistryAccessWithAWSKeySecret test if the AWS key and secret has valid permission to access container registry.
func TestRegistryAccessWithAWSKeySecret(ctx context.Context, key, secret, region, registry string, tokenProvider RegistryAuthTokenProvider, do Do) error {
token, err := tokenProvider.GetTokenByAWSKeySecret(ctx, key, secret, region)
if err != nil {
return err
}
return TestRegistryWithAuthToken(token, registry, do)
}
Loading

0 comments on commit 79c297c

Please sign in to comment.