diff --git a/go.mod b/go.mod index a408887c292f..9bd53d631c36 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.20 require ( github.com/Masterminds/sprig v2.22.0+incompatible github.com/aws/aws-sdk-go v1.42.23 - github.com/aws/aws-sdk-go-v2 v1.16.14 + github.com/aws/aws-sdk-go-v2 v1.21.0 github.com/aws/aws-sdk-go-v2/config v1.15.3 github.com/aws/aws-sdk-go-v2/credentials v1.11.2 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3 @@ -16,7 +16,7 @@ require ( github.com/aws/eks-distro-build-tooling/release v0.0.0-20211103003257-a7e2379eae5e github.com/aws/etcdadm-bootstrap-provider v1.0.7-rc3 github.com/aws/etcdadm-controller v1.0.6-rc3 - github.com/aws/smithy-go v1.13.2 + github.com/aws/smithy-go v1.14.2 github.com/docker/cli v23.0.5+incompatible github.com/go-git/go-git/v5 v5.4.2 github.com/go-logr/logr v1.2.3 @@ -81,9 +81,10 @@ require ( github.com/VictorLowther/simplexml v0.0.0-20180716164440-0bff93621230 // indirect github.com/VictorLowther/soap v0.0.0-20150314151524-8e36fca84b22 // indirect github.com/acomagu/bufpipe v1.0.3 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.9 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.3.10 // indirect + github.com/aws/aws-sdk-go-v2/service/ecr v1.20.0 github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.3 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.11.3 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.16.3 // indirect diff --git a/go.sum b/go.sum index be9f9768115d..0bad3487afb9 100644 --- a/go.sum +++ b/go.sum @@ -524,6 +524,8 @@ github.com/aws/aws-sdk-go v1.42.23/go.mod h1:gyRszuZ/icHmHAVE4gc/r+cfCmhA1AD+vqf github.com/aws/aws-sdk-go-v2 v1.16.2/go.mod h1:ytwTPBG6fXTZLxxeeCCWj2/EMYp/xDUgX+OET6TLNNU= github.com/aws/aws-sdk-go-v2 v1.16.14 h1:db6GvO4Z2UqHt5gvT0lr6J5x5P+oQ7bdRzczVaRekMU= github.com/aws/aws-sdk-go-v2 v1.16.14/go.mod h1:s/G+UV29dECbF5rf+RNj1xhlmvoNurGSr+McVSRj59w= +github.com/aws/aws-sdk-go-v2 v1.21.0 h1:gMT0IW+03wtYJhRqTVYn0wLzwdnK9sRMcxmtfGzRdJc= +github.com/aws/aws-sdk-go-v2 v1.21.0/go.mod h1:/RfNgGmRxI+iFOB1OeJUyxiU+9s88k3pfHvDagGEp0M= github.com/aws/aws-sdk-go-v2/config v1.15.3 h1:5AlQD0jhVXlGzwo+VORKiUuogkG7pQcLJNzIzK7eodw= github.com/aws/aws-sdk-go-v2/config v1.15.3/go.mod h1:9YL3v07Xc/ohTsxFXzan9ZpFpdTOFl4X65BAKYaz8jg= github.com/aws/aws-sdk-go-v2/credentials v1.11.2 h1:RQQ5fzclAKJyY5TvF+fkjJEwzK4hnxQCLOu5JXzDmQo= @@ -532,12 +534,18 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3 h1:LWPg5zjHV9oz/myQr4wMs0g github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3/go.mod h1:uk1vhHHERfSVCUnqSqz8O48LBYDSC+k6brng09jcMOk= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.9 h1:onz/VaaxZ7Z4V+WIN9Txly9XLTmoOh1oJ8XcAC3pako= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.9/go.mod h1:AnVH5pvai0pAF4lXRq0bmhbes1u9R8wTE+g+183bZNM= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41 h1:22dGT7PneFMx4+b3pz7lMTRyN8ZKH7M2cW4GP9yUS2g= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41/go.mod h1:CrObHAuPneJBlfEJ5T3szXOUkLEThaGfvnhTf33buas= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.3 h1:9stUQR/u2KXU6HkFJYlqnZEjBnbgrVbG6I5HN09xZh0= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.3/go.mod h1:ssOhaLpRlh88H3UmEcsBoVKq309quMvm3Ds8e9d4eJM= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35 h1:SijA0mgjV8E+8G45ltVHs0fvKpTj8xmZJ3VwhGKtUSI= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35/go.mod h1:SJC1nEVVva1g3pHAIdCp7QsRIkMmLAgoDquQ9Rr8kYw= github.com/aws/aws-sdk-go-v2/internal/ini v1.3.10 h1:by9P+oy3P/CwggN4ClnW2D4oL91QV7pBzBICi1chZvQ= github.com/aws/aws-sdk-go-v2/internal/ini v1.3.10/go.mod h1:8DcYQcz0+ZJaSxANlHIsbbi6S+zMwjwdDqwW3r9AzaE= github.com/aws/aws-sdk-go-v2/service/ec2 v1.34.0 h1:dfWleW7/a3+TR6qJynYZsaovCEStQOep5x+BxkiBDhc= github.com/aws/aws-sdk-go-v2/service/ec2 v1.34.0/go.mod h1:37MWOQMGyj8lcranOwo716OHvJgeFJUOaWu6vk1pWNE= +github.com/aws/aws-sdk-go-v2/service/ecr v1.20.0 h1:Qw8H7V55d2P1d/a9+cLgAcdez4GtP6l30KQAeYqx9vY= +github.com/aws/aws-sdk-go-v2/service/ecr v1.20.0/go.mod h1:pGwmNL8hN0jpBfKfTbmu+Rl0bJkDhaGl+9PQLrZ4KLo= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.3 h1:Gh1Gpyh01Yvn7ilO/b/hr01WgNpaszfbKMUgqM186xQ= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.3/go.mod h1:wlY6SVjuwvh3TVRpTqdy4I1JpBFLX4UGeKZdWntaocw= github.com/aws/aws-sdk-go-v2/service/sso v1.11.3 h1:frW4ikGcxfAEDfmQqWgMLp+F1n4nRo9sF39OcIb5BkQ= @@ -555,6 +563,8 @@ github.com/aws/etcdadm-controller v1.0.6-rc3/go.mod h1:60QVQeYClyeV22MpI+SMBDx/d github.com/aws/smithy-go v1.11.2/go.mod h1:3xHYmszWVx2c0kIwQeEVf9uSm4fYZt67FBJnwub1bgM= github.com/aws/smithy-go v1.13.2 h1:TBLKyeJfXTrTXRHmsv4qWt9IQGYyWThLYaJWSahTOGE= github.com/aws/smithy-go v1.13.2/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= +github.com/aws/smithy-go v1.14.2 h1:MJU9hqBGbvWZdApzpvoF2WAIJDbtjK2NDJSiJP7HblQ= +github.com/aws/smithy-go v1.14.2/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= github.com/benbjohnson/clock v1.0.3/go.mod h1:bGMdMPoPVvcYyt1gHDf4J2KE153Yf9BuiUKYMaxlTDM= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= diff --git a/pkg/curatedpackages/packagecontrollerclient.go b/pkg/curatedpackages/packagecontrollerclient.go index c53f0e7e1f12..8369aadfa3df 100644 --- a/pkg/curatedpackages/packagecontrollerclient.go +++ b/pkg/curatedpackages/packagecontrollerclient.go @@ -40,6 +40,8 @@ const ( type PackageControllerClientOpt func(client *PackageControllerClient) +type registryAccessTester func(ctx context.Context, accessKey, secret, registry, region string) error + type PackageControllerClient struct { kubeConfig string chart *releasev1.Image @@ -72,6 +74,9 @@ type PackageControllerClient struct { // mu provides some thread-safety. mu sync.Mutex + + // registryAccessTester test if the aws credential has access to registry + registryAccessTester registryAccessTester } // ClientBuilder returns a k8s client for the specified cluster. @@ -107,6 +112,7 @@ func NewPackageControllerClientFullLifecycle(logger logr.Logger, chartManager Ch skipWaitForPackageBundle: true, eksaRegion: eksaDefaultRegion, clientBuilder: clientBuilder, + registryAccessTester: TestRegistryAccess, } } @@ -158,13 +164,14 @@ func (pc *PackageControllerClient) EnableFullLifecycle(ctx context.Context, log // NewPackageControllerClient instantiates a new instance of PackageControllerClient. func NewPackageControllerClient(chartManager ChartManager, kubectl KubectlRunner, clusterName, kubeConfig string, chart *releasev1.Image, registryMirror *registrymirror.RegistryMirror, options ...PackageControllerClientOpt) *PackageControllerClient { pcc := &PackageControllerClient{ - kubeConfig: kubeConfig, - clusterName: clusterName, - chart: chart, - chartManager: chartManager, - kubectl: kubectl, - registryMirror: registryMirror, - eksaRegion: eksaDefaultRegion, + kubeConfig: kubeConfig, + clusterName: clusterName, + chart: chart, + chartManager: chartManager, + kubectl: kubectl, + registryMirror: registryMirror, + eksaRegion: eksaDefaultRegion, + registryAccessTester: TestRegistryAccess, } for _, o := range options { @@ -186,7 +193,7 @@ func NewPackageControllerClient(chartManager ChartManager, kubectl KubectlRunner func (pc *PackageControllerClient) Enable(ctx context.Context) error { ociURI := fmt.Sprintf("%s%s", "oci://", pc.registryMirror.ReplaceRegistry(pc.chart.Image())) clusterName := fmt.Sprintf("clusterName=%s", pc.clusterName) - sourceRegistry, defaultRegistry, defaultImageRegistry := pc.GetCuratedPackagesRegistries() + sourceRegistry, defaultRegistry, defaultImageRegistry := pc.GetCuratedPackagesRegistries(ctx) sourceRegistry = fmt.Sprintf("sourceRegistry=%s", sourceRegistry) defaultRegistry = fmt.Sprintf("defaultRegistry=%s", defaultRegistry) defaultImageRegistry = fmt.Sprintf("defaultImageRegistry=%s", defaultImageRegistry) @@ -232,7 +239,7 @@ func (pc *PackageControllerClient) Enable(ctx context.Context) error { } // GetCuratedPackagesRegistries gets value for configurable registries from PBC. -func (pc *PackageControllerClient) GetCuratedPackagesRegistries() (sourceRegistry, defaultRegistry, defaultImageRegistry string) { +func (pc *PackageControllerClient) GetCuratedPackagesRegistries(ctx context.Context) (sourceRegistry, defaultRegistry, defaultImageRegistry string) { sourceRegistry = publicProdECR defaultImageRegistry = packageProdDomain accountName := prodAccount @@ -260,6 +267,16 @@ func (pc *PackageControllerClient) GetCuratedPackagesRegistries() (sourceRegistr if pc.eksaRegion != eksaDefaultRegion { defaultImageRegistry = strings.ReplaceAll(defaultImageRegistry, eksaDefaultRegion, pc.eksaRegion) } + + regionalRegistry := GetRegionalRegistry(defaultRegistry, pc.eksaRegion) + if err := pc.registryAccessTester(ctx, pc.eksaAccessKeyID, pc.eksaSecretAccessKey, regionalRegistry, pc.eksaRegion); err == nil { + // use regional registry when the above credential is good + logger.V(6).Info("Using regional registry") + defaultRegistry = regionalRegistry + defaultImageRegistry = regionalRegistry + } else { + logger.V(6).Info("Using fallback registry", "Registry", defaultRegistry, "RegionalRegistryAccessIssue", err) + } } return sourceRegistry, defaultRegistry, defaultImageRegistry } @@ -600,3 +617,10 @@ func WithClusterSpec(clusterSpec *cluster.Spec) func(client *PackageControllerCl config.clusterSpec = &clusterSpec.Cluster.Spec } } + +// WithRegistryAccessTester sets the registryTester. +func WithRegistryAccessTester(registryTester registryAccessTester) func(client *PackageControllerClient) { + return func(config *PackageControllerClient) { + config.registryAccessTester = registryTester + } +} diff --git a/pkg/curatedpackages/packagecontrollerclient_test.go b/pkg/curatedpackages/packagecontrollerclient_test.go index 289927eac737..3f8e04e31d82 100644 --- a/pkg/curatedpackages/packagecontrollerclient_test.go +++ b/pkg/curatedpackages/packagecontrollerclient_test.go @@ -274,7 +274,7 @@ func TestEnableSuccess(t *testing.T) { clusterName := fmt.Sprintf("clusterName=%s", "billy") valueFilePath := filepath.Join("billy", filewriter.DefaultTmpFolder, valueFileName) ociURI := fmt.Sprintf("%s%s", "oci://", tt.registryMirror.ReplaceRegistry(tt.chart.Image())) - sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries() + sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries(context.Background()) sourceRegistry = fmt.Sprintf("sourceRegistry=%s", sourceRegistry) defaultRegistry = fmt.Sprintf("defaultRegistry=%s", defaultRegistry) defaultImageRegistry = fmt.Sprintf("defaultImageRegistry=%s", defaultImageRegistry) @@ -317,7 +317,7 @@ func TestEnableSucceedInWorkloadCluster(t *testing.T) { clusterName := fmt.Sprintf("clusterName=%s", "billy") valueFilePath := filepath.Join("billy", filewriter.DefaultTmpFolder, valueFileName) ociURI := fmt.Sprintf("%s%s", "oci://", tt.registryMirror.ReplaceRegistry(tt.chart.Image())) - sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries() + sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries(context.Background()) sourceRegistry = fmt.Sprintf("sourceRegistry=%s", sourceRegistry) defaultRegistry = fmt.Sprintf("defaultRegistry=%s", defaultRegistry) defaultImageRegistry = fmt.Sprintf("defaultImageRegistry=%s", defaultImageRegistry) @@ -360,7 +360,7 @@ func TestEnableSucceedInWorkloadClusterWhenPackageBundleControllerNotExist(t *te clusterName := fmt.Sprintf("clusterName=%s", "billy") valueFilePath := filepath.Join("billy", filewriter.DefaultTmpFolder, valueFileName) ociURI := fmt.Sprintf("%s%s", "oci://", tt.registryMirror.ReplaceRegistry(tt.chart.Image())) - sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries() + sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries(context.Background()) sourceRegistry = fmt.Sprintf("sourceRegistry=%s", sourceRegistry) defaultRegistry = fmt.Sprintf("defaultRegistry=%s", defaultRegistry) defaultImageRegistry = fmt.Sprintf("defaultImageRegistry=%s", defaultImageRegistry) @@ -439,7 +439,7 @@ func TestEnableWithProxy(t *testing.T) { httpsProxy := fmt.Sprintf("proxy.HTTPS_PROXY=%s", tt.httpsProxy) noProxy := fmt.Sprintf("proxy.NO_PROXY=%s", strings.Join(tt.noProxy, "\\,")) ociURI := fmt.Sprintf("%s%s", "oci://", tt.registryMirror.ReplaceRegistry(tt.chart.Image())) - sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries() + sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries(context.Background()) sourceRegistry = fmt.Sprintf("sourceRegistry=%s", sourceRegistry) defaultRegistry = fmt.Sprintf("defaultRegistry=%s", defaultRegistry) defaultImageRegistry = fmt.Sprintf("defaultImageRegistry=%s", defaultImageRegistry) @@ -490,7 +490,7 @@ func TestEnableWithEmptyProxy(t *testing.T) { clusterName := fmt.Sprintf("clusterName=%s", "billy") valueFilePath := filepath.Join("billy", filewriter.DefaultTmpFolder, valueFileName) ociURI := fmt.Sprintf("%s%s", "oci://", tt.registryMirror.ReplaceRegistry(tt.chart.Image())) - sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries() + sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries(context.Background()) sourceRegistry = fmt.Sprintf("sourceRegistry=%s", sourceRegistry) defaultRegistry = fmt.Sprintf("defaultRegistry=%s", defaultRegistry) defaultImageRegistry = fmt.Sprintf("defaultImageRegistry=%s", defaultImageRegistry) @@ -529,7 +529,7 @@ func TestEnableFail(t *testing.T) { clusterName := fmt.Sprintf("clusterName=%s", "billy") valueFilePath := filepath.Join("billy", filewriter.DefaultTmpFolder, valueFileName) ociURI := fmt.Sprintf("%s%s", "oci://", tt.registryMirror.ReplaceRegistry(tt.chart.Image())) - sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries() + sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries(context.Background()) sourceRegistry = fmt.Sprintf("sourceRegistry=%s", sourceRegistry) defaultRegistry = fmt.Sprintf("defaultRegistry=%s", defaultRegistry) defaultImageRegistry = fmt.Sprintf("defaultImageRegistry=%s", defaultImageRegistry) @@ -559,7 +559,7 @@ func TestEnableFailNoActiveBundle(t *testing.T) { clusterName := fmt.Sprintf("clusterName=%s", "billy") valueFilePath := filepath.Join("billy", filewriter.DefaultTmpFolder, valueFileName) ociURI := fmt.Sprintf("%s%s", "oci://", tt.registryMirror.ReplaceRegistry(tt.chart.Image())) - sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries() + sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries(context.Background()) sourceRegistry = fmt.Sprintf("sourceRegistry=%s", sourceRegistry) defaultRegistry = fmt.Sprintf("defaultRegistry=%s", defaultRegistry) defaultImageRegistry = fmt.Sprintf("defaultImageRegistry=%s", defaultImageRegistry) @@ -589,7 +589,7 @@ func TestEnableSuccessWhenCronJobFails(t *testing.T) { clusterName := fmt.Sprintf("clusterName=%s", "billy") valueFilePath := filepath.Join("billy", filewriter.DefaultTmpFolder, valueFileName) ociURI := fmt.Sprintf("%s%s", "oci://", tt.registryMirror.ReplaceRegistry(tt.chart.Image())) - sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries() + sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries(context.Background()) sourceRegistry = fmt.Sprintf("sourceRegistry=%s", sourceRegistry) defaultRegistry = fmt.Sprintf("defaultRegistry=%s", defaultRegistry) defaultImageRegistry = fmt.Sprintf("defaultImageRegistry=%s", defaultImageRegistry) @@ -657,7 +657,7 @@ func TestEnableActiveBundleCustomTimeout(t *testing.T) { clusterName := fmt.Sprintf("clusterName=%s", "billy") valueFilePath := filepath.Join("billy", filewriter.DefaultTmpFolder, valueFileName) ociURI := fmt.Sprintf("%s%s", "oci://", tt.registryMirror.ReplaceRegistry(tt.chart.Image())) - sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries() + sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries(context.Background()) sourceRegistry = fmt.Sprintf("sourceRegistry=%s", sourceRegistry) defaultRegistry = fmt.Sprintf("defaultRegistry=%s", defaultRegistry) defaultImageRegistry = fmt.Sprintf("defaultImageRegistry=%s", defaultImageRegistry) @@ -696,7 +696,7 @@ func TestEnableActiveBundleWaitLoops(t *testing.T) { clusterName := fmt.Sprintf("clusterName=%s", "billy") valueFilePath := filepath.Join("billy", filewriter.DefaultTmpFolder, valueFileName) ociURI := fmt.Sprintf("%s%s", "oci://", tt.registryMirror.ReplaceRegistry(tt.chart.Image())) - sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries() + sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries(context.Background()) sourceRegistry = fmt.Sprintf("sourceRegistry=%s", sourceRegistry) defaultRegistry = fmt.Sprintf("defaultRegistry=%s", defaultRegistry) defaultImageRegistry = fmt.Sprintf("defaultImageRegistry=%s", defaultImageRegistry) @@ -756,7 +756,7 @@ func TestEnableActiveBundleTimesOut(t *testing.T) { clusterName := fmt.Sprintf("clusterName=%s", "billy") valueFilePath := filepath.Join("billy", filewriter.DefaultTmpFolder, valueFileName) ociURI := fmt.Sprintf("%s%s", "oci://", tt.registryMirror.ReplaceRegistry(tt.chart.Image())) - sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries() + sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries(context.Background()) sourceRegistry = fmt.Sprintf("sourceRegistry=%s", sourceRegistry) defaultRegistry = fmt.Sprintf("defaultRegistry=%s", defaultRegistry) defaultImageRegistry = fmt.Sprintf("defaultImageRegistry=%s", defaultImageRegistry) @@ -802,7 +802,7 @@ func TestEnableActiveBundleNamespaceTimesOut(t *testing.T) { clusterName := fmt.Sprintf("clusterName=%s", "billy") valueFilePath := filepath.Join("billy", filewriter.DefaultTmpFolder, valueFileName) ociURI := fmt.Sprintf("%s%s", "oci://", tt.registryMirror.ReplaceRegistry(tt.chart.Image())) - sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries() + sourceRegistry, defaultRegistry, defaultImageRegistry := tt.command.GetCuratedPackagesRegistries(context.Background()) sourceRegistry = fmt.Sprintf("sourceRegistry=%s", sourceRegistry) defaultRegistry = fmt.Sprintf("defaultRegistry=%s", defaultRegistry) defaultImageRegistry = fmt.Sprintf("defaultImageRegistry=%s", defaultImageRegistry) @@ -991,7 +991,7 @@ func TestGetCuratedPackagesRegistriesDefaultRegion(t *testing.T) { g := NewWithT(t) cluster := cluster.Spec{Config: &cluster.Config{Cluster: &v1alpha1.Cluster{Spec: clusterSpec}}} sut := curatedpackages.NewPackageControllerClient(nil, nil, "billy", "", chart, nil, curatedpackages.WithClusterSpec(&cluster)) - _, _, img := sut.GetCuratedPackagesRegistries() + _, _, img := sut.GetCuratedPackagesRegistries(context.Background()) g.Expect(img).To(Equal("783794618700.dkr.ecr.us-west-2.amazonaws.com")) } @@ -1008,7 +1008,7 @@ func TestGetCuratedPackagesRegistriesCustomRegion(t *testing.T) { g := NewWithT(t) cluster := cluster.Spec{Config: &cluster.Config{Cluster: &v1alpha1.Cluster{Spec: clusterSpec}}} sut := curatedpackages.NewPackageControllerClient(nil, nil, "billy", "", chart, nil, curatedpackages.WithClusterSpec(&cluster), curatedpackages.WithEksaRegion("test")) - _, _, img := sut.GetCuratedPackagesRegistries() + _, _, img := sut.GetCuratedPackagesRegistries(context.Background()) g.Expect(img).To(Equal("783794618700.dkr.ecr.test.amazonaws.com")) } @@ -1199,7 +1199,7 @@ func TestGetCuratedPackagesRegistries(s *testing.T) { ) expected := "783794618700.dkr.ecr.testing.amazonaws.com" - _, _, got := client.GetCuratedPackagesRegistries() + _, _, got := client.GetCuratedPackagesRegistries(context.Background()) if got != expected { t.Errorf("expected %q, got %q", expected, got) @@ -1225,12 +1225,44 @@ func TestGetCuratedPackagesRegistries(s *testing.T) { ) expected := "783794618700.dkr.ecr.us-west-2.amazonaws.com" - _, _, got := client.GetCuratedPackagesRegistries() + _, _, got := client.GetCuratedPackagesRegistries(context.Background()) if got != expected { t.Errorf("expected %q, got %q", expected, got) } }) + + s.Run("get regional registries", func(t *testing.T) { + ctrl := gomock.NewController(t) + k := mocks.NewMockKubectlRunner(ctrl) + cm := mocks.NewMockChartManager(ctrl) + kubeConfig := "kubeconfig.kubeconfig" + chart := &artifactsv1.Image{ + Name: "test_controller", + URI: "test_registry/eks-anywhere/eks-anywhere-packages:v1", + } + // eksaRegion := "test-region" + clusterName := "billy" + writer, _ := filewriter.NewWriter(clusterName) + client := curatedpackages.NewPackageControllerClient( + cm, k, clusterName, kubeConfig, chart, nil, + curatedpackages.WithManagementClusterName(clusterName), + curatedpackages.WithValuesFileWriter(writer), + curatedpackages.WithRegistryAccessTester(func(ctx context.Context, accessKey, secret, registry, region string) error { + return nil + }), + ) + + expected := "TODO.dkr.ecr.us-west-2.amazonaws.com" + _, actualDefaultRegistry, actualImageRegistry := client.GetCuratedPackagesRegistries(context.Background()) + + if actualDefaultRegistry != expected { + t.Errorf("expected %q, got %q", expected, actualDefaultRegistry) + } + if actualImageRegistry != expected { + t.Errorf("expected %q, got %q", expected, actualImageRegistry) + } + }) } func TestReconcile(s *testing.T) { diff --git a/pkg/curatedpackages/regional_registry.go b/pkg/curatedpackages/regional_registry.go new file mode 100644 index 000000000000..569ee105da7e --- /dev/null +++ b/pkg/curatedpackages/regional_registry.go @@ -0,0 +1,78 @@ +package curatedpackages + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/ecr" +) + +const ( + devRegionalECR string = "067575901363.dkr.ecr.us-west-2.amazonaws.com" + stagingRegionalECR string = "TODO.dkr.ecr.us-west-2.amazonaws.com" +) + +var prodRegionalECRMap = map[string]string{ + "us-west-2": "TODO.dkr.ecr.us-west-2.amazonaws.com", + "us-east-2": "TODO.dkr.ecr.us-east-2.amazonaws.com", +} + +// TestRegistryAccess test if the packageControllerClient has valid credential to access registry. +func TestRegistryAccess(ctx context.Context, accessKey, secret, registry, region string) error { + cfg, err := config.LoadDefaultConfig(ctx, + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secret, "")), + config.WithRegion(region), + ) + if err != nil { + return err + } + + ecrClient := ecr.NewFromConfig(cfg) + out, err := ecrClient.GetAuthorizationToken(context.Background(), &ecr.GetAuthorizationTokenInput{}) + if err != nil { + return err + } + authToken := out.AuthorizationData[0].AuthorizationToken + + return TestRegistryWithAuthToken(*authToken, registry, http.DefaultClient.Do) +} + +// TestRegistryWithAuthToken test if the registry can be acccessed with auth token. +func TestRegistryWithAuthToken(authToken, registry string, getResponse func(req *http.Request) (*http.Response, error)) error { + manifestPath := "/v2/eks-anywhere-packages/manifests/latest" + + req, err := http.NewRequest("GET", "https://"+registry+manifestPath, nil) + if err != nil { + return err + } + req.Header.Add("Authorization", "Basic "+authToken) + + resp2, err := getResponse(req) + if err != nil { + return err + } + + bodyBytes, err := io.ReadAll(resp2.Body) + // 404 means the IAM policy is good, so 404 is good here + if resp2.StatusCode != 200 && resp2.StatusCode != 404 { + return fmt.Errorf("%s\n, %v", string(bodyBytes), err) + } + + return nil +} + +// GetRegionalRegistry get the regional registry corresponding to defaultRegistry in a specific region. +func GetRegionalRegistry(defaultRegistry, region string) string { + if strings.Contains(defaultRegistry, devAccount) { + return devRegionalECR + } + if strings.Contains(defaultRegistry, stagingAccount) { + return stagingRegionalECR + } + return prodRegionalECRMap[region] +} diff --git a/pkg/curatedpackages/regional_registry_test.go b/pkg/curatedpackages/regional_registry_test.go new file mode 100644 index 000000000000..d573b6486e1f --- /dev/null +++ b/pkg/curatedpackages/regional_registry_test.go @@ -0,0 +1,42 @@ +package curatedpackages_test + +import ( + "bytes" + "io" + "net/http" + "testing" + + "github.com/aws/eks-anywhere/pkg/curatedpackages" +) + +func TestTestRegistry(t *testing.T) { + err := curatedpackages.TestRegistryWithAuthToken("authToken", "registry_url", func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader(nil)), + }, nil + }) + if err != nil { + t.Errorf("Registry is good, but error has been returned %v\n", err) + } + + err = curatedpackages.TestRegistryWithAuthToken("authToken", "registry_url", func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 404, + Body: io.NopCloser(bytes.NewReader(nil)), + }, nil + }) + if err != nil { + t.Errorf("Registry is good, but error has been returned %v\n", err) + } + + err = curatedpackages.TestRegistryWithAuthToken("authToken", "registry_url", func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 400, + Body: io.NopCloser(bytes.NewReader(nil)), + }, nil + }) + if err == nil { + t.Errorf("Error should have been returned") + } +}