Skip to content

Commit

Permalink
[serverless] Inject trace context into SQS/SNS/EventBridge (#2917)
Browse files Browse the repository at this point in the history
Co-authored-by: Dario Castañé <dario.castane@datadoghq.com>
  • Loading branch information
nhulston and darccio authored Oct 22, 2024
1 parent e104e1e commit a118199
Show file tree
Hide file tree
Showing 8 changed files with 1,021 additions and 3 deletions.
14 changes: 14 additions & 0 deletions contrib/aws/aws-sdk-go-v2/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ import (
"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"

eventBridgeTracer "gopkg.in/DataDog/dd-trace-go.v1/contrib/aws/internal/eventbridge"
snsTracer "gopkg.in/DataDog/dd-trace-go.v1/contrib/aws/internal/sns"
sqsTracer "gopkg.in/DataDog/dd-trace-go.v1/contrib/aws/internal/sqs"
)

const componentName = "aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -105,6 +109,16 @@ func (mw *traceMiddleware) startTraceMiddleware(stack *middleware.Stack) error {
}
span, spanctx := tracer.StartSpanFromContext(ctx, spanName(serviceID, operation), opts...)

// Inject trace context
switch serviceID {
case "SQS":
sqsTracer.EnrichOperation(span, in, operation)
case "SNS":
snsTracer.EnrichOperation(span, in, operation)
case "EventBridge":
eventBridgeTracer.EnrichOperation(span, in, operation)
}

// Handle initialize and continue through the middleware chain.
out, metadata, err = next.HandleInitialize(spanctx, in)
if err != nil && (mw.cfg.errCheck == nil || mw.cfg.errCheck(err)) {
Expand Down
140 changes: 137 additions & 3 deletions contrib/aws/aws-sdk-go-v2/aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package aws
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
Expand All @@ -24,12 +25,13 @@ import (
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/eventbridge"
eventBridgeTypes "github.com/aws/aws-sdk-go-v2/service/eventbridge/types"
"github.com/aws/aws-sdk-go-v2/service/kinesis"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/sfn"
"github.com/aws/aws-sdk-go-v2/service/sns"
"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
sqsTypes "github.com/aws/aws-sdk-go-v2/service/sqs/types"
"github.com/aws/smithy-go/middleware"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -281,6 +283,66 @@ func TestAppendMiddlewareSqsReceiveMessage(t *testing.T) {
}
}

func TestAppendMiddlewareSqsSendMessage(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

expectedStatusCode := 200
server := mockAWS(expectedStatusCode)
defer server.Close()

resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) {
return aws.Endpoint{
PartitionID: "aws",
URL: server.URL,
SigningRegion: "eu-west-1",
}, nil
})

awsCfg := aws.Config{
Region: "eu-west-1",
Credentials: aws.AnonymousCredentials{},
EndpointResolver: resolver,
}

AppendMiddleware(&awsCfg)

sqsClient := sqs.NewFromConfig(awsCfg)
sendMessageInput := &sqs.SendMessageInput{
MessageBody: aws.String("test message"),
QueueUrl: aws.String("https://sqs.us-west-2.amazonaws.com/123456789012/MyQueueName"),
}
_, err := sqsClient.SendMessage(context.Background(), sendMessageInput)
require.NoError(t, err)

spans := mt.FinishedSpans()
require.Len(t, spans, 1)

s := spans[0]
assert.Equal(t, "SQS.request", s.OperationName())
assert.Equal(t, "SendMessage", s.Tag("aws.operation"))
assert.Equal(t, "SQS", s.Tag("aws.service"))
assert.Equal(t, "MyQueueName", s.Tag("queuename"))
assert.Equal(t, "SQS.SendMessage", s.Tag(ext.ResourceName))
assert.Equal(t, "aws.SQS", s.Tag(ext.ServiceName))

