Skip to content

Commit

Permalink
[ODSC-63984] BYOC TEI deployment for embedding models (#975)
Browse files Browse the repository at this point in the history
  • Loading branch information
VipulMascarenhas authored Oct 29, 2024
2 parents c5bb94e + cf81e28 commit 4936025
Show file tree
Hide file tree
Showing 13 changed files with 668 additions and 87 deletions.
9 changes: 9 additions & 0 deletions ads/aqua/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"


class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
Expand Down Expand Up @@ -80,3 +81,11 @@ class RqsAdditionalDetails(str, metaclass=ExtendedEnumMeta):
MODEL_VERSION_SET_NAME = "modelVersionSetName"
PROJECT_ID = "projectId"
VERSION_LABEL = "versionLabel"


class TextEmbeddingInferenceContainerParams(str, metaclass=ExtendedEnumMeta):
"""Contains a subset of params that are required for enabling model deployment in OCI Data Science. More options
are available at https://huggingface.co/docs/text-embeddings-inference/en/cli_arguments"""

MODEL_ID = "model-id"
PORT = "port"
88 changes: 83 additions & 5 deletions ads/aqua/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
InferenceContainerParamType,
InferenceContainerType,
RqsAdditionalDetails,
TextEmbeddingInferenceContainerParams,
)
from ads.aqua.common.errors import (
AquaFileNotFoundError,
Expand All @@ -51,6 +52,7 @@
MODEL_BY_REFERENCE_OSS_PATH_KEY,
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
SUPPORTED_FILE_FORMATS,
TEI_CONTAINER_DEFAULT_HOST,
TGI_INFERENCE_RESTRICTED_PARAMS,
UNKNOWN,
UNKNOWN_JSON_STR,
Expand All @@ -63,7 +65,12 @@
from ads.common.object_storage_details import ObjectStorageDetails
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
from ads.common.utils import copy_file, get_console_link, upload_to_os
from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID
from ads.config import (
AQUA_MODEL_DEPLOYMENT_FOLDER,
AQUA_SERVICE_MODELS_BUCKET,
CONDA_BUCKET_NS,
TENANCY_OCID,
)
from ads.model import DataScienceModel, ModelVersionSet

logger = logging.getLogger("ads.aqua")
Expand Down Expand Up @@ -569,15 +576,13 @@ def get_container_image(
A dict of allowed configs.
"""

container_image = UNKNOWN
config = config_file_name or get_container_config()
config_file_name = service_config_path()

if container_type not in config:
raise AquaValueError(
f"{config_file_name} does not have config details for model: {container_type}"
)
return UNKNOWN

container_image = None
mapping = config[container_type]
versions = [obj["version"] for obj in mapping]
# assumes numbered versions, update if `latest` is used
Expand Down Expand Up @@ -1078,3 +1083,76 @@ def list_hf_models(query: str) -> List[str]:
return [model.id for model in models if model.disabled is None]
except HfHubHTTPError as err:
raise format_hf_custom_error_message(err) from err


def generate_tei_cmd_var(os_path: str) -> List[str]:
"""This utility functions generates CMD params for Text Embedding Inference container. Only the
essential parameters for OCI model deployment are added, defaults are used for the rest.
Parameters
----------
os_path: str
OCI bucket path where the model artifacts are uploaded - oci://bucket@namespace/prefix
Returns
-------
cmd_var:
List of command line arguments
"""

cmd_prefix = "--"
cmd_var = [
f"{cmd_prefix}{TextEmbeddingInferenceContainerParams.MODEL_ID}",
f"{AQUA_MODEL_DEPLOYMENT_FOLDER}{ObjectStorageDetails.from_path(os_path.rstrip('/')).filepath}/",
f"{cmd_prefix}{TextEmbeddingInferenceContainerParams.PORT}",
TEI_CONTAINER_DEFAULT_HOST,
]

return cmd_var


def parse_cmd_var(cmd_list: List[str]) -> dict:
"""Helper functions that parses a list into a key-value dictionary. The list contains keys separated by the prefix
'--' and the value of the key is the subsequent element.
"""
parsed_cmd = {}

for i, cmd in enumerate(cmd_list):
if cmd.startswith("--"):
if i + 1 < len(cmd_list) and not cmd_list[i + 1].startswith("--"):
parsed_cmd[cmd] = cmd_list[i + 1]
i += 1
else:
parsed_cmd[cmd] = None
return parsed_cmd


def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
"""This function accepts two lists of parameters and combines them. If the second list shares the common parameter
names/keys, then it raises an error.
Parameters
----------
cmd_var: List[str]
Default list of parameters
overrides: List[str]
List of parameters to override
Returns
-------
List[str] of combined parameters
"""
cmd_var = [str(x) for x in cmd_var]
if not overrides:
return cmd_var
overrides = [str(x) for x in overrides]

cmd_dict = parse_cmd_var(cmd_var)
overrides_dict = parse_cmd_var(overrides)

# check for conflicts
common_keys = set(cmd_dict.keys()) & set(overrides_dict.keys())
if common_keys:
raise AquaValueError(
f"The following CMD input cannot be overridden for model deployment: {', '.join(common_keys)}"
)

combined_cmd_var = cmd_var + overrides
return combined_cmd_var
1 change: 1 addition & 0 deletions ads/aqua/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,4 @@
"--port",
"--host",
}
TEI_CONTAINER_DEFAULT_HOST = "8080"
1 change: 1 addition & 0 deletions ads/aqua/model/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ModelCustomMetadataFields(str, metaclass=ExtendedEnumMeta):
DEPLOYMENT_CONTAINER = "deployment-container"
EVALUATION_CONTAINER = "evaluation-container"
FINETUNE_CONTAINER = "finetune-container"
DEPLOYMENT_CONTAINER_URI = "deployment-container-uri"


class ModelTask(str, metaclass=ExtendedEnumMeta):
Expand Down
2 changes: 2 additions & 0 deletions ads/aqua/model/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class AquaModel(AquaModelSummary, DataClassSerializable):

model_card: str = None
inference_container: str = None
inference_container_uri: str = None
finetuning_container: str = None
evaluation_container: str = None
artifact_location: str = None
Expand Down Expand Up @@ -287,6 +288,7 @@ class ImportModelDetails(CLIBuilderMixin):
compartment_id: Optional[str] = None
project_id: Optional[str] = None
model_file: Optional[str] = None
inference_container_uri: Optional[str] = None

def __post_init__(self):
self._command = "model register"
Loading

0 comments on commit 4936025

Please sign in to comment.