From f24d2edfcc6e8e8b022013671703b249d9a4164d Mon Sep 17 00:00:00 2001 From: vj Date: Fri, 30 Aug 2024 13:17:17 -0600 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=B9=20improve=20aws=20s3=20buckets=20r?= =?UTF-8?q?esource?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- providers/aws/resources/aws.lr | 2 +- providers/aws/resources/aws_s3.go | 188 ++++++++++++++++++++++-------- 2 files changed, 140 insertions(+), 50 deletions(-) diff --git a/providers/aws/resources/aws.lr b/providers/aws/resources/aws.lr index bce64ee084..bd0f032478 100644 --- a/providers/aws/resources/aws.lr +++ b/providers/aws/resources/aws.lr @@ -1717,7 +1717,7 @@ aws.s3 @defaults("buckets") { } // Amazon S3 bucket -private aws.s3.bucket @defaults("name location public") { +private aws.s3.bucket @defaults("name public") { // ARN of the bucket arn string // Name of the bucket diff --git a/providers/aws/resources/aws_s3.go b/providers/aws/resources/aws_s3.go index 44cb6c8a62..5a654bf5af 100644 --- a/providers/aws/resources/aws_s3.go +++ b/providers/aws/resources/aws_s3.go @@ -58,41 +58,29 @@ func (a *mqlAwsS3) buckets() ([]interface{}, error) { conn := a.MqlRuntime.Connection.(*connection.AwsConnection) svc := conn.S3("") - ctx := context.Background() - buckets, err := svc.ListBuckets(ctx, &s3.ListBucketsInput{}) - if err != nil { - return nil, err + totalBuckets := make([]s3types.Bucket, 0) + params := &s3.ListBucketsInput{} + paginator := s3.NewListBucketsPaginator(svc, params, func(o *s3.ListBucketsPaginatorOptions) { + o.Limit = 100 + }) + for paginator.HasMorePages() { + output, err := paginator.NextPage(context.TODO()) + if err != nil { + return nil, err + } + totalBuckets = append(totalBuckets, output.Buckets...) } res := []interface{}{} - for i := range buckets.Buckets { - bucket := buckets.Buckets[i] + for i := range totalBuckets { + bucket := totalBuckets[i] - location, err := svc.GetBucketLocation(ctx, &s3.GetBucketLocationInput{ - Bucket: bucket.Name, - }) - if err != nil { - log.Error().Err(err).Str("bucket", *bucket.Name).Msg("Could not get bucket location") - continue - } - if location == nil { - log.Error().Err(err).Str("bucket", *bucket.Name).Msg("Could not get bucket location (returned null)") - continue - } - - region := string(location.LocationConstraint) - // us-east-1 returns "" therefore we set it explicitly - // https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetBucketLocation.html#API_GetBucketLocation_ResponseSyntax - if region == "" { - region = "us-east-1" - } mqlS3Bucket, err := CreateResource(a.MqlRuntime, "aws.s3.bucket", map[string]*llx.RawData{ "name": llx.StringDataPtr(bucket.Name), "arn": llx.StringData(fmt.Sprintf(s3ArnPattern, convert.ToString(bucket.Name))), "exists": llx.BoolData(true), - "location": llx.StringData(region), "createdTime": llx.TimeDataPtr(bucket.CreationDate), }) if err != nil { @@ -104,6 +92,34 @@ func (a *mqlAwsS3) buckets() ([]interface{}, error) { return res, nil } +func (s *mqlAwsS3Bucket) location() (string, error) { + if s.Location.Data == "" { + return s.fetchBucketLocation() + } + return "", errors.Newf("no location found for bucket %s", s.Name.Data) +} + +func (s *mqlAwsS3Bucket) fetchBucketLocation() (string, error) { + conn := s.MqlRuntime.Connection.(*connection.AwsConnection) + svc := conn.S3("") + ctx := context.Background() + location, err := svc.GetBucketLocation(ctx, &s3.GetBucketLocationInput{ + Bucket: aws.String(s.Name.Data), + }) + if err != nil { + return "", err + } + if location == nil { + return "", errors.Newf("no location found for bucket %s", s.Name.Data) + } + loc := string(location.LocationConstraint) + if location.LocationConstraint == "" { + // us-east-1 comes back as empty string + loc = "us-east-1" + } + return loc, nil +} + func initAwsS3Bucket(runtime *plugin.Runtime, args map[string]*llx.RawData) (map[string]*llx.RawData, plugin.Resource, error) { // NOTE: bucket only initializes with arn and name if len(args) >= 2 { @@ -185,6 +201,13 @@ func emptyAwsS3BucketPolicy(runtime *plugin.Runtime) (*mqlAwsS3BucketPolicy, err } func (a *mqlAwsS3Bucket) policy() (*mqlAwsS3BucketPolicy, error) { + if a.Location.Data == "" { + l, err := a.fetchBucketLocation() + if err != nil { + return nil, err + } + a.Location.Data = l + } conn := a.MqlRuntime.Connection.(*connection.AwsConnection) bucketname := a.Name.Data @@ -220,6 +243,13 @@ func (a *mqlAwsS3Bucket) policy() (*mqlAwsS3BucketPolicy, error) { } func (a *mqlAwsS3Bucket) tags() (map[string]interface{}, error) { + if a.Location.Data == "" { + l, err := a.fetchBucketLocation() + if err != nil { + return nil, err + } + a.Location.Data = l + } bucketname := a.Name.Data location := a.Location.Data @@ -251,33 +281,16 @@ func (a *mqlAwsS3Bucket) tags() (map[string]interface{}, error) { return res, nil } -func (a *mqlAwsS3Bucket) location() (string, error) { - bucketname := a.Name.Data - - conn := a.MqlRuntime.Connection.(*connection.AwsConnection) - - svc := conn.S3("") - ctx := context.Background() - - location, err := svc.GetBucketLocation(ctx, &s3.GetBucketLocationInput{ - Bucket: &bucketname, - }) - if err != nil { - return "", err - } - - region := string(location.LocationConstraint) - // us-east-1 returns "" therefore we set it explicitly - // https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetBucketLocation.html#API_GetBucketLocation_ResponseSyntax - if region == "" { - region = "us-east-1" - } - return region, nil -} - func (a *mqlAwsS3Bucket) gatherAcl() (*s3.GetBucketAclOutput, error) { bucketname := a.Name.Data location := a.Location.Data + if a.Location.Data == "" { + l, err := a.fetchBucketLocation() + if err != nil { + return nil, err + } + a.Location.Data = l + } conn := a.MqlRuntime.Connection.(*connection.AwsConnection) @@ -296,6 +309,13 @@ func (a *mqlAwsS3Bucket) gatherAcl() (*s3.GetBucketAclOutput, error) { } func (a *mqlAwsS3Bucket) acl() ([]interface{}, error) { + if a.Location.Data == "" { + l, err := a.fetchBucketLocation() + if err != nil { + return nil, err + } + a.Location.Data = l + } bucketname := a.Name.Data acl, err := a.gatherAcl() @@ -344,6 +364,13 @@ func (a *mqlAwsS3Bucket) acl() ([]interface{}, error) { } func (a *mqlAwsS3Bucket) publicAccessBlock() (interface{}, error) { + if a.Location.Data == "" { + l, err := a.fetchBucketLocation() + if err != nil { + return nil, err + } + a.Location.Data = l + } bucketname := a.Name.Data location := a.Location.Data conn := a.MqlRuntime.Connection.(*connection.AwsConnection) @@ -365,6 +392,13 @@ func (a *mqlAwsS3Bucket) publicAccessBlock() (interface{}, error) { } func (a *mqlAwsS3Bucket) owner() (map[string]interface{}, error) { + if a.Location.Data == "" { + l, err := a.fetchBucketLocation() + if err != nil { + return nil, err + } + a.Location.Data = l + } acl, err := a.gatherAcl() if err != nil { return nil, err @@ -388,6 +422,13 @@ const ( ) func (a *mqlAwsS3Bucket) public() (bool, error) { + if a.Location.Data == "" { + l, err := a.fetchBucketLocation() + if err != nil { + return false, err + } + a.Location.Data = l + } acl, err := a.gatherAcl() if err != nil { return false, err @@ -403,6 +444,13 @@ func (a *mqlAwsS3Bucket) public() (bool, error) { } func (a *mqlAwsS3Bucket) cors() ([]interface{}, error) { + if a.Location.Data == "" { + l, err := a.fetchBucketLocation() + if err != nil { + return nil, err + } + a.Location.Data = l + } bucketname := a.Name.Data location := a.Location.Data @@ -443,6 +491,13 @@ func (a *mqlAwsS3Bucket) cors() ([]interface{}, error) { } func (a *mqlAwsS3Bucket) logging() (map[string]interface{}, error) { + if a.Location.Data == "" { + l, err := a.fetchBucketLocation() + if err != nil { + return nil, err + } + a.Location.Data = l + } bucketname := a.Name.Data bucketlocation := a.Location.Data @@ -479,6 +534,13 @@ func (a *mqlAwsS3Bucket) logging() (map[string]interface{}, error) { } func (a *mqlAwsS3Bucket) versioning() (map[string]interface{}, error) { + if a.Location.Data == "" { + l, err := a.fetchBucketLocation() + if err != nil { + return nil, err + } + a.Location.Data = l + } bucketname := a.Name.Data location := a.Location.Data @@ -505,6 +567,13 @@ func (a *mqlAwsS3Bucket) versioning() (map[string]interface{}, error) { } func (a *mqlAwsS3Bucket) replication() (interface{}, error) { + if a.Location.Data == "" { + l, err := a.fetchBucketLocation() + if err != nil { + return nil, err + } + a.Location.Data = l + } bucketname := a.Name.Data region := a.Location.Data @@ -526,6 +595,13 @@ func (a *mqlAwsS3Bucket) replication() (interface{}, error) { } func (a *mqlAwsS3Bucket) encryption() (interface{}, error) { + if a.Location.Data == "" { + l, err := a.fetchBucketLocation() + if err != nil { + return nil, err + } + a.Location.Data = l + } bucketname := a.Name.Data region := a.Location.Data @@ -552,6 +628,13 @@ func (a *mqlAwsS3Bucket) encryption() (interface{}, error) { } func (a *mqlAwsS3Bucket) defaultLock() (string, error) { + if a.Location.Data == "" { + l, err := a.fetchBucketLocation() + if err != nil { + return "", err + } + a.Location.Data = l + } bucketname := a.Name.Data region := a.Location.Data @@ -574,6 +657,13 @@ func (a *mqlAwsS3Bucket) defaultLock() (string, error) { } func (a *mqlAwsS3Bucket) staticWebsiteHosting() (map[string]interface{}, error) { + if a.Location.Data == "" { + l, err := a.fetchBucketLocation() + if err != nil { + return nil, err + } + a.Location.Data = l + } bucketname := a.Name.Data region := a.Location.Data