// Check for trace context injection
assert.NotNil(t, sendMessageInput.MessageAttributes)
assert.Contains(t, sendMessageInput.MessageAttributes, "_datadog")
ddAttr := sendMessageInput.MessageAttributes["_datadog"]
assert.Equal(t, "String", *ddAttr.DataType)
assert.NotEmpty(t, *ddAttr.StringValue)

// Decode and verify the injected trace context
var traceContext map[string]string
err = json.Unmarshal([]byte(*ddAttr.StringValue), &traceContext)
assert.NoError(t, err)
assert.Contains(t, traceContext, "x-datadog-trace-id")
assert.Contains(t, traceContext, "x-datadog-parent-id")
assert.NotEmpty(t, traceContext["x-datadog-trace-id"])
assert.NotEmpty(t, traceContext["x-datadog-parent-id"])
}

func TestAppendMiddlewareS3ListObjects(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -441,6 +503,22 @@ func TestAppendMiddlewareSnsPublish(t *testing.T) {
assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL))
assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component))
assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind))

// Check for trace context injection
assert.NotNil(t, tt.publishInput.MessageAttributes)
assert.Contains(t, tt.publishInput.MessageAttributes, "_datadog")
ddAttr := tt.publishInput.MessageAttributes["_datadog"]
assert.Equal(t, "Binary", *ddAttr.DataType)
assert.NotEmpty(t, ddAttr.BinaryValue)

// Decode and verify the injected trace context
var traceContext map[string]string
err := json.Unmarshal(ddAttr.BinaryValue, &traceContext)
assert.NoError(t, err)
assert.Contains(t, traceContext, "x-datadog-trace-id")
assert.Contains(t, traceContext, "x-datadog-parent-id")
assert.NotEmpty(t, traceContext["x-datadog-trace-id"])
assert.NotEmpty(t, traceContext["x-datadog-parent-id"])
})
}
}
Expand Down Expand Up @@ -657,6 +735,62 @@ func TestAppendMiddlewareEventBridgePutRule(t *testing.T) {
}
}

func TestAppendMiddlewareEventBridgePutEvents(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

expectedStatusCode := 200
server := mockAWS(expectedStatusCode)
defer server.Close()

resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) {
return aws.Endpoint{
PartitionID: "aws",
URL: server.URL,
SigningRegion: "eu-west-1",
}, nil
})

awsCfg := aws.Config{
Region: "eu-west-1",
Credentials: aws.AnonymousCredentials{},
EndpointResolver: resolver,
}

AppendMiddleware(&awsCfg)

eventbridgeClient := eventbridge.NewFromConfig(awsCfg)
putEventsInput := &eventbridge.PutEventsInput{
Entries: []eventBridgeTypes.PutEventsRequestEntry{
{
EventBusName: aws.String("my-event-bus"),
Detail: aws.String(`{"key": "value"}`),
},
},
}
eventbridgeClient.PutEvents(context.Background(), putEventsInput)

spans := mt.FinishedSpans()
require.Len(t, spans, 1)

s := spans[0]
assert.Equal(t, "PutEvents", s.Tag("aws.operation"))
assert.Equal(t, "EventBridge.PutEvents", s.Tag(ext.ResourceName))

// Check for trace context injection
assert.Len(t, putEventsInput.Entries, 1)
entry := putEventsInput.Entries[0]
var detail map[string]interface{}
err := json.Unmarshal([]byte(*entry.Detail), &detail)
assert.NoError(t, err)
assert.Contains(t, detail, "_datadog")
ddData, ok := detail["_datadog"].(map[string]interface{})
assert.True(t, ok)
assert.Contains(t, ddData, "x-datadog-start-time")
assert.Contains(t, ddData, "x-datadog-resource-name")
assert.Equal(t, "my-event-bus", ddData["x-datadog-resource-name"])
}

