From b317df819f9a0ba481b77fac3177e3135c753e6d Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala Date: Wed, 13 Mar 2024 10:25:24 -0700 Subject: [PATCH] feature: Add support for Streaming Inference --- src/sagemaker/base_predictor.py | 82 ++++++++++ src/sagemaker/exceptions.py | 16 ++ src/sagemaker/iterators.py | 154 ++++++++++++++++++ .../lmi-model-falcon-7b/mymodel-7B.tar.gz | Bin 0 -> 382 bytes tests/integ/test_predict_stream.py | 88 ++++++++++ .../sagemaker/iterators/test_iterators.py | 58 +++++++ tests/unit/test_predictor.py | 75 +++++++++ 7 files changed, 473 insertions(+) create mode 100644 src/sagemaker/iterators.py create mode 100644 tests/data/lmi-model-falcon-7b/mymodel-7B.tar.gz create mode 100644 tests/integ/test_predict_stream.py create mode 100644 tests/unit/sagemaker/iterators/test_iterators.py diff --git a/src/sagemaker/base_predictor.py b/src/sagemaker/base_predictor.py index 882cfafc39..023561dbf0 100644 --- a/src/sagemaker/base_predictor.py +++ b/src/sagemaker/base_predictor.py @@ -52,6 +52,7 @@ JSONSerializer, NumpySerializer, ) +from sagemaker.iterators import LineIterator from sagemaker.session import production_variant, Session from sagemaker.utils import name_from_base, stringify_object, format_tags @@ -223,6 +224,7 @@ def _create_request_args( target_variant=None, inference_id=None, custom_attributes=None, + target_container_hostname=None, ): """Placeholder docstring""" @@ -284,9 +286,89 @@ def _create_request_args( if self._get_component_name(): args["InferenceComponentName"] = self.component_name + if target_container_hostname: + args["TargetContainerHostname"] = target_container_hostname + args["Body"] = data return args + def predict_stream( + self, + data, + initial_args=None, + target_variant=None, + inference_id=None, + custom_attributes=None, + component_name: Optional[str] = None, + target_container_hostname=None, + iterator=LineIterator, + ): + """Return the inference from the specified endpoint. + + Args: + data (object): Input data for which you want the model to provide + inference. If a serializer was specified when creating the + Predictor, the result of the serializer is sent as input + data. Otherwise the data must be sequence of bytes, and the + predict method then sends the bytes in the request body as is. + initial_args (dict[str,str]): Optional. Default arguments for boto3 + ``invoke_endpoint_with_response_stream`` call. Default is None (no default + arguments). + target_variant (str): The name of the production variant to run an inference + request on (Default: None). Note that the ProductionVariant identifies the + model you want to host and the resources you want to deploy for hosting it. + inference_id (str): If you provide a value, it is added to the captured data + when you enable data capture on the endpoint (Default: None). + custom_attributes (str): Provides additional information about a request for an + inference submitted to a model hosted at an Amazon SageMaker endpoint. + The information is an opaque value that is forwarded verbatim. You could use this + value, for example, to provide an ID that you can use to track a request or to + provide other metadata that a service endpoint was programmed to process. The value + must consist of no more than 1024 visible US-ASCII characters. + + The code in your model is responsible for setting or updating any custom attributes + in the response. If your code does not set this value in the response, an empty + value is returned. For example, if a custom attribute represents the trace ID, your + model can prepend the custom attribute with Trace ID: in your post-processing + function (Default: None). + component_name (str): Optional. Name of the Amazon SageMaker inference component + corresponding the predictor. + target_container_hostname (str): If the endpoint hosts multiple containers and is + configured to use direct invocation, this parameter specifies the host name of the + container to invoke. (Default: None). + iterator (:class:`~sagemaker.iterators.BaseIterator`): An iterator class which provides + an iterable interface to deserialize a stream response from Inference Endpoint. + An object of the iterator class provided will be returned by the predict_stream + method (Default::class:`~sagemaker.iterators.LineIterator`). Iterators defined in + :class:`~sagemaker.iterators` or custom iterators (needs to inherit + :class:`~sagemaker.iterators.BaseIterator`) can be specified as an input. + + Returns: + object (:class:`~sagemaker.iterators.BaseIterator`): An iterator object which would + allow iteration on EventStream response will be returned. The object would be + instantiated from `predict_stream` method's `iterator` parameter. + """ + # [TODO]: clean up component_name in _create_request_args + request_args = self._create_request_args( + data=data, + initial_args=initial_args, + target_variant=target_variant, + inference_id=inference_id, + custom_attributes=custom_attributes, + target_container_hostname=target_container_hostname, + ) + + inference_component_name = component_name or self._get_component_name() + if inference_component_name: + request_args["InferenceComponentName"] = inference_component_name + + response = ( + self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint_with_response_stream( + **request_args + ) + ) + return iterator(response["Body"]) + def update_endpoint( self, initial_instance_count=None, diff --git a/src/sagemaker/exceptions.py b/src/sagemaker/exceptions.py index b9d97cc241..7e1ed2c5fd 100644 --- a/src/sagemaker/exceptions.py +++ b/src/sagemaker/exceptions.py @@ -86,3 +86,19 @@ class AsyncInferenceModelError(AsyncInferenceError): def __init__(self, message): super().__init__(message=message) + + +class ModelStreamError(Exception): + def __init__(self, message="An error occurred", code=None): + self.message = message + self.code = code + if code is not None: + super().__init__(f"{message} (Code: {code})") + else: + super().__init__(message) + + +class InternalStreamFailure(Exception): + def __init__(self, message="An error occurred"): + self.message = message + super().__init__(self.message) diff --git a/src/sagemaker/iterators.py b/src/sagemaker/iterators.py new file mode 100644 index 0000000000..5680f14fc9 --- /dev/null +++ b/src/sagemaker/iterators.py @@ -0,0 +1,154 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Implements iterators for deserializing data returned from an inference streaming endpoint.""" +from __future__ import absolute_import + +from abc import ABC, abstractmethod +import io + +from sagemaker.exceptions import ModelStreamError, InternalStreamFailure + + +def handle_stream_errors(chunk): + """Handle API Response errors within `invoke_endpoint_with_response_stream` API if any. + + Args: + chunk (dict): A chunk of response received as part of `botocore.eventstream.EventStream` + response object. + + Raises: + ModelStreamError: If `ModelStreamError` error is detected in a chunk of + `botocore.eventstream.EventStream` response object. + InternalStreamFailure: If `InternalStreamFailure` error is detected in a chunk of + `botocore.eventstream.EventStream` response object. + """ + if "ModelStreamError" in chunk: + raise ModelStreamError( + chunk["ModelStreamError"]["Message"], code=chunk["ModelStreamError"]["ErrorCode"] + ) + if "InternalStreamFailure" in chunk: + raise InternalStreamFailure(chunk["InternalStreamFailure"]["Message"]) + + +class BaseIterator(ABC): + """Abstract base class for creation of new iterators. + + Provides a skeleton for customization requiring the overriding of iterator methods + __iter__ and __next__. + + Tenets of iterator class for Streaming Inference API Response + (https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ + sagemaker-runtime/client/invoke_endpoint_with_response_stream.html): + 1. Needs to accept an botocore.eventstream.EventStream response. + 2. Needs to implement logic in __next__ to: + 2.1. Concatenate and provide next chunk of response from botocore.eventstream.EventStream. + While doing so parse the response_chunk["PayloadPart"]["Bytes"]. + 2.2. Perform deserialization of response chunk based on expected response type. + 2.3. If PayloadPart not in EventStream response, handle Errors. + """ + + def __init__(self, stream): + """Initialises a Iterator object to help parse the byte event stream input. + + Args: + stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. + """ + self.stream = stream + + @abstractmethod + def __iter__(self): + """Abstract __iter__ method, returns an iterator object itself""" + return self + + @abstractmethod + def __next__(self): + """Abstract __next__ method, is responsible for returning the next element in the + iteration""" + pass + + +class LineIterator(BaseIterator): + """ + A helper class for parsing the byte stream input and provide iteration on lines with + '\n' separators. + """ + + def __init__(self, stream): + """Initialises a Iterator object to help parse the byte stream input and + provide iteration on lines with '\n' separators + + Args: + stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. + """ + super().__init__(stream) + self.byte_iterator = iter(self.stream) + self.buffer = io.BytesIO() + self.read_pos = 0 + + def __iter__(self): + """Returns an iterator object itself, which allows the object to be iterated. + + Returns: + iter : object + An iterator object representing the iterable. + """ + return self + + def __next__(self): + """ + The output of the event stream will be in the following format: + + ``` + b'{"outputs": [" a"]}\n' + b'{"outputs": [" challenging"]}\n' + b'{"outputs": [" problem"]}\n' + ... + ``` + + While usually each PayloadPart event from the event stream will contain a byte array + with a full json, this is not guaranteed and some of the json objects may be split across + PayloadPart events. For example: + ``` + {'PayloadPart': {'Bytes': b'{"outputs": '}} + {'PayloadPart': {'Bytes': b'[" problem"]}\n'}} + ``` + + This class accounts for this by concatenating bytes written via the 'write' function + and then exposing a method which will return lines (ending with a '\n' character) within + the buffer via the 'scan_lines' function. It maintains the position of the last read + position to ensure that previous bytes are not exposed again. + + Returns: + str: Read and return one line from the event stream. + """ + # Even with "while True" loop the function still behaves like a generator + # and sends the next new concatenated line + while True: + self.buffer.seek(self.read_pos) + line = self.buffer.readline() + if line and line[-1] == ord("\n"): + self.read_pos += len(line) + return line[:-1] + try: + chunk = next(self.byte_iterator) + except StopIteration: + if self.read_pos < self.buffer.getbuffer().nbytes: + continue + raise + if "PayloadPart" not in chunk: + # handle errors within API Response if any. + handle_stream_errors(chunk) + print("Unknown event type:" + chunk) + continue + self.buffer.seek(0, io.SEEK_END) + self.buffer.write(chunk["PayloadPart"]["Bytes"]) diff --git a/tests/data/lmi-model-falcon-7b/mymodel-7B.tar.gz b/tests/data/lmi-model-falcon-7b/mymodel-7B.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..6a66178b47a086851a140b8dba294307ac31ace1 GIT binary patch literal 382 zcmV-^0fGJ>iwFP!000001MSkyPV68Q2k@?aioU=oOg|VGY}~mp@eK$qr*knaG(e5| z^fF_nPHu8_qcJ!Be;Zmj^qg{-o+oc;+=!d2;=8a+G|h3${vMCdylzC*3NGrlV4OF+ zF3RTHDmt^oq(fO2!Ta=4+-K|msp-A{k;0>O`^!1_nL@G@zbMC{!EIgtv;R$1o%KK8 z6W&xz6eatj{2%(|{U^7#j^y3_?S-F{_3rX`ACxsRS-WVu8uZwEw-MdOx|qV!r&DC0 zM;r5lq^{;{=-Oe>*KE6^YoSfIV|-wyrM_+r+YY;qiPOgXm6%kZ$tO~M&L{H>t*hjs z4{Fvyk7F*y&^{1Jz80vTRPf`N@2cu_>i?){Ur1KlwXX9;IZk$CY+S4MOPZIY1|KG! z5(W7Xz02_wPZ1_P&s55Cn0b4eoAsWII&5%