Skip to content

Commit

Permalink
feature: Add support for Streaming Inference
Browse files Browse the repository at this point in the history
  • Loading branch information
mufaddal-rohawala committed Mar 13, 2024
1 parent 12bcf05 commit b317df8
Show file tree
Hide file tree
Showing 7 changed files with 473 additions and 0 deletions.
82 changes: 82 additions & 0 deletions src/sagemaker/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
JSONSerializer,
NumpySerializer,
)
from sagemaker.iterators import LineIterator
from sagemaker.session import production_variant, Session
from sagemaker.utils import name_from_base, stringify_object, format_tags

Expand Down Expand Up @@ -223,6 +224,7 @@ def _create_request_args(
target_variant=None,
inference_id=None,
custom_attributes=None,
target_container_hostname=None,
):
"""Placeholder docstring"""

Expand Down Expand Up @@ -284,9 +286,89 @@ def _create_request_args(
if self._get_component_name():
args["InferenceComponentName"] = self.component_name

if target_container_hostname:
args["TargetContainerHostname"] = target_container_hostname

Check warning on line 290 in src/sagemaker/base_predictor.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/base_predictor.py#L289-L290

Added lines #L289 - L290 were not covered by tests

args["Body"] = data
return args

def predict_stream(
self,
data,
initial_args=None,
target_variant=None,
inference_id=None,
custom_attributes=None,
component_name: Optional[str] = None,
target_container_hostname=None,
iterator=LineIterator,
):
"""Return the inference from the specified endpoint.
Args:
data (object): Input data for which you want the model to provide
inference. If a serializer was specified when creating the
Predictor, the result of the serializer is sent as input
data. Otherwise the data must be sequence of bytes, and the
predict method then sends the bytes in the request body as is.
initial_args (dict[str,str]): Optional. Default arguments for boto3
``invoke_endpoint_with_response_stream`` call. Default is None (no default
arguments).
target_variant (str): The name of the production variant to run an inference
request on (Default: None). Note that the ProductionVariant identifies the
model you want to host and the resources you want to deploy for hosting it.
inference_id (str): If you provide a value, it is added to the captured data
when you enable data capture on the endpoint (Default: None).
custom_attributes (str): Provides additional information about a request for an
inference submitted to a model hosted at an Amazon SageMaker endpoint.
The information is an opaque value that is forwarded verbatim. You could use this
value, for example, to provide an ID that you can use to track a request or to
provide other metadata that a service endpoint was programmed to process. The value
must consist of no more than 1024 visible US-ASCII characters.
The code in your model is responsible for setting or updating any custom attributes
in the response. If your code does not set this value in the response, an empty
value is returned. For example, if a custom attribute represents the trace ID, your
model can prepend the custom attribute with Trace ID: in your post-processing
function (Default: None).
component_name (str): Optional. Name of the Amazon SageMaker inference component
corresponding the predictor.
target_container_hostname (str): If the endpoint hosts multiple containers and is
configured to use direct invocation, this parameter specifies the host name of the
container to invoke. (Default: None).
iterator (:class:`~sagemaker.iterators.BaseIterator`): An iterator class which provides
an iterable interface to deserialize a stream response from Inference Endpoint.
An object of the iterator class provided will be returned by the predict_stream
method (Default::class:`~sagemaker.iterators.LineIterator`). Iterators defined in
:class:`~sagemaker.iterators` or custom iterators (needs to inherit
:class:`~sagemaker.iterators.BaseIterator`) can be specified as an input.
Returns:
object (:class:`~sagemaker.iterators.BaseIterator`): An iterator object which would
allow iteration on EventStream response will be returned. The object would be
instantiated from `predict_stream` method's `iterator` parameter.
"""
# [TODO]: clean up component_name in _create_request_args
request_args = self._create_request_args(

Check warning on line 352 in src/sagemaker/base_predictor.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/base_predictor.py#L352

Added line #L352 was not covered by tests
data=data,
initial_args=initial_args,
target_variant=target_variant,
inference_id=inference_id,
custom_attributes=custom_attributes,
target_container_hostname=target_container_hostname,
)

inference_component_name = component_name or self._get_component_name()
if inference_component_name:
request_args["InferenceComponentName"] = inference_component_name

Check warning on line 363 in src/sagemaker/base_predictor.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/base_predictor.py#L361-L363

Added lines #L361 - L363 were not covered by tests

response = (

Check warning on line 365 in src/sagemaker/base_predictor.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/base_predictor.py#L365

Added line #L365 was not covered by tests
self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint_with_response_stream(
**request_args
)
)
return iterator(response["Body"])

Check warning on line 370 in src/sagemaker/base_predictor.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/base_predictor.py#L370

Added line #L370 was not covered by tests

def update_endpoint(
self,
initial_instance_count=None,
Expand Down
16 changes: 16 additions & 0 deletions src/sagemaker/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,19 @@ class AsyncInferenceModelError(AsyncInferenceError):

def __init__(self, message):
super().__init__(message=message)


class ModelStreamError(Exception):
def __init__(self, message="An error occurred", code=None):
self.message = message
self.code = code
if code is not None:
super().__init__(f"{message} (Code: {code})")

Check warning on line 96 in src/sagemaker/exceptions.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/exceptions.py#L93-L96

Added lines #L93 - L96 were not covered by tests
else:
super().__init__(message)

Check warning on line 98 in src/sagemaker/exceptions.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/exceptions.py#L98

Added line #L98 was not covered by tests


class InternalStreamFailure(Exception):
def __init__(self, message="An error occurred"):
self.message = message
super().__init__(self.message)

Check warning on line 104 in src/sagemaker/exceptions.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/exceptions.py#L103-L104

Added lines #L103 - L104 were not covered by tests
154 changes: 154 additions & 0 deletions src/sagemaker/iterators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Implements iterators for deserializing data returned from an inference streaming endpoint."""
from __future__ import absolute_import

from abc import ABC, abstractmethod
import io

from sagemaker.exceptions import ModelStreamError, InternalStreamFailure


def handle_stream_errors(chunk):
"""Handle API Response errors within `invoke_endpoint_with_response_stream` API if any.
Args:
chunk (dict): A chunk of response received as part of `botocore.eventstream.EventStream`
response object.
Raises:
ModelStreamError: If `ModelStreamError` error is detected in a chunk of
`botocore.eventstream.EventStream` response object.
InternalStreamFailure: If `InternalStreamFailure` error is detected in a chunk of
`botocore.eventstream.EventStream` response object.
"""
if "ModelStreamError" in chunk:
raise ModelStreamError(

Check warning on line 36 in src/sagemaker/iterators.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/iterators.py#L35-L36

Added lines #L35 - L36 were not covered by tests
chunk["ModelStreamError"]["Message"], code=chunk["ModelStreamError"]["ErrorCode"]
)
if "InternalStreamFailure" in chunk:
raise InternalStreamFailure(chunk["InternalStreamFailure"]["Message"])

Check warning on line 40 in src/sagemaker/iterators.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/iterators.py#L39-L40

Added lines #L39 - L40 were not covered by tests


class BaseIterator(ABC):
"""Abstract base class for creation of new iterators.
Provides a skeleton for customization requiring the overriding of iterator methods
__iter__ and __next__.
Tenets of iterator class for Streaming Inference API Response
(https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/
sagemaker-runtime/client/invoke_endpoint_with_response_stream.html):
1. Needs to accept an botocore.eventstream.EventStream response.
2. Needs to implement logic in __next__ to:
2.1. Concatenate and provide next chunk of response from botocore.eventstream.EventStream.
While doing so parse the response_chunk["PayloadPart"]["Bytes"].
2.2. Perform deserialization of response chunk based on expected response type.
2.3. If PayloadPart not in EventStream response, handle Errors.
"""

def __init__(self, stream):
"""Initialises a Iterator object to help parse the byte event stream input.
Args:
stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
"""
self.stream = stream

Check warning on line 66 in src/sagemaker/iterators.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/iterators.py#L66

Added line #L66 was not covered by tests

@abstractmethod
def __iter__(self):
"""Abstract __iter__ method, returns an iterator object itself"""
return self

Check warning on line 71 in src/sagemaker/iterators.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/iterators.py#L71

Added line #L71 was not covered by tests

@abstractmethod
def __next__(self):
"""Abstract __next__ method, is responsible for returning the next element in the
iteration"""
pass

Check warning on line 77 in src/sagemaker/iterators.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/iterators.py#L77

Added line #L77 was not covered by tests


class LineIterator(BaseIterator):
"""
A helper class for parsing the byte stream input and provide iteration on lines with
'\n' separators.
"""

def __init__(self, stream):
"""Initialises a Iterator object to help parse the byte stream input and
provide iteration on lines with '\n' separators
Args:
stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
"""
super().__init__(stream)
self.byte_iterator = iter(self.stream)
self.buffer = io.BytesIO()
self.read_pos = 0

Check warning on line 96 in src/sagemaker/iterators.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/iterators.py#L93-L96

Added lines #L93 - L96 were not covered by tests

def __iter__(self):
"""Returns an iterator object itself, which allows the object to be iterated.
Returns:
iter : object
An iterator object representing the iterable.
"""
return self

Check warning on line 105 in src/sagemaker/iterators.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/iterators.py#L105

Added line #L105 was not covered by tests

def __next__(self):
"""
The output of the event stream will be in the following format:
```
b'{"outputs": [" a"]}\n'
b'{"outputs": [" challenging"]}\n'
b'{"outputs": [" problem"]}\n'
...
```
While usually each PayloadPart event from the event stream will contain a byte array
with a full json, this is not guaranteed and some of the json objects may be split across
PayloadPart events. For example:
```
{'PayloadPart': {'Bytes': b'{"outputs": '}}
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
```
This class accounts for this by concatenating bytes written via the 'write' function
and then exposing a method which will return lines (ending with a '\n' character) within
the buffer via the 'scan_lines' function. It maintains the position of the last read
position to ensure that previous bytes are not exposed again.
Returns:
str: Read and return one line from the event stream.
"""
# Even with "while True" loop the function still behaves like a generator
# and sends the next new concatenated line
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if "PayloadPart" not in chunk:

Check warning on line 148 in src/sagemaker/iterators.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/iterators.py#L136-L148

Added lines #L136 - L148 were not covered by tests
# handle errors within API Response if any.
handle_stream_errors(chunk)
print("Unknown event type:" + chunk)
continue
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])

Check warning on line 154 in src/sagemaker/iterators.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/iterators.py#L150-L154

Added lines #L150 - L154 were not covered by tests
Binary file added tests/data/lmi-model-falcon-7b/mymodel-7B.tar.gz
Binary file not shown.
88 changes: 88 additions & 0 deletions tests/integ/test_predict_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import json
import os
import pytest

import tests.integ
import tests.integ.timeout

from sagemaker import image_uris
from sagemaker.iterators import LineIterator
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.utils import unique_name_from_base

from tests.integ import DATA_DIR


ROLE = "SageMakerRole"
INSTANCE_COUNT = 1
INSTANCE_TYPE = "ml.g5.2xlarge"
LMI_FALCON_7B_DATA_PATH = os.path.join(DATA_DIR, "lmi-model-falcon-7b")


@pytest.yield_fixture(scope="module")
def endpoint_name(sagemaker_session):
lmi_endpoint_name = unique_name_from_base("lmi-model-falcon-7b")
model_data = sagemaker_session.upload_data(
path=os.path.join(LMI_FALCON_7B_DATA_PATH, "mymodel-7B.tar.gz"),
key_prefix="large-model-lmi/code",
)

image_uri = image_uris.retrieve(
framework="djl-deepspeed", region=sagemaker_session.boto_region_name, version="0.23.0"
)

with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
endpoint_name=lmi_endpoint_name, sagemaker_session=sagemaker_session, hours=2
):
lmi_model = Model(
sagemaker_session=sagemaker_session,
model_data=model_data,
image_uri=image_uri,
name=lmi_endpoint_name, # model name
role=ROLE,
)
lmi_model.deploy(
INSTANCE_COUNT,
INSTANCE_TYPE,
endpoint_name=lmi_endpoint_name,
container_startup_health_check_timeout=900,
)
yield lmi_endpoint_name


def test_predict_stream(sagemaker_session, endpoint_name):
data = {"inputs": "what does AWS stand for?", "parameters": {"max_new_tokens": 400}}
initial_args = {"ContentType": "application/json"}
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
)

# Validate that no exception is raised when the target_variant is specified.
stream_iterator = predictor.predict_stream(
data=json.dumps(data),
initial_args=initial_args,
iterator=LineIterator,
)

response = ""
for line in stream_iterator:
resp = json.loads(line)
response += resp.get("outputs")[0]

assert "AWS stands for Amazon Web Services." in response
Loading

0 comments on commit b317df8

Please sign in to comment.