Skip to content

Commit

Permalink
Use .get to avoid keyError issues (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
hghotra authored May 9, 2022
1 parent b14e6aa commit 8d80e67
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 33 deletions.
48 changes: 24 additions & 24 deletions datadog_lambda/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,8 @@ def create_inferred_span(event, context):
def create_inferred_span_from_lambda_function_url_event(event, context):
request_context = event.get("requestContext")
domain = request_context.get("domainName")
method = request_context.get("http").get("method")
path = request_context.get("http").get("path")
method = request_context.get("http", {}).get("method")
path = request_context.get("http", {}).get("path")
resource = "{0} {1}".format(method, path)
tags = {
"operation_name": "aws.lambda.url",
Expand Down Expand Up @@ -583,11 +583,7 @@ def create_inferred_span_from_lambda_function_url_event(event, context):


def is_api_gateway_invocation_async(event):
return (
"headers" in event
and "X-Amz-Invocation-Type" in event.get("headers")
and event.get("headers").get("X-Amz-Invocation-Type") == "Event"
)
return event.get("headers", {}).get("X-Amz-Invocation-Type") == "Event"

This comment has been minimized.

Copy link
@kimi-p

kimi-p Jun 22, 2022

Contributor

missing one {} in .get("X-Amz-Invocation-Type")?



def create_inferred_span_from_api_gateway_websocket_event(event, context):
Expand Down Expand Up @@ -663,17 +659,17 @@ def create_inferred_span_from_api_gateway_event(event, context):
def create_inferred_span_from_http_api_event(event, context):
request_context = event.get("requestContext")
domain = request_context.get("domainName")
method = request_context.get("http").get("method")
method = request_context.get("http", {}).get("method")
path = event.get("rawPath")
resource = "{0} {1}".format(method, path)
tags = {
"operation_name": "aws.httpapi",
"endpoint": path,
"http.url": domain + path,
"http.method": request_context.get("http").get("method"),
"http.protocol": request_context.get("http").get("protocol"),
"http.source_ip": request_context.get("http").get("sourceIp"),
"http.user_agent": request_context.get("http").get("userAgent"),
"http.method": request_context.get("http", {}).get("method"),
"http.protocol": request_context.get("http", {}).get("protocol"),
"http.source_ip": request_context.get("http", {}).get("sourceIp"),
"http.user_agent": request_context.get("http", {}).get("userAgent"),
"resource_names": resource,
"request_id": context.aws_request_id,
"apiid": request_context.get("apiId"),
Expand Down Expand Up @@ -710,10 +706,10 @@ def create_inferred_span_from_sqs_event(event, context):
"queuename": queue_name,
"event_source_arn": event_source_arn,
"receipt_handle": event_record.get("receiptHandle"),
"sender_id": event_record.get("attributes").get("SenderId"),
"sender_id": event_record.get("attributes", {}).get("SenderId"),
}
InferredSpanInfo.set_tags(tags, tag_source="self", synchronicity="async")
request_time_epoch = event_record.get("attributes").get("SentTimestamp")
request_time_epoch = event_record.get("attributes", {}).get("SentTimestamp")
args = {
"service": "sqs",
"resource": queue_name,
Expand Down Expand Up @@ -756,7 +752,7 @@ def create_inferred_span_from_sqs_event(event, context):
def create_inferred_span_from_sns_event(event, context):
event_record = get_first_record(event)
sns_message = event_record.get("Sns")
topic_arn = event_record.get("Sns").get("TopicArn")
topic_arn = event_record.get("Sns", {}).get("TopicArn")
topic_name = topic_arn.split(":")[-1]
tags = {
"operation_name": "aws.sns",
Expand All @@ -773,7 +769,7 @@ def create_inferred_span_from_sns_event(event, context):

InferredSpanInfo.set_tags(tags, tag_source="self", synchronicity="async")
sns_dt_format = "%Y-%m-%dT%H:%M:%S.%fZ"
timestamp = event_record.get("Sns").get("Timestamp")
timestamp = event_record.get("Sns", {}).get("Timestamp")
dt = datetime.strptime(timestamp, sns_dt_format)

args = {
Expand Down Expand Up @@ -804,10 +800,12 @@ def create_inferred_span_from_kinesis_event(event, context):
"event_id": event_id,
"event_name": event_record.get("eventName"),
"event_version": event_record.get("eventVersion"),
"partition_key": event_record.get("kinesis").get("partitionKey"),
"partition_key": event_record.get("kinesis", {}).get("partitionKey"),
}
InferredSpanInfo.set_tags(tags, tag_source="self", synchronicity="async")
request_time_epoch = event_record.get("kinesis").get("approximateArrivalTimestamp")
request_time_epoch = event_record.get("kinesis", {}).get(
"approximateArrivalTimestamp"
)

args = {
"service": "kinesis",
Expand Down Expand Up @@ -839,7 +837,9 @@ def create_inferred_span_from_dynamodb_event(event, context):
"size_bytes": str(dynamodb_message.get("SizeBytes")),
}
InferredSpanInfo.set_tags(tags, synchronicity="async", tag_source="self")
request_time_epoch = event_record.get("dynamodb").get("ApproximateCreationDateTime")
request_time_epoch = event_record.get("dynamodb", {}).get(
"ApproximateCreationDateTime"
)
args = {
"service": "dynamodb",
"resource": table_name,
Expand All @@ -856,16 +856,16 @@ def create_inferred_span_from_dynamodb_event(event, context):

def create_inferred_span_from_s3_event(event, context):
event_record = get_first_record(event)
bucket_name = event_record.get("s3").get("bucket").get("name")
bucket_name = event_record.get("s3", {}).get("bucket", {}).get("name")
tags = {
"operation_name": "aws.s3",
"resource_names": bucket_name,
"event_name": event_record.get("eventName"),
"bucketname": bucket_name,
"bucket_arn": event_record.get("s3").get("bucket").get("arn"),
"object_key": event_record.get("s3").get("object").get("key"),
"object_size": str(event_record.get("s3").get("object").get("size")),
"object_etag": event_record.get("s3").get("object").get("eTag"),
"bucket_arn": event_record.get("s3", {}).get("bucket", {}).get("arn"),
"object_key": event_record.get("s3", {}).get("object", {}).get("key"),
"object_size": str(event_record.get("s3", {}).get("object", {}).get("size")),
"object_etag": event_record.get("s3", {}).get("object", {}).get("eTag"),
}
InferredSpanInfo.set_tags(tags, synchronicity="async", tag_source="self")
dt_format = "%Y-%m-%dT%H:%M:%S.%fZ"
Expand Down
20 changes: 11 additions & 9 deletions datadog_lambda/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def parse_event_source(event: dict) -> _EventSource:
event_source.subtype = EventSubtypes.API_GATEWAY
if "routeKey" in event:
event_source.subtype = EventSubtypes.HTTP_API
if "requestContext" in event and "messageDirection" in event["requestContext"]:
if event.get("requestContext", {}).get("messageDirection"):
event_source.subtype = EventSubtypes.WEBSOCKET

if request_context and request_context.get("elb"):
Expand Down Expand Up @@ -186,15 +186,17 @@ def parse_event_source_arn(source: _EventSource, event: dict, context: Any) -> s
event_record = get_first_record(event)
# e.g. arn:aws:s3:::lambda-xyz123-abc890
if source.to_string() == "s3":
return event_record.get("s3")["bucket"]["arn"]
return event_record.get("s3", {}).get("bucket", {}).get("arn")

# e.g. arn:aws:sns:us-east-1:123456789012:sns-lambda
if source.to_string() == "sns":
return event_record.get("Sns")["TopicArn"]
return event_record.get("Sns", {}).get("TopicArn")

# e.g. arn:aws:cloudfront::123456789012:distribution/ABC123XYZ
if source.event_type == EventTypes.CLOUDFRONT:
distribution_id = event_record.get("cf")["config"]["distributionId"]
distribution_id = (
event_record.get("cf", {}).get("config", {}).get("distributionId")
)
return "arn:{}:cloudfront::{}:distribution/{}".format(
aws_arn, account_id, distribution_id
)
Expand All @@ -215,18 +217,18 @@ def parse_event_source_arn(source: _EventSource, event: dict, context: Any) -> s
if source.event_type == EventTypes.API_GATEWAY:
request_context = event.get("requestContext")
return "arn:{}:apigateway:{}::/restapis/{}/stages/{}".format(
aws_arn, region, request_context["apiId"], request_context["stage"]
aws_arn, region, request_context.get("apiId"), request_context.get("stage")
)

# e.g. arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/lambda-xyz/123
if source.event_type == EventTypes.ALB:
request_context = event.get("requestContext")
return request_context.get("elb")["targetGroupArn"]
return request_context.get("elb", {}).get("targetGroupArn")

# e.g. arn:aws:logs:us-west-1:123456789012:log-group:/my-log-group-xyz
if source.event_type == EventTypes.CLOUDWATCH_LOGS:
with gzip.GzipFile(
fileobj=BytesIO(base64.b64decode(event["awslogs"]["data"]))
fileobj=BytesIO(base64.b64decode(event.get("awslogs", {}).get("data")))
) as decompress_stream:
data = b"".join(BufferedReader(decompress_stream))
logs = json.loads(data)
Expand Down Expand Up @@ -265,7 +267,7 @@ def extract_http_tags(event):
method = event.get("httpMethod")
if request_context and request_context.get("stage"):
if request_context.get("domainName"):
http_tags["http.url"] = request_context["domainName"]
http_tags["http.url"] = request_context.get("domainName")

path = request_context.get("path")
method = request_context.get("httpMethod")
Expand All @@ -282,7 +284,7 @@ def extract_http_tags(event):

headers = event.get("headers")
if headers and headers.get("Referer"):
http_tags["http.referer"] = headers["Referer"]
http_tags["http.referer"] = headers.get("Referer")

return http_tags

Expand Down

0 comments on commit 8d80e67

Please sign in to comment.