Skip to content

Commit

Permalink
Remove discovery filters parameter from Discover and solely rely on f…
Browse files Browse the repository at this point in the history
…ilters from AwsConnection.

Signed-off-by: Vasil Sirakov <sirakov97@gmail.com>
  • Loading branch information
VasilSirakov committed Sep 25, 2024
1 parent 77779d0 commit 9247bf5
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 106 deletions.
22 changes: 16 additions & 6 deletions providers/aws/connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ type GeneralResourceDiscoveryFilters struct {
}

type Ec2DiscoveryFilters struct {
Regions []string
Tags map[string]string
InstanceIds []string
Regions []string
Tags map[string]string
InstanceIds []string
ExcludeRegions []string
ExcludeTags map[string]string
ExcludeInstanceIds []string
}
type EcrDiscoveryFilters struct {
Tags []string
Expand Down Expand Up @@ -126,9 +129,10 @@ func NewAwsConnection(id uint32, asset *inventory.Asset, conf *inventory.Config)
return c, nil
}

// TODO: @vasil - unit test.
func parseOptsToFilters(opts map[string]string) DiscoveryFilters {
d := DiscoveryFilters{
Ec2DiscoveryFilters: Ec2DiscoveryFilters{Tags: map[string]string{}},
Ec2DiscoveryFilters: Ec2DiscoveryFilters{Tags: map[string]string{}, ExcludeTags: map[string]string{}},
EcsDiscoveryFilters: EcsDiscoveryFilters{},
EcrDiscoveryFilters: EcrDiscoveryFilters{Tags: []string{}},
GeneralDiscoveryFilters: GeneralResourceDiscoveryFilters{Tags: map[string]string{}},
Expand All @@ -137,12 +141,18 @@ func parseOptsToFilters(opts map[string]string) DiscoveryFilters {
switch {
case strings.HasPrefix(k, "ec2:tag:"):
d.Ec2DiscoveryFilters.Tags[strings.TrimPrefix(k, "ec2:tag:")] = v
case k == "ec2:region":
case strings.HasPrefix(k, "exclude:ec2:tag:"):
d.Ec2DiscoveryFilters.ExcludeTags[strings.TrimPrefix(k, "exclude:ec2:tag:")] = v
case strings.HasPrefix(k, "ec2:region:"):
d.Ec2DiscoveryFilters.Regions = append(d.Ec2DiscoveryFilters.Regions, v)
case strings.HasPrefix(k, "exclude:ec2:region"):
d.Ec2DiscoveryFilters.ExcludeRegions = append(d.Ec2DiscoveryFilters.ExcludeRegions, v)
case k == "all:region", k == "region":
d.GeneralDiscoveryFilters.Regions = append(d.GeneralDiscoveryFilters.Regions, v)
case k == "instance-id":
case strings.HasPrefix(k, "instance-id:"):
d.Ec2DiscoveryFilters.InstanceIds = append(d.Ec2DiscoveryFilters.InstanceIds, v)
case strings.HasPrefix(k, "exclude:instance-id:"):
d.Ec2DiscoveryFilters.ExcludeInstanceIds = append(d.Ec2DiscoveryFilters.ExcludeInstanceIds, v)
case strings.HasPrefix(k, "all:tag:"):
d.GeneralDiscoveryFilters.Tags[strings.TrimPrefix(k, "all:tag:")] = v
case k == "ecr:tag":
Expand Down
95 changes: 95 additions & 0 deletions providers/aws/connection/connection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package connection

import (
"testing"

"github.com/stretchr/testify/require"
)

// testParseOptsToFilters accepts a map which doesn't guarantee a deterministic iteration order. this means that slices
// in the parsed filters need to be compared individually ensuring their elements match regardless of their order.
func compareFilters(t *testing.T, expected, actual DiscoveryFilters) {
require.ElementsMatch(t, expected.Ec2DiscoveryFilters.Regions, actual.Ec2DiscoveryFilters.Regions)
require.ElementsMatch(t, expected.Ec2DiscoveryFilters.ExcludeRegions, actual.Ec2DiscoveryFilters.ExcludeRegions)

require.ElementsMatch(t, expected.Ec2DiscoveryFilters.InstanceIds, actual.Ec2DiscoveryFilters.InstanceIds)
require.ElementsMatch(t, expected.Ec2DiscoveryFilters.ExcludeInstanceIds, actual.Ec2DiscoveryFilters.ExcludeInstanceIds)

require.Equal(t, expected.Ec2DiscoveryFilters.Tags, actual.Ec2DiscoveryFilters.Tags)
require.Equal(t, expected.Ec2DiscoveryFilters.ExcludeTags, actual.Ec2DiscoveryFilters.ExcludeTags)

require.Equal(t, expected.EcsDiscoveryFilters, actual.EcsDiscoveryFilters)

require.Equal(t, expected.EcrDiscoveryFilters.Tags, actual.EcrDiscoveryFilters.Tags)

require.ElementsMatch(t, expected.GeneralDiscoveryFilters.Regions, actual.GeneralDiscoveryFilters.Regions)
require.Equal(t, expected.GeneralDiscoveryFilters.Tags, actual.GeneralDiscoveryFilters.Tags)
}

func TestParseOptsToFilters(t *testing.T) {
t.Run("all opts are mapped to discovery filters correctly", func(t *testing.T) {
opts := map[string]string{
// Ec2DiscoveryFilters.Tags
"ec2:tag:key1": "val1",
"ec2:tag:key2": "val2",
// Ec2DiscoveryFilters.ExcludeTags
"exclude:ec2:tag:key1": "val1",
"exclude:ec2:tag:key2": "val2",
// Ec2DiscoveryFilters.Regions
"ec2:region:us-east-1": "us-east-1",
"ec2:region:us-west-1": "us-west-1",
// Ec2DiscoveryFilters.ExcludeRegions
"exclude:ec2:region:us-east-1": "us-east-1",
"exclude:ec2:region:us-west-1": "us-west-1",
// Ec2DiscoveryFilters.InstanceIds
"instance-id:iid-1": "iid-1",
"instance-id:iid-2": "iid-2",
// Ec2DiscoveryFilters.ExcludeInstanceIds
"exclude:instance-id:iid-1": "iid-1",
"exclude:instance-id:iid-2": "iid-2",
// TODO: @vasil - include others?
}
expected := DiscoveryFilters{
Ec2DiscoveryFilters: Ec2DiscoveryFilters{
Regions: []string{
"us-east-1", "us-west-1",
},
ExcludeRegions: []string{
"us-east-1", "us-west-1",
},
InstanceIds: []string{
"iid-1", "iid-2",
},
ExcludeInstanceIds: []string{
"iid-1", "iid-2",
},
Tags: map[string]string{
"key1": "val1",
"key2": "val2",
},
ExcludeTags: map[string]string{
"key1": "val1",
"key2": "val2",
},
},
EcsDiscoveryFilters: EcsDiscoveryFilters{},
EcrDiscoveryFilters: EcrDiscoveryFilters{Tags: []string{}},
GeneralDiscoveryFilters: GeneralResourceDiscoveryFilters{Tags: map[string]string{}},
}

actual := parseOptsToFilters(opts)
compareFilters(t, expected, actual)
})

t.Run("empty opts are mapped to discovery filters correctly", func(t *testing.T) {
expected := DiscoveryFilters{
Ec2DiscoveryFilters: Ec2DiscoveryFilters{Tags: map[string]string{}, ExcludeTags: map[string]string{}},
EcsDiscoveryFilters: EcsDiscoveryFilters{},
EcrDiscoveryFilters: EcrDiscoveryFilters{Tags: []string{}},
GeneralDiscoveryFilters: GeneralResourceDiscoveryFilters{Tags: map[string]string{}},
}

actual := parseOptsToFilters(map[string]string{})
compareFilters(t, expected, actual)
})
}
2 changes: 1 addition & 1 deletion providers/aws/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,5 +269,5 @@ func (s *Service) discover(conn *connection.AwsConnection) (*inventory.Inventory
return nil, err
}

return resources.Discover(runtime, conn.Filters)
return resources.Discover(runtime)
}
34 changes: 34 additions & 0 deletions providers/aws/resources/aws_ec2.go
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,9 @@ func (a *mqlAwsEc2) getInstances(conn *connection.AwsConnection) []*jobpool.Job
if len(conn.Filters.Ec2DiscoveryFilters.Regions) > 0 {
regions = conn.Filters.Ec2DiscoveryFilters.Regions
}
for _, regionToExclude := range conn.Filters.Ec2DiscoveryFilters.ExcludeRegions {
regions = removeElement(regions, regionToExclude)
}
for _, region := range regions {
regionVal := region
f := func() (jobpool.JobResult, error) {
Expand Down Expand Up @@ -837,6 +840,9 @@ func (a *mqlAwsEc2) gatherInstanceInfo(instances []ec2types.Reservation, regionV
res := []interface{}{}
for _, reservation := range instances {
for _, instance := range reservation.Instances {
if shouldExcludeInstance(instance, conn.Filters.Ec2DiscoveryFilters) {
continue
}
mqlDevices := []interface{}{}
for i := range instance.BlockDeviceMappings {
device := instance.BlockDeviceMappings[i]
Expand Down Expand Up @@ -1769,3 +1775,31 @@ func (a *mqlAwsEc2Vpnconnection) id() (string, error) {
func (a *mqlAwsEc2Vgwtelemetry) id() (string, error) {
return a.OutsideIpAddress.Data, nil
}

// true if the instance should be excluded from results. filtering for excluded regions should happen before we retrieve the EC2 instance.
func shouldExcludeInstance(instance ec2types.Instance, filters connection.Ec2DiscoveryFilters) bool {
for _, id := range filters.ExcludeInstanceIds {
if instance.InstanceId != nil && *instance.InstanceId == id {
return true
}
}
for k, v := range filters.ExcludeTags {
for _, iTag := range instance.Tags {
if iTag.Key != nil && *iTag.Key == k &&
iTag.Value != nil && *iTag.Value == v {
return true
}
}
}
return false
}

func removeElement(slice []string, value string) []string {
result := []string{}
for _, v := range slice {
if v != value {
result = append(result, v)
}
}
return result
}
75 changes: 75 additions & 0 deletions providers/aws/resources/aws_ec2_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package resources

import (
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/stretchr/testify/require"
"go.mondoo.com/cnquery/v11/providers/aws/connection"
)

func TestShouldExcludeInstance(t *testing.T) {
instance := ec2types.Instance{
InstanceId: aws.String("iid"),
Tags: []ec2types.Tag{
{
Key: aws.String("key-1"),
Value: aws.String("val-1"),
},
{
Key: aws.String("key-2"),
Value: aws.String("val-2"),
},
},
}

t.Run("should exclude instance by id", func(t *testing.T) {
filters := connection.Ec2DiscoveryFilters{
ExcludeInstanceIds: []string{
"iid",
},
ExcludeTags: map[string]string{
"key-3": "val3",
},
}
require.True(t, shouldExcludeInstance(instance, filters))
})

t.Run("should exclude instance by matching tag", func(t *testing.T) {
filters := connection.Ec2DiscoveryFilters{
ExcludeInstanceIds: []string{
"iid-2",
},
ExcludeTags: map[string]string{
"key-2": "val2",
},
}
require.False(t, shouldExcludeInstance(instance, filters))
})

t.Run("should not exclude instance with only a matching tag key", func(t *testing.T) {
filters := connection.Ec2DiscoveryFilters{
ExcludeInstanceIds: []string{
"iid-2",
},
ExcludeTags: map[string]string{
"key-2": "val3",
"key-3": "val3",
},
}
require.False(t, shouldExcludeInstance(instance, filters))
})

t.Run("should not exclude instance when instance id and tags don't match", func(t *testing.T) {
filters := connection.Ec2DiscoveryFilters{
ExcludeInstanceIds: []string{
"iid-2",
},
ExcludeTags: map[string]string{
"key-3": "val3",
},
}
require.False(t, shouldExcludeInstance(instance, filters))
})
}
47 changes: 6 additions & 41 deletions providers/aws/resources/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,35 +114,6 @@ func containsInterfaceSlice(sl []interface{}, s string) bool {
return false
}

func instanceMatchesFilters(instance *mqlAwsEc2Instance, filters connection.DiscoveryFilters) bool {
regions := []string{}
if len(filters.GeneralDiscoveryFilters.Regions) > 0 {
regions = append(regions, filters.GeneralDiscoveryFilters.Regions...)
}
if len(filters.Ec2DiscoveryFilters.Regions) > 0 {
regions = append(regions, filters.Ec2DiscoveryFilters.Regions...)
}
if len(regions) > 0 && !contains(regions, instance.Region.Data) {
return false
}
if len(filters.Ec2DiscoveryFilters.InstanceIds) > 0 {
if !contains(filters.Ec2DiscoveryFilters.InstanceIds, instance.InstanceId.Data) {
return false
}
}
if len(filters.Ec2DiscoveryFilters.Tags) > 0 {
for k, v := range filters.Ec2DiscoveryFilters.Tags {
if instance.Tags.Data[k] == nil {
return false
}
if instance.Tags.Data[k].(string) != v {
return false
}
}
}
return true
}

func imageMatchesFilters(image *mqlAwsEcrImage, filters connection.DiscoveryFilters) bool {
f := filters.EcrDiscoveryFilters
if len(f.Tags) > 0 {
Expand Down Expand Up @@ -185,9 +156,8 @@ func discoveredAssetMatchesGeneralFilters(asset *inventory.Asset, filters connec
return true
}

func Discover(runtime *plugin.Runtime, filters connection.DiscoveryFilters) (*inventory.Inventory, error) {
func Discover(runtime *plugin.Runtime) (*inventory.Inventory, error) {
conn := runtime.Connection.(*connection.AwsConnection)

in := &inventory.Inventory{Spec: &inventory.InventorySpec{
Assets: []*inventory.Asset{},
}}
Expand All @@ -202,15 +172,15 @@ func Discover(runtime *plugin.Runtime, filters connection.DiscoveryFilters) (*in
targets := handleTargets(conn.Conf.Discover.Targets)
for i := range targets {
target := targets[i]
list, err := discover(runtime, awsAccount, target, filters)
list, err := discover(runtime, awsAccount, target, conn.Filters)
if err != nil {
log.Error().Err(err).Msg("error during discovery")
continue
}
if len(filters.GeneralDiscoveryFilters.Tags) > 0 {
if len(conn.Filters.GeneralDiscoveryFilters.Tags) > 0 {
newList := []*inventory.Asset{}
for i := range list {
if discoveredAssetMatchesGeneralFilters(list[i], filters.GeneralDiscoveryFilters) {
if discoveredAssetMatchesGeneralFilters(list[i], conn.Filters.GeneralDiscoveryFilters) {
newList = append(newList, list[i])
}
}
Expand Down Expand Up @@ -274,18 +244,15 @@ func discover(runtime *plugin.Runtime, awsAccount *mqlAwsAccount, target string,

ec2 := res.(*mqlAwsEc2)

// get instances already filters out instances not matched by the filters specified in the AwsConnection
ins := ec2.GetInstances()
if ins == nil {
return assetList, nil
}

for i := range ins.Data {
instance := ins.Data[i].(*mqlAwsEc2Instance)
if !instanceMatchesFilters(instance, filters) {
continue
}
assetList = append(assetList, addConnectionInfoToEc2Asset(instance, accountId, conn))

}
case DiscoverySSMInstances:
res, err := NewResource(runtime, "aws.ssm", map[string]*llx.RawData{})
Expand Down Expand Up @@ -431,16 +398,14 @@ func discover(runtime *plugin.Runtime, awsAccount *mqlAwsAccount, target string,

ec2 := res.(*mqlAwsEc2)

// get instances already filters out instances not matched by the filters specified in the AwsConnection
ins := ec2.GetInstances()
if ins == nil {
return assetList, nil
}

for i := range ins.Data {
instance := ins.Data[i].(*mqlAwsEc2Instance)
if !instanceMatchesFilters(instance, filters) {
continue
}
l := mapStringInterfaceToStringString(instance.Tags.Data)
assetList = append(assetList, MqlObjectToAsset(accountId,
mqlObject{
Expand Down
Loading

0 comments on commit 9247bf5

Please sign in to comment.