Skip to content

Commit

Permalink
feat(customSageMakerEndpoint): Enable Async configuration for endpoint (
Browse files Browse the repository at this point in the history
#591)

* feat(customSageMakerEndpoint): Enable Async configuration for endpoint (#591)
  • Loading branch information
krokoko authored Jul 31, 2024
1 parent 119b68f commit c633397
Show file tree
Hide file tree
Showing 12 changed files with 505 additions and 171 deletions.
1 change: 1 addition & 0 deletions apidocs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

## Interfaces

- [AsyncInferenceConfig](interfaces/AsyncInferenceConfig.md)
- [BaseClassProps](interfaces/BaseClassProps.md)
- [ContainerImageConfig](interfaces/ContainerImageConfig.md)
- [ContentGenerationAppSyncLambdaProps](interfaces/ContentGenerationAppSyncLambdaProps.md)
Expand Down
12 changes: 12 additions & 0 deletions apidocs/classes/CustomSageMakerEndpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ enable disable xray tracing
***

### errorTopic?

> `readonly` `optional` **errorTopic**: `Topic`
***

### fieldLogLevel

> **fieldLogLevel**: `FieldLogLevel` = `appsync.FieldLogLevel.ALL`
Expand Down Expand Up @@ -212,6 +218,12 @@ Value will be appended to resources name.

***

### successTopic?

> `readonly` `optional` **successTopic**: `Topic`
***

### usageMetricMap

> `protected` `static` **usageMetricMap**: `Record`\<`string`, `number`\>
Expand Down
25 changes: 25 additions & 0 deletions apidocs/interfaces/AsyncInferenceConfig.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[**@cdklabs/generative-ai-cdk-constructs**](../README.md)**Docs**

***

[@cdklabs/generative-ai-cdk-constructs](../README.md) / AsyncInferenceConfig

# Interface: AsyncInferenceConfig

## Properties

### failurePath

> `readonly` **failurePath**: `string`
***

### maxConcurrentInvocationsPerInstance?

> `readonly` `optional` **maxConcurrentInvocationsPerInstance**: `number`
***

### outputPath

> `readonly` **outputPath**: `string`
6 changes: 6 additions & 0 deletions apidocs/interfaces/CustomSageMakerEndpointProps.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@

## Properties

### asyncInference?

> `readonly` `optional` **asyncInference**: [`AsyncInferenceConfig`](AsyncInferenceConfig.md)
***

### container

> `readonly` **container**: [`ContainerImage`](../classes/ContainerImage.md)
Expand Down
44 changes: 42 additions & 2 deletions docs/generative_ai_cdk_constructs.drawio
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<mxfile host="Electron" modified="2024-07-02T00:52:48.483Z" agent="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/24.4.13 Chrome/124.0.6367.207 Electron/30.0.6 Safari/537.36" etag="7Bo-kMY3EzOB2eYsUAEn" version="24.4.13" type="device" pages="10">
<mxfile host="Electron" agent="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/24.7.5 Chrome/126.0.6478.183 Electron/31.3.0 Safari/537.36" version="24.7.5" pages="11">
<diagram id="yqzoU6PykweUqwamPqNK" name="aws-rag-appsync-stepfn-opensearch">
<mxGraphModel dx="2726" dy="658" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
<root>
Expand Down Expand Up @@ -626,7 +626,7 @@
</root>
</mxGraphModel>
</diagram>
<diagram name="CustomSageMakerEndpoint" id="Ld184xT8tr4mMkqV-7Tk">
<diagram name="CustomSageMakerEndpointRealTime" id="Ld184xT8tr4mMkqV-7Tk">
<mxGraphModel dx="1026" dy="658" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
<root>
<mxCell id="hdcXcgM0UPDYjAbK98_j-0" />
Expand Down Expand Up @@ -664,6 +664,46 @@
</root>
</mxGraphModel>
</diagram>
<diagram name="CustomSageMakerEndpointAsync" id="1XxpFNphD4dG7aegsWry">
<mxGraphModel dx="1026" dy="658" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
<root>
<mxCell id="0k5G1l83rrXwrz7Byu6q-0" />
<mxCell id="0k5G1l83rrXwrz7Byu6q-1" parent="0k5G1l83rrXwrz7Byu6q-0" />
<mxCell id="0k5G1l83rrXwrz7Byu6q-2" value="CustomSageMakerEndpoint (async)" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="0k5G1l83rrXwrz7Byu6q-1">
<mxGeometry x="90" y="530" width="610" height="60" as="geometry" />
</mxCell>
<mxCell id="0k5G1l83rrXwrz7Byu6q-3" value="AWS Cloud - User account" style="points=[[0,0],[0.25,0],[0.5,0],[0.75,0],[1,0],[1,0.25],[1,0.5],[1,0.75],[1,1],[0.75,1],[0.5,1],[0.25,1],[0,1],[0,0.75],[0,0.5],[0,0.25]];outlineConnect=0;gradientColor=none;html=1;whiteSpace=wrap;fontSize=12;fontStyle=0;container=1;pointerEvents=0;collapsible=0;recursiveResize=0;shape=mxgraph.aws4.group;grIcon=mxgraph.aws4.group_aws_cloud_alt;strokeColor=#232F3E;fillColor=none;verticalAlign=top;align=left;spacingLeft=30;fontColor=#232F3E;dashed=0;" vertex="1" parent="0k5G1l83rrXwrz7Byu6q-1">
<mxGeometry x="90" y="620" width="610" height="550" as="geometry" />
</mxCell>
<mxCell id="0k5G1l83rrXwrz7Byu6q-4" value="Amazon Simple Storage Service&lt;br&gt;&lt;b&gt;Model artifacts +&lt;/b&gt;&lt;div&gt;&lt;b&gt;input data +&lt;/b&gt;&lt;/div&gt;&lt;div&gt;&lt;b&gt;Inference results&lt;/b&gt;&lt;/div&gt;" style="sketch=0;points=[[0,0,0],[0.25,0,0],[0.5,0,0],[0.75,0,0],[1,0,0],[0,1,0],[0.25,1,0],[0.5,1,0],[0.75,1,0],[1,1,0],[0,0.25,0],[0,0.5,0],[0,0.75,0],[1,0.25,0],[1,0.5,0],[1,0.75,0]];outlineConnect=0;fontColor=#232F3E;fillColor=#7AA116;strokeColor=#ffffff;dashed=0;verticalLabelPosition=bottom;verticalAlign=top;align=center;html=1;fontSize=12;fontStyle=0;aspect=fixed;shape=mxgraph.aws4.resourceIcon;resIcon=mxgraph.aws4.s3;" vertex="1" parent="0k5G1l83rrXwrz7Byu6q-3">
<mxGeometry x="370" y="350" width="78" height="78" as="geometry" />
</mxCell>
<mxCell id="mQpaVntdGapTdL6flM7p-2" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;" edge="1" parent="0k5G1l83rrXwrz7Byu6q-3" source="0k5G1l83rrXwrz7Byu6q-5">
<mxGeometry relative="1" as="geometry">
<mxPoint x="148" y="220" as="targetPoint" />
</mxGeometry>
</mxCell>
<mxCell id="0k5G1l83rrXwrz7Byu6q-5" value="Amazon Elastic&amp;nbsp;&lt;div&gt;Container Registry&lt;br&gt;&lt;/div&gt;" style="sketch=0;points=[[0,0,0],[0.25,0,0],[0.5,0,0],[0.75,0,0],[1,0,0],[0,1,0],[0.25,1,0],[0.5,1,0],[0.75,1,0],[1,1,0],[0,0.25,0],[0,0.5,0],[0,0.75,0],[1,0.25,0],[1,0.5,0],[1,0.75,0]];outlineConnect=0;fontColor=#232F3E;fillColor=#ED7100;strokeColor=#ffffff;dashed=0;verticalLabelPosition=bottom;verticalAlign=top;align=center;html=1;fontSize=12;fontStyle=0;aspect=fixed;shape=mxgraph.aws4.resourceIcon;resIcon=mxgraph.aws4.ecr;" vertex="1" parent="0k5G1l83rrXwrz7Byu6q-3">
<mxGeometry x="110" y="350" width="78" height="78" as="geometry" />
</mxCell>
<mxCell id="mQpaVntdGapTdL6flM7p-0" value="Amazon Simple Notification Service&lt;div&gt;&lt;b&gt;Success and Error notifications&lt;/b&gt;&lt;/div&gt;" style="sketch=0;points=[[0,0,0],[0.25,0,0],[0.5,0,0],[0.75,0,0],[1,0,0],[0,1,0],[0.25,1,0],[0.5,1,0],[0.75,1,0],[1,1,0],[0,0.25,0],[0,0.5,0],[0,0.75,0],[1,0.25,0],[1,0.5,0],[1,0.75,0]];outlineConnect=0;fontColor=#232F3E;fillColor=#E7157B;strokeColor=#ffffff;dashed=0;verticalLabelPosition=bottom;verticalAlign=top;align=center;html=1;fontSize=12;fontStyle=0;aspect=fixed;shape=mxgraph.aws4.resourceIcon;resIcon=mxgraph.aws4.sns;" vertex="1" parent="0k5G1l83rrXwrz7Byu6q-3">
<mxGeometry x="370" y="100" width="78" height="78" as="geometry" />
</mxCell>
<mxCell id="0k5G1l83rrXwrz7Byu6q-8" value="Amazon SageMaker&amp;nbsp;&lt;div&gt;asynchronous endpoint&lt;/div&gt;" style="sketch=0;points=[[0,0,0],[0.25,0,0],[0.5,0,0],[0.75,0,0],[1,0,0],[0,1,0],[0.25,1,0],[0.5,1,0],[0.75,1,0],[1,1,0],[0,0.25,0],[0,0.5,0],[0,0.75,0],[1,0.25,0],[1,0.5,0],[1,0.75,0]];outlineConnect=0;fontColor=#232F3E;fillColor=#01A88D;strokeColor=#ffffff;dashed=0;verticalLabelPosition=bottom;verticalAlign=top;align=center;html=1;fontSize=12;fontStyle=0;aspect=fixed;shape=mxgraph.aws4.resourceIcon;resIcon=mxgraph.aws4.sagemaker;" vertex="1" parent="0k5G1l83rrXwrz7Byu6q-3">
<mxGeometry x="110" y="100" width="78" height="78" as="geometry" />
</mxCell>
<mxCell id="mQpaVntdGapTdL6flM7p-1" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;entryPerimeter=0;" edge="1" parent="0k5G1l83rrXwrz7Byu6q-3" source="0k5G1l83rrXwrz7Byu6q-8" target="mQpaVntdGapTdL6flM7p-0">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="0k5G1l83rrXwrz7Byu6q-7" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;" edge="1" parent="0k5G1l83rrXwrz7Byu6q-1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="460" y="1009" as="sourcePoint" />
<mxPoint x="278" y="1009" as="targetPoint" />
</mxGeometry>
</mxCell>
</root>
</mxGraphModel>
</diagram>
<diagram id="zwjm-m0tJrHEoboqtmw_" name="aws-contentgen-appsync-lambda">
<mxGraphModel dx="2726" dy="669" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
<root>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ env: {
},
```

Here is a minimal deployable pattern definition:
Here is a minimal deployable pattern definition to deploy a real-time Amazon SageMaker endpoint:

TypeScript
```typescript
Expand Down Expand Up @@ -95,6 +95,12 @@ CustomSageMakerEndpoint(
)
```

The construct also allows you to deploy an asyncronous SageMaker endpoint. Amazon SageMaker Asynchronous Inference is a capability in SageMaker that queues incoming requests and processes them asynchronously. This option is ideal for requests with large payload sizes (up to 1GB), long processing times (up to one hour), and near real-time latency requirements.

Asynchronous Inference enables you to save on costs by autoscaling the instance count to zero when there are no requests to process, so you only pay when your endpoint is processing requests. For more information about asynchronous inference, please refer to the [documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference.html).

To configure the endpoint in asynchronous mode, you simply need to define the [AsyncInferenceConfig](#asyncinferenceconfig) in the construct properties. In this case, the construct will provision two Amazon Simple Notification Service topics which can be used to received notifications about inference (failure and success).

## Initializer

```
Expand Down Expand Up @@ -125,6 +131,17 @@ Parameters
| startupHealthCheckTimeoutInSeconds | Integer | ![Optional](https://img.shields.io/badge/optional-4169E1) | The timeout value, in seconds, for your inference container to pass health check by SageMaker Hosting |
| modelDataDownloadTimeoutInSeconds | Integer | ![Optional](https://img.shields.io/badge/optional-4169E1) | The timeout value, in seconds, to download and extract the model that you want to host from Amazon S3 to the individual inference instance associated with this production variant. |
| volumeSizeInGb | Integer | ![Optional](https://img.shields.io/badge/optional-4169E1) | The size, in GB, of the ML storage volume attached to individual inference instance associated with the production variant. Currently only Amazon EBS gp2 storage volumes are supported. |
| asyncInference | AsyncInferenceConfig | ![Optional](https://img.shields.io/badge/optional-4169E1) | Specifies configuration for how an endpoint performs asynchronous inference. Refer to [AsyncInferenceConfig](#asyncinferenceconfig) for details. If not defined, the endpoint will be configured as real-time.|

### AsyncInferenceConfig

If defined, the SageMaker endpoint will perform asynchronous inference.

| **Name** | **Type** | **Required** |**Description** |
|:-------------|:----------------|-----------------|-----------------|
| failurePath | string | ![Required](https://img.shields.io/badge/required-ff0000) | The Amazon S3 location to upload failure inference responses to. This location needs to be in the same bucket containing the model artifacts. |
| outputPath | string | ![Required](https://img.shields.io/badge/required-ff0000) | The Amazon S3 location to upload inference responses to. This location needs to be in the same bucket containing the model artifacts. |
| maxConcurrentInvocationsPerInstance | number | ![Optional](https://img.shields.io/badge/optional-4169E1) | The maximum number of concurrent requests sent by the SageMaker client to the model container. |

## Pattern Properties

Expand All @@ -141,6 +158,8 @@ Parameters
|instanceType| SageMakerInstanceType | The ML compute instance type |
|instanceCount| number | Number of instances to launch initially|
|role| [iam.Role](https://docs.aws.amazon.com/cdk/api/v2/docs/aws-cdk-lib.aws_iam.Role.html) |The IAM role that SageMaker can assume to access model artifacts and docker image for deployment on ML compute instances or for batch transform jobs |
|successTopic| [sns.Topic](https://docs.aws.amazon.com/cdk/api/v2/docs/aws-cdk-lib.aws_sns.Topic.html) | Amazon SNS topic to post a notification to when an inference completes successfully. If async configuration is not provided, this will not be defined.|
|errorTopic| [sns.Topic](https://docs.aws.amazon.com/cdk/api/v2/docs/aws-cdk-lib.aws_sns.Topic.html) | Amazon SNS topic to post a notification to when an inference fails. If async configuration is not provided, this will not be defined.|

## Default properties

Expand All @@ -149,12 +168,25 @@ Parameters
- modelDataDownloadTimeoutInSeconds: 600 if not provided
- instanceCount: 1 if not provided

If async configuration is enabled:
- Enable server-side encryption for SNS Topics using AWS managed KMS Key
- maxConcurrentInvocationsPerInstance: 10 if not provided

## Troubleshooting



## Architecture
![Architecture Diagram](architecture_CustomSageMakerEndpoint.png)

Real-time endpoint architecture:

![Architecture Real-time Diagram](architecture_rt_CustomSageMakerEndpoint.png)

Asynchronous endpoint architecture:

To invoke the endpoint, you need to place the request payload in Amazon Simple Storage Service (S3). You also need to provide a pointer to this payload as a part of the InvokeEndpointAsync request. Upon invocation, SageMaker queues the request for processing and returns an identifier and output location as a response. Upon processing, SageMaker places the result in the Amazon S3 location.

![Architecture Async Diagram](architecture_async_CustomSageMakerEndpoint.png)

## Cost

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,22 @@
*/
import * as cdk from 'aws-cdk-lib';
import * as iam from 'aws-cdk-lib/aws-iam';
import * as kms from 'aws-cdk-lib/aws-kms';
import * as sagemaker from 'aws-cdk-lib/aws-sagemaker';
import * as sns from 'aws-cdk-lib/aws-sns';
import { Construct } from 'constructs';
import { ContainerImage } from './container-image';
import { SageMakerEndpointBase } from './sagemaker-endpoint-base';
import { SageMakerInstanceType } from './sagemaker-instance-type';
import { ConstructName } from '../../../common/base-class';
import { BaseClassProps } from '../../../common/base-class/base-class';

export interface AsyncInferenceConfig {
readonly failurePath: string;
readonly outputPath: string;
readonly maxConcurrentInvocationsPerInstance?: number;
}

export interface CustomSageMakerEndpointProps {
readonly modelId: string;
readonly endpointName: string;
Expand All @@ -33,6 +41,7 @@ export interface CustomSageMakerEndpointProps {
readonly volumeSizeInGb?: number | undefined;
readonly vpcConfig?: sagemaker.CfnModel.VpcConfigProperty | undefined;
readonly modelDataUrl: string;
readonly asyncInference?: AsyncInferenceConfig | undefined;

}

Expand All @@ -42,6 +51,8 @@ export class CustomSageMakerEndpoint extends SageMakerEndpointBase implements ia
public readonly cfnModel: sagemaker.CfnModel;
public readonly cfnEndpoint: sagemaker.CfnEndpoint;
public readonly cfnEndpointConfig: sagemaker.CfnEndpointConfig;
public readonly successTopic?: sns.Topic;
public readonly errorTopic?: sns.Topic;

public readonly instanceType?: SageMakerInstanceType;
public readonly instanceCount: number;
Expand All @@ -64,7 +75,6 @@ export class CustomSageMakerEndpoint extends SageMakerEndpointBase implements ia
const lambdaFunctions: cdk.aws_lambda.DockerImageFunction[]=[];
this.updateConstructUsageMetricCode( baseProps, scope, lambdaFunctions);


this.instanceType = props.instanceType;
this.modelId = props.modelId;
this.instanceCount = Math.max(1, props.instanceCount ?? 1);
Expand Down Expand Up @@ -128,6 +138,33 @@ export class CustomSageMakerEndpoint extends SageMakerEndpointBase implements ia
],
});

if (props.asyncInference) {

// build sns topics for success and failure
const successTopic = this.buildSnsTopic(`success-topic-${id}`, 'Success Topic');
const failureTopic = this.buildSnsTopic(`failure-topic-${id}`, 'Failure Topic');

this.errorTopic = failureTopic;
this.successTopic = successTopic;

// configure async inference
const asyncInferenceConfigProperty: sagemaker.CfnEndpointConfig.AsyncInferenceConfigProperty = {
outputConfig: {
s3FailurePath: props.asyncInference.failurePath,
s3OutputPath: props.asyncInference.outputPath,
notificationConfig: {
successTopic: successTopic.topicArn,
errorTopic: failureTopic.topicArn,
},
},
clientConfig: {
maxConcurrentInvocationsPerInstance: props.asyncInference.maxConcurrentInvocationsPerInstance ?? 10,
},
};

endpointConfig.asyncInferenceConfig = asyncInferenceConfigProperty;
}

endpointConfig.addDependency(model);

const endpoint = new sagemaker.CfnEndpoint(scope, `${modelIdStr}-endpoint-${id}`, {
Expand Down Expand Up @@ -164,4 +201,30 @@ export class CustomSageMakerEndpoint extends SageMakerEndpointBase implements ia
resourceArns: [this.endpointArn],
});
}

private buildSnsTopic(topicName: string, displayName: string): sns.Topic {
const masterKey = kms.Alias.fromAliasName(this, `aws-managed-key-${topicName}`, 'alias/aws/sns');

const topic = new sns.Topic(this, topicName, {
topicName,
displayName,
masterKey: masterKey,
});

topic.grantPublish(this.role);

topic.addToResourcePolicy(new iam.PolicyStatement({
actions: ['sns:Publish'],
effect: iam.Effect.DENY,
resources: [topic.topicArn],
conditions: {
Bool: {
'aws:SecureTransport': 'false',
},
},
principals: [new iam.AnyPrincipal()],
}));

return topic;
}
}
Loading

0 comments on commit c633397

Please sign in to comment.