-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature: Add support for Streaming Inference
- Loading branch information
1 parent
12bcf05
commit b317df8
Showing
7 changed files
with
473 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""Implements iterators for deserializing data returned from an inference streaming endpoint.""" | ||
from __future__ import absolute_import | ||
|
||
from abc import ABC, abstractmethod | ||
import io | ||
|
||
from sagemaker.exceptions import ModelStreamError, InternalStreamFailure | ||
|
||
|
||
def handle_stream_errors(chunk): | ||
"""Handle API Response errors within `invoke_endpoint_with_response_stream` API if any. | ||
Args: | ||
chunk (dict): A chunk of response received as part of `botocore.eventstream.EventStream` | ||
response object. | ||
Raises: | ||
ModelStreamError: If `ModelStreamError` error is detected in a chunk of | ||
`botocore.eventstream.EventStream` response object. | ||
InternalStreamFailure: If `InternalStreamFailure` error is detected in a chunk of | ||
`botocore.eventstream.EventStream` response object. | ||
""" | ||
if "ModelStreamError" in chunk: | ||
raise ModelStreamError( | ||
chunk["ModelStreamError"]["Message"], code=chunk["ModelStreamError"]["ErrorCode"] | ||
) | ||
if "InternalStreamFailure" in chunk: | ||
raise InternalStreamFailure(chunk["InternalStreamFailure"]["Message"]) | ||
|
||
|
||
class BaseIterator(ABC): | ||
"""Abstract base class for creation of new iterators. | ||
Provides a skeleton for customization requiring the overriding of iterator methods | ||
__iter__ and __next__. | ||
Tenets of iterator class for Streaming Inference API Response | ||
(https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ | ||
sagemaker-runtime/client/invoke_endpoint_with_response_stream.html): | ||
1. Needs to accept an botocore.eventstream.EventStream response. | ||
2. Needs to implement logic in __next__ to: | ||
2.1. Concatenate and provide next chunk of response from botocore.eventstream.EventStream. | ||
While doing so parse the response_chunk["PayloadPart"]["Bytes"]. | ||
2.2. Perform deserialization of response chunk based on expected response type. | ||
2.3. If PayloadPart not in EventStream response, handle Errors. | ||
""" | ||
|
||
def __init__(self, stream): | ||
"""Initialises a Iterator object to help parse the byte event stream input. | ||
Args: | ||
stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. | ||
""" | ||
self.stream = stream | ||
|
||
@abstractmethod | ||
def __iter__(self): | ||
"""Abstract __iter__ method, returns an iterator object itself""" | ||
return self | ||
|
||
@abstractmethod | ||
def __next__(self): | ||
"""Abstract __next__ method, is responsible for returning the next element in the | ||
iteration""" | ||
pass | ||
|
||
|
||
class LineIterator(BaseIterator): | ||
""" | ||
A helper class for parsing the byte stream input and provide iteration on lines with | ||
'\n' separators. | ||
""" | ||
|
||
def __init__(self, stream): | ||
"""Initialises a Iterator object to help parse the byte stream input and | ||
provide iteration on lines with '\n' separators | ||
Args: | ||
stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. | ||
""" | ||
super().__init__(stream) | ||
self.byte_iterator = iter(self.stream) | ||
self.buffer = io.BytesIO() | ||
self.read_pos = 0 | ||
|
||
def __iter__(self): | ||
"""Returns an iterator object itself, which allows the object to be iterated. | ||
Returns: | ||
iter : object | ||
An iterator object representing the iterable. | ||
""" | ||
return self | ||
|
||
def __next__(self): | ||
""" | ||
The output of the event stream will be in the following format: | ||
``` | ||
b'{"outputs": [" a"]}\n' | ||
b'{"outputs": [" challenging"]}\n' | ||
b'{"outputs": [" problem"]}\n' | ||
... | ||
``` | ||
While usually each PayloadPart event from the event stream will contain a byte array | ||
with a full json, this is not guaranteed and some of the json objects may be split across | ||
PayloadPart events. For example: | ||
``` | ||
{'PayloadPart': {'Bytes': b'{"outputs": '}} | ||
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}} | ||
``` | ||
This class accounts for this by concatenating bytes written via the 'write' function | ||
and then exposing a method which will return lines (ending with a '\n' character) within | ||
the buffer via the 'scan_lines' function. It maintains the position of the last read | ||
position to ensure that previous bytes are not exposed again. | ||
Returns: | ||
str: Read and return one line from the event stream. | ||
""" | ||
# Even with "while True" loop the function still behaves like a generator | ||
# and sends the next new concatenated line | ||
while True: | ||
self.buffer.seek(self.read_pos) | ||
line = self.buffer.readline() | ||
if line and line[-1] == ord("\n"): | ||
self.read_pos += len(line) | ||
return line[:-1] | ||
try: | ||
chunk = next(self.byte_iterator) | ||
except StopIteration: | ||
if self.read_pos < self.buffer.getbuffer().nbytes: | ||
continue | ||
raise | ||
if "PayloadPart" not in chunk: | ||
# handle errors within API Response if any. | ||
handle_stream_errors(chunk) | ||
print("Unknown event type:" + chunk) | ||
continue | ||
self.buffer.seek(0, io.SEEK_END) | ||
self.buffer.write(chunk["PayloadPart"]["Bytes"]) | ||
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import json | ||
import os | ||
import pytest | ||
|
||
import tests.integ | ||
import tests.integ.timeout | ||
|
||
from sagemaker import image_uris | ||
from sagemaker.iterators import LineIterator | ||
from sagemaker.model import Model | ||
from sagemaker.predictor import Predictor | ||
from sagemaker.utils import unique_name_from_base | ||
|
||
from tests.integ import DATA_DIR | ||
|
||
|
||
ROLE = "SageMakerRole" | ||
INSTANCE_COUNT = 1 | ||
INSTANCE_TYPE = "ml.g5.2xlarge" | ||
LMI_FALCON_7B_DATA_PATH = os.path.join(DATA_DIR, "lmi-model-falcon-7b") | ||
|
||
|
||
@pytest.yield_fixture(scope="module") | ||
def endpoint_name(sagemaker_session): | ||
lmi_endpoint_name = unique_name_from_base("lmi-model-falcon-7b") | ||
model_data = sagemaker_session.upload_data( | ||
path=os.path.join(LMI_FALCON_7B_DATA_PATH, "mymodel-7B.tar.gz"), | ||
key_prefix="large-model-lmi/code", | ||
) | ||
|
||
image_uri = image_uris.retrieve( | ||
framework="djl-deepspeed", region=sagemaker_session.boto_region_name, version="0.23.0" | ||
) | ||
|
||
with tests.integ.timeout.timeout_and_delete_endpoint_by_name( | ||
endpoint_name=lmi_endpoint_name, sagemaker_session=sagemaker_session, hours=2 | ||
): | ||
lmi_model = Model( | ||
sagemaker_session=sagemaker_session, | ||
model_data=model_data, | ||
image_uri=image_uri, | ||
name=lmi_endpoint_name, # model name | ||
role=ROLE, | ||
) | ||
lmi_model.deploy( | ||
INSTANCE_COUNT, | ||
INSTANCE_TYPE, | ||
endpoint_name=lmi_endpoint_name, | ||
container_startup_health_check_timeout=900, | ||
) | ||
yield lmi_endpoint_name | ||
|
||
|
||
def test_predict_stream(sagemaker_session, endpoint_name): | ||
data = {"inputs": "what does AWS stand for?", "parameters": {"max_new_tokens": 400}} | ||
initial_args = {"ContentType": "application/json"} | ||
predictor = Predictor( | ||
endpoint_name=endpoint_name, | ||
sagemaker_session=sagemaker_session, | ||
) | ||
|
||
# Validate that no exception is raised when the target_variant is specified. | ||
stream_iterator = predictor.predict_stream( | ||
data=json.dumps(data), | ||
initial_args=initial_args, | ||
iterator=LineIterator, | ||
) | ||
|
||
response = "" | ||
for line in stream_iterator: | ||
resp = json.loads(line) | ||
response += resp.get("outputs")[0] | ||
|
||
assert "AWS stands for Amazon Web Services." in response |
Oops, something went wrong.