Skip to content

Commit

Permalink
refactor aws_kinesis_stream_scaler (#6400)
Browse files Browse the repository at this point in the history
Signed-off-by: Omer Aplatony <omerap12@gmail.com>
  • Loading branch information
omerap12 authored Dec 7, 2024
1 parent c2e19c1 commit 88ddc39
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 94 deletions.
69 changes: 17 additions & 52 deletions pkg/scalers/aws_kinesis_stream_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package scalers
import (
"context"
"fmt"
"strconv"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kinesis"
Expand Down Expand Up @@ -41,11 +40,11 @@ func (w kinesisWrapperClient) DescribeStreamSummary(ctx context.Context, params
}

type awsKinesisStreamMetadata struct {
targetShardCount int64
activationTargetShardCount int64
streamName string
awsRegion string
awsEndpoint string
TargetShardCount int64 `keda:"name=shardCount, order=triggerMetadata, default=2"`
ActivationTargetShardCount int64 `keda:"name=activationShardCount, order=triggerMetadata, default=0"`
StreamName string `keda:"name=streamName, order=triggerMetadata"`
AwsRegion string `keda:"name=awsRegion, order=triggerMetadata"`
AwsEndpoint string `keda:"name=awsEndpoint, order=triggerMetadata, optional"`
awsAuthorization awsutils.AuthorizationMetadata
triggerIndex int
}
Expand All @@ -59,7 +58,7 @@ func NewAwsKinesisStreamScaler(ctx context.Context, config *scalersconfig.Scaler

logger := InitializeLogger(config, "aws_kinesis_stream_scaler")

meta, err := parseAwsKinesisStreamMetadata(config, logger)
meta, err := parseAwsKinesisStreamMetadata(config)
if err != nil {
return nil, fmt.Errorf("error parsing Kinesis stream metadata: %w", err)
}
Expand All @@ -78,44 +77,11 @@ func NewAwsKinesisStreamScaler(ctx context.Context, config *scalersconfig.Scaler
}, nil
}

func parseAwsKinesisStreamMetadata(config *scalersconfig.ScalerConfig, logger logr.Logger) (*awsKinesisStreamMetadata, error) {
meta := awsKinesisStreamMetadata{}
meta.targetShardCount = targetShardCountDefault

if val, ok := config.TriggerMetadata["shardCount"]; ok && val != "" {
shardCount, err := strconv.ParseInt(val, 10, 64)
if err != nil {
meta.targetShardCount = targetShardCountDefault
logger.Error(err, "Error parsing Kinesis stream metadata shardCount, using default %n", targetShardCountDefault)
} else {
meta.targetShardCount = shardCount
}
}

if val, ok := config.TriggerMetadata["activationShardCount"]; ok && val != "" {
activationShardCount, err := strconv.ParseInt(val, 10, 64)
if err != nil {
meta.activationTargetShardCount = activationTargetShardCountDefault
logger.Error(err, "Error parsing Kinesis stream metadata activationShardCount, using default %n", activationTargetShardCountDefault)
} else {
meta.activationTargetShardCount = activationShardCount
}
}
func parseAwsKinesisStreamMetadata(config *scalersconfig.ScalerConfig) (*awsKinesisStreamMetadata, error) {
meta := &awsKinesisStreamMetadata{}

if val, ok := config.TriggerMetadata["streamName"]; ok && val != "" {
meta.streamName = val
} else {
return nil, fmt.Errorf("no streamName given")
}

if val, ok := config.TriggerMetadata["awsRegion"]; ok && val != "" {
meta.awsRegion = val
} else {
return nil, fmt.Errorf("no awsRegion given")
}

if val, ok := config.TriggerMetadata["awsEndpoint"]; ok {
meta.awsEndpoint = val
if err := config.TypedConfig(meta); err != nil {
return nil, fmt.Errorf("error parsing Kinesis stream metadata: %w", err)
}

auth, err := awsutils.GetAwsAuthorization(config.TriggerUniqueKey, config.PodIdentity, config.TriggerMetadata, config.AuthParams, config.ResolvedEnv)
Expand All @@ -124,10 +90,9 @@ func parseAwsKinesisStreamMetadata(config *scalersconfig.ScalerConfig, logger lo
}

meta.awsAuthorization = auth

meta.triggerIndex = config.TriggerIndex

return &meta, nil
return meta, nil
}

func createKinesisClient(ctx context.Context, metadata *awsKinesisStreamMetadata) (*kinesis.Client, error) {
Expand All @@ -136,8 +101,8 @@ func createKinesisClient(ctx context.Context, metadata *awsKinesisStreamMetadata
return nil, err
}
return kinesis.NewFromConfig(*cfg, func(options *kinesis.Options) {
if metadata.awsEndpoint != "" {
options.BaseEndpoint = aws.String(metadata.awsEndpoint)
if metadata.AwsEndpoint != "" {
options.BaseEndpoint = aws.String(metadata.AwsEndpoint)
}
}), nil
}
Expand All @@ -150,9 +115,9 @@ func (s *awsKinesisStreamScaler) Close(context.Context) error {
func (s *awsKinesisStreamScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec {
externalMetric := &v2.ExternalMetricSource{
Metric: v2.MetricIdentifier{
Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("aws-kinesis-%s", s.metadata.streamName))),
Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("aws-kinesis-%s", s.metadata.StreamName))),
},
Target: GetMetricTarget(s.metricType, s.metadata.targetShardCount),
Target: GetMetricTarget(s.metricType, s.metadata.TargetShardCount),
}
metricSpec := v2.MetricSpec{External: externalMetric, Type: externalMetricType}
return []v2.MetricSpec{metricSpec}
Expand All @@ -169,13 +134,13 @@ func (s *awsKinesisStreamScaler) GetMetricsAndActivity(ctx context.Context, metr

metric := GenerateMetricInMili(metricName, float64(shardCount))

return []external_metrics.ExternalMetricValue{metric}, shardCount > s.metadata.activationTargetShardCount, nil
return []external_metrics.ExternalMetricValue{metric}, shardCount > s.metadata.ActivationTargetShardCount, nil
}

// GetAwsKinesisOpenShardCount Get Kinesis open shard count
func (s *awsKinesisStreamScaler) GetAwsKinesisOpenShardCount(ctx context.Context) (int64, error) {
input := &kinesis.DescribeStreamSummaryInput{
StreamName: &s.metadata.streamName,
StreamName: &s.metadata.StreamName,
}

output, err := s.kinesisWrapperClient.DescribeStreamSummary(ctx, input)
Expand Down
75 changes: 33 additions & 42 deletions pkg/scalers/aws_kinesis_stream_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ var testAWSKinesisMetadata = []parseAWSKinesisMetadataTestData{
authParams: testAWSKinesisAuthentication,
expected: &awsKinesisStreamMetadata{},
isError: true,
comment: "metadata empty"},
comment: "metadata empty",
},
{
metadata: map[string]string{
"streamName": testAWSKinesisStreamName,
Expand All @@ -77,10 +78,10 @@ var testAWSKinesisMetadata = []parseAWSKinesisMetadataTestData{
"awsRegion": testAWSRegion},
authParams: testAWSKinesisAuthentication,
expected: &awsKinesisStreamMetadata{
targetShardCount: 2,
activationTargetShardCount: 1,
streamName: testAWSKinesisStreamName,
awsRegion: testAWSRegion,
TargetShardCount: 2,
ActivationTargetShardCount: 1,
StreamName: testAWSKinesisStreamName,
AwsRegion: testAWSRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSKinesisAccessKeyID,
AwsSecretAccessKey: testAWSKinesisSecretAccessKey,
Expand All @@ -101,11 +102,11 @@ var testAWSKinesisMetadata = []parseAWSKinesisMetadataTestData{
"awsEndpoint": testAWSEndpoint},
authParams: testAWSKinesisAuthentication,
expected: &awsKinesisStreamMetadata{
targetShardCount: 2,
activationTargetShardCount: 1,
streamName: testAWSKinesisStreamName,
awsRegion: testAWSRegion,
awsEndpoint: testAWSEndpoint,
TargetShardCount: 2,
ActivationTargetShardCount: 1,
StreamName: testAWSKinesisStreamName,
AwsRegion: testAWSRegion,
AwsEndpoint: testAWSEndpoint,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSKinesisAccessKeyID,
AwsSecretAccessKey: testAWSKinesisSecretAccessKey,
Expand Down Expand Up @@ -147,10 +148,10 @@ var testAWSKinesisMetadata = []parseAWSKinesisMetadataTestData{
"awsRegion": testAWSRegion},
authParams: testAWSKinesisAuthentication,
expected: &awsKinesisStreamMetadata{
targetShardCount: targetShardCountDefault,
activationTargetShardCount: activationTargetShardCountDefault,
streamName: testAWSKinesisStreamName,
awsRegion: testAWSRegion,
TargetShardCount: targetShardCountDefault,
ActivationTargetShardCount: activationTargetShardCountDefault,
StreamName: testAWSKinesisStreamName,
AwsRegion: testAWSRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSKinesisAccessKeyID,
AwsSecretAccessKey: testAWSKinesisSecretAccessKey,
Expand All @@ -167,20 +168,10 @@ var testAWSKinesisMetadata = []parseAWSKinesisMetadataTestData{
"streamName": testAWSKinesisStreamName,
"shardCount": "a",
"awsRegion": testAWSRegion},
authParams: testAWSKinesisAuthentication,
expected: &awsKinesisStreamMetadata{
targetShardCount: 2,
streamName: testAWSKinesisStreamName,
awsRegion: testAWSRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSKinesisAccessKeyID,
AwsSecretAccessKey: testAWSKinesisSecretAccessKey,
PodIdentityOwner: true,
},
triggerIndex: 4,
},
isError: false,
comment: "properly formed stream name and region, wrong shard count",
authParams: testAWSKinesisAuthentication,
expected: &awsKinesisStreamMetadata{},
isError: true,
comment: "invalid shardCount value",
triggerIndex: 4,
},
{
Expand Down Expand Up @@ -221,9 +212,9 @@ var testAWSKinesisMetadata = []parseAWSKinesisMetadataTestData{
"awsSessionToken": testAWSKinesisSessionToken,
},
expected: &awsKinesisStreamMetadata{
targetShardCount: 2,
streamName: testAWSKinesisStreamName,
awsRegion: testAWSRegion,
TargetShardCount: 2,
StreamName: testAWSKinesisStreamName,
AwsRegion: testAWSRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSKinesisAccessKeyID,
AwsSecretAccessKey: testAWSKinesisSecretAccessKey,
Expand Down Expand Up @@ -273,9 +264,9 @@ var testAWSKinesisMetadata = []parseAWSKinesisMetadataTestData{
"awsRoleArn": testAWSKinesisRoleArn,
},
expected: &awsKinesisStreamMetadata{
targetShardCount: 2,
streamName: testAWSKinesisStreamName,
awsRegion: testAWSRegion,
TargetShardCount: 2,
StreamName: testAWSKinesisStreamName,
AwsRegion: testAWSRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsRoleArn: testAWSKinesisRoleArn,
PodIdentityOwner: true,
Expand All @@ -293,9 +284,9 @@ var testAWSKinesisMetadata = []parseAWSKinesisMetadataTestData{
"identityOwner": "operator"},
authParams: map[string]string{},
expected: &awsKinesisStreamMetadata{
targetShardCount: 2,
streamName: testAWSKinesisStreamName,
awsRegion: testAWSRegion,
TargetShardCount: 2,
StreamName: testAWSKinesisStreamName,
AwsRegion: testAWSRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
PodIdentityOwner: false,
},
Expand All @@ -313,13 +304,13 @@ var awsKinesisMetricIdentifiers = []awsKinesisMetricIdentifier{
}

var awsKinesisGetMetricTestData = []*awsKinesisStreamMetadata{
{streamName: "Good"},
{streamName: testAWSKinesisErrorStream},
{StreamName: "Good"},
{StreamName: testAWSKinesisErrorStream},
}

func TestKinesisParseMetadata(t *testing.T) {
for _, testData := range testAWSKinesisMetadata {
result, err := parseAwsKinesisStreamMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: testAWSKinesisAuthentication, AuthParams: testData.authParams, TriggerIndex: testData.triggerIndex}, logr.Discard())
result, err := parseAwsKinesisStreamMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: testAWSKinesisAuthentication, AuthParams: testData.authParams, TriggerIndex: testData.triggerIndex})
if err != nil && !testData.isError {
t.Errorf("Expected success because %s got error, %s", testData.comment, err)
}
Expand All @@ -336,7 +327,7 @@ func TestKinesisParseMetadata(t *testing.T) {
func TestAWSKinesisGetMetricSpecForScaling(t *testing.T) {
for _, testData := range awsKinesisMetricIdentifiers {
ctx := context.Background()
meta, err := parseAwsKinesisStreamMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadataTestData.metadata, ResolvedEnv: testAWSKinesisAuthentication, AuthParams: testData.metadataTestData.authParams, TriggerIndex: testData.triggerIndex}, logr.Discard())
meta, err := parseAwsKinesisStreamMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadataTestData.metadata, ResolvedEnv: testAWSKinesisAuthentication, AuthParams: testData.metadataTestData.authParams, TriggerIndex: testData.triggerIndex})
if err != nil {
t.Fatal("Could not parse metadata:", err)
}
Expand All @@ -354,7 +345,7 @@ func TestAWSKinesisStreamScalerGetMetrics(t *testing.T) {
for _, meta := range awsKinesisGetMetricTestData {
scaler := awsKinesisStreamScaler{"", meta, &mockKinesis{}, logr.Discard()}
value, _, err := scaler.GetMetricsAndActivity(context.Background(), "MetricName")
switch meta.streamName {
switch meta.StreamName {
case testAWSKinesisErrorStream:
assert.Error(t, err, "expect error because of kinesis api error")
default:
Expand Down

0 comments on commit 88ddc39

Please sign in to comment.