Skip to content

Commit

Permalink
Added warning for load balancer (#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
lu-ohai authored Oct 13, 2023
2 parents a4a3d9a + 8207481 commit 9018bde
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 45 deletions.
93 changes: 50 additions & 43 deletions ads/model/deployment/model_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ads.model.deployment.model_deployment_runtime import (
ModelDeploymentCondaRuntime,
ModelDeploymentContainerRuntime,
ModelDeploymentMode,
ModelDeploymentRuntime,
ModelDeploymentRuntimeType,
OCIModelDeploymentRuntimeType,
Expand Down Expand Up @@ -80,11 +81,6 @@ class ModelDeploymentLogType:
ACCESS = "access"


class ModelDeploymentMode:
HTTPS = "HTTPS_ONLY"
STREAM = "STREAM_ONLY"


class LogNotConfiguredError(Exception): # pragma: no cover
pass

Expand Down Expand Up @@ -911,48 +907,59 @@ def predict(
"`data` and `json_input` are both provided. You can only use one of them."
)

if auto_serialize_data:
data = data or json_input
serialized_data = serializer.serialize(data=data)
return send_request(
data=serialized_data,
endpoint=endpoint,
is_json_payload=_is_json_serializable(serialized_data),
header=header,
)
try:
if auto_serialize_data:
data = data or json_input
serialized_data = serializer.serialize(data=data)
return send_request(
data=serialized_data,
endpoint=endpoint,
is_json_payload=_is_json_serializable(serialized_data),
header=header,
)

if json_input is not None:
if not _is_json_serializable(json_input):
raise ValueError(
"`json_input` must be json serializable. "
"Set `auto_serialize_data` to True, or serialize the provided input data first,"
"or using `data` to pass binary data."
if json_input is not None:
if not _is_json_serializable(json_input):
raise ValueError(
"`json_input` must be json serializable. "
"Set `auto_serialize_data` to True, or serialize the provided input data first,"
"or using `data` to pass binary data."
)
utils.get_logger().warning(
"The `json_input` argument of `predict()` will be deprecated soon. "
"Please use `data` argument. "
)
utils.get_logger().warning(
"The `json_input` argument of `predict()` will be deprecated soon. "
"Please use `data` argument. "
)
data = json_input
data = json_input

is_json_payload = _is_json_serializable(data)
if not isinstance(data, bytes) and not is_json_payload:
raise TypeError(
"`data` is not bytes or json serializable. Set `auto_serialize_data` to `True` to serialize the input data."
)
if model_name and model_version:
header["model-name"] = model_name
header["model-version"] = model_version
elif bool(model_version) ^ bool(model_name):
raise ValueError(
"`model_name` and `model_version` have to be provided together."
is_json_payload = _is_json_serializable(data)
if not isinstance(data, bytes) and not is_json_payload:
raise TypeError(
"`data` is not bytes or json serializable. Set `auto_serialize_data` to `True` to serialize the input data."
)
if model_name and model_version:
header["model-name"] = model_name
header["model-version"] = model_version
elif bool(model_version) ^ bool(model_name):
raise ValueError(
"`model_name` and `model_version` have to be provided together."
)
prediction = send_request(
data=data,
endpoint=endpoint,
is_json_payload=is_json_payload,
header=header,
)
prediction = send_request(
data=data,
endpoint=endpoint,
is_json_payload=is_json_payload,
header=header,
)
return prediction
return prediction
except oci.exceptions.ServiceError as ex:
# When bandwidth exceeds the allocated value, TooManyRequests error (429) will be raised by oci backend.
if ex.status == 429:
bandwidth_mbps = self.infrastructure.bandwidth_mbps or MODEL_DEPLOYMENT_BANDWIDTH_MBPS
utils.get_logger().warning(
f"Load balancer bandwidth exceeds the allocated {bandwidth_mbps} Mbps."
"To estimate the actual bandwidth, use formula: (payload size in KB) * (estimated requests per second) * 8 / 1024."
"To resolve the issue, try sizing down the payload, slowing down the request rate or increasing the allocated bandwidth."
)
raise

def activate(
self,
Expand Down
7 changes: 6 additions & 1 deletion ads/model/deployment/model_deployment_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ class OCIModelDeploymentRuntimeType:
CONTAINER = "OCIR_CONTAINER"


class ModelDeploymentMode:
HTTPS = "HTTPS_ONLY"
STREAM = "STREAM_ONLY"


class ModelDeploymentRuntime(Builder):
"""A class used to represent a Model Deployment Runtime.
Expand Down Expand Up @@ -173,7 +178,7 @@ def deployment_mode(self) -> str:
str
The deployment mode of model deployment.
"""
return self.get_spec(self.CONST_DEPLOYMENT_MODE, None)
return self.get_spec(self.CONST_DEPLOYMENT_MODE, ModelDeploymentMode.HTTPS)

def with_deployment_mode(self, deployment_mode: str) -> "ModelDeploymentRuntime":
"""Sets the deployment mode of model deployment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@
ModelDeployment,
ModelDeploymentProperties,
)
from ads.model.deployment.model_deployment_infrastructure import ModelDeploymentInfrastructure
from ads.model.deployment.model_deployment_runtime import ModelDeploymentCondaRuntime


class ModelDeploymentTestCase(unittest.TestCase):
MODEL_ID = "<MODEL_OCID>"
with patch.object(oci_client, "OCIClientFactory"):
test_model_deployment = ModelDeployment(
model_deployment_id="test_model_deployment_id", properties={}
model_deployment_id="test_model_deployment_id", properties={},
infrastructure=ModelDeploymentInfrastructure(),
runtime=ModelDeploymentCondaRuntime()
)

@patch("requests.post")
Expand Down

0 comments on commit 9018bde

Please sign in to comment.