func TestAppendMiddlewareSfnDescribeStateMachine(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -971,8 +1105,8 @@ func TestMessagingNamingSchema(t *testing.T) {
_, err = sqsClient.SendMessage(ctx, msg)
require.NoError(t, err)

entry := types.SendMessageBatchRequestEntry{Id: aws.String("1"), MessageBody: aws.String("body")}
batchMsg := &sqs.SendMessageBatchInput{QueueUrl: sqsResp.QueueUrl, Entries: []types.SendMessageBatchRequestEntry{entry}}
entry := sqsTypes.SendMessageBatchRequestEntry{Id: aws.String("1"), MessageBody: aws.String("body")}
batchMsg := &sqs.SendMessageBatchInput{QueueUrl: sqsResp.QueueUrl, Entries: []sqsTypes.SendMessageBatchRequestEntry{entry}}
_, err = sqsClient.SendMessageBatch(ctx, batchMsg)
require.NoError(t, err)

Expand Down
112 changes: 112 additions & 0 deletions contrib/aws/internal/eventbridge/eventbridge.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Unless explicitly stated otherwise all files in this repository are licensed
// under the Apache License Version 2.0.
// This product includes software developed at Datadog (https://www.datadoghq.com/).
// Copyright 2016 Datadog, Inc.

package eventbridge

import (
"encoding/json"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/eventbridge"
"github.com/aws/aws-sdk-go-v2/service/eventbridge/types"
"github.com/aws/smithy-go/middleware"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/DataDog/dd-trace-go.v1/internal/log"
"strconv"
"time"
)

const (
datadogKey = "_datadog"
startTimeKey = "x-datadog-start-time"
resourceNameKey = "x-datadog-resource-name"
maxSizeBytes = 256 * 1024 // 256 KB
)

func EnrichOperation(span tracer.Span, in middleware.InitializeInput, operation string) {
switch operation {
case "PutEvents":
handlePutEvents(span, in)
}
}

func handlePutEvents(span tracer.Span, in middleware.InitializeInput) {
params, ok := in.Parameters.(*eventbridge.PutEventsInput)
if !ok {
log.Debug("Unable to read PutEvents params")
return
}

// Create trace context
carrier := tracer.TextMapCarrier{}
err := tracer.Inject(span.Context(), carrier)
if err != nil {
log.Debug("Unable to inject trace context: %s", err)
return
}

// Add start time
startTimeMillis := time.Now().UnixMilli()
carrier[startTimeKey] = strconv.FormatInt(startTimeMillis, 10)

carrierJSON, err := json.Marshal(carrier)
if err != nil {
log.Debug("Unable to marshal trace context: %s", err)
return
}

// Remove last '}'
reusedTraceContext := string(carrierJSON[:len(carrierJSON)-1])

for i := range params.Entries {
injectTraceContext(reusedTraceContext, &params.Entries[i])
}
}

func injectTraceContext(baseTraceContext string, entryPtr *types.PutEventsRequestEntry) {
if entryPtr == nil {
return
}

// Build the complete trace context
var traceContext string
if entryPtr.EventBusName != nil {
traceContext = fmt.Sprintf(`%s,"%s":"%s"}`, baseTraceContext, resourceNameKey, *entryPtr.EventBusName)
} else {
traceContext = baseTraceContext + "}"
}

// Get current detail string
var detail string
if entryPtr.Detail == nil || *entryPtr.Detail == "" {
detail = "{}"
} else {
detail = *entryPtr.Detail
}

// Basic JSON structure validation
if len(detail) < 2 || detail[len(detail)-1] != '}' {
log.Debug("Unable to parse detail JSON. Not injecting trace context into EventBridge payload.")
return
}

// Create new detail string
var newDetail string
if len(detail) > 2 {
// Case where detail is not empty
newDetail = fmt.Sprintf(`%s,"%s":%s}`, detail[:len(detail)-1], datadogKey, traceContext)
} else {
// Cae where detail is empty
newDetail = fmt.Sprintf(`{"%s":%s}`, datadogKey, traceContext)
}

// Check sizes
if len(newDetail) > maxSizeBytes {
log.Debug("Payload size too large to pass context")
return
}

entryPtr.Detail = aws.String(newDetail)
}
Loading

0 comments on commit a118199

Please sign in to comment.