Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ODSC-63984] BYOC TEI deployment for embedding models #975

Merged
merged 19 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ads/aqua/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"


class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
Expand Down Expand Up @@ -80,3 +81,11 @@ class RqsAdditionalDetails(str, metaclass=ExtendedEnumMeta):
MODEL_VERSION_SET_NAME = "modelVersionSetName"
PROJECT_ID = "projectId"
VERSION_LABEL = "versionLabel"


class TextEmbeddingInferenceContainerParams(str, metaclass=ExtendedEnumMeta):
"""Contains a subset of params that are required for enabling model deployment in OCI Data Science. More options
are available at https://huggingface.co/docs/text-embeddings-inference/en/cli_arguments"""

MODEL_ID = "model-id"
PORT = "port"
88 changes: 83 additions & 5 deletions ads/aqua/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
InferenceContainerParamType,
InferenceContainerType,
RqsAdditionalDetails,
TextEmbeddingInferenceContainerParams,
)
from ads.aqua.common.errors import (
AquaFileNotFoundError,
Expand All @@ -51,6 +52,7 @@
MODEL_BY_REFERENCE_OSS_PATH_KEY,
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
SUPPORTED_FILE_FORMATS,
TEI_CONTAINER_DEFAULT_HOST,
TGI_INFERENCE_RESTRICTED_PARAMS,
UNKNOWN,
UNKNOWN_JSON_STR,
Expand All @@ -63,7 +65,12 @@
from ads.common.object_storage_details import ObjectStorageDetails
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
from ads.common.utils import copy_file, get_console_link, upload_to_os
from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID
from ads.config import (
AQUA_MODEL_DEPLOYMENT_FOLDER,
AQUA_SERVICE_MODELS_BUCKET,
CONDA_BUCKET_NS,
TENANCY_OCID,
)
from ads.model import DataScienceModel, ModelVersionSet

logger = logging.getLogger("ads.aqua")
Expand Down Expand Up @@ -569,15 +576,13 @@ def get_container_image(
A dict of allowed configs.
"""

container_image = UNKNOWN
config = config_file_name or get_container_config()
config_file_name = service_config_path()

if container_type not in config:
raise AquaValueError(
f"{config_file_name} does not have config details for model: {container_type}"
)
return UNKNOWN

container_image = None
mapping = config[container_type]
versions = [obj["version"] for obj in mapping]
# assumes numbered versions, update if `latest` is used
Expand Down Expand Up @@ -1078,3 +1083,76 @@ def list_hf_models(query: str) -> List[str]:
return [model.id for model in models if model.disabled is None]
except HfHubHTTPError as err:
raise format_hf_custom_error_message(err) from err


def generate_tei_cmd_var(os_path: str) -> List[str]:
"""This utility functions generates CMD params for Text Embedding Inference container. Only the
essential parameters for OCI model deployment are added, defaults are used for the rest.
Parameters
----------
os_path: str
OCI bucket path where the model artifacts are uploaded - oci://bucket@namespace/prefix
Returns
-------
cmd_var:
List of command line arguments
"""

cmd_prefix = "--"
cmd_var = [
f"{cmd_prefix}{TextEmbeddingInferenceContainerParams.MODEL_ID}",
f"{AQUA_MODEL_DEPLOYMENT_FOLDER}{ObjectStorageDetails.from_path(os_path.rstrip('/')).filepath}/",
f"{cmd_prefix}{TextEmbeddingInferenceContainerParams.PORT}",
TEI_CONTAINER_DEFAULT_HOST,
]

return cmd_var


def parse_cmd_var(cmd_list: List[str]) -> dict:
"""Helper functions that parses a list into a key-value dictionary. The list contains keys separated by the prefix
'--' and the value of the key is the subsequent element.
"""
parsed_cmd = {}

for i, cmd in enumerate(cmd_list):
if cmd.startswith("--"):
if i + 1 < len(cmd_list) and not cmd_list[i + 1].startswith("--"):
parsed_cmd[cmd] = cmd_list[i + 1]
i += 1
else:
parsed_cmd[cmd] = None
return parsed_cmd


def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
"""This function accepts two lists of parameters and combines them. If the second list shares the common parameter
names/keys, then it raises an error.
Parameters
----------
cmd_var: List[str]
Default list of parameters
overrides: List[str]
List of parameters to override
Returns
-------
List[str] of combined parameters
"""
cmd_var = [str(x) for x in cmd_var]
if not overrides:
return cmd_var
overrides = [str(x) for x in overrides]

cmd_dict = parse_cmd_var(cmd_var)
overrides_dict = parse_cmd_var(overrides)

# check for conflicts
common_keys = set(cmd_dict.keys()) & set(overrides_dict.keys())
if common_keys:
raise AquaValueError(
f"The following CMD input cannot be overridden for model deployment: {', '.join(common_keys)}"
)

combined_cmd_var = cmd_var + overrides
return combined_cmd_var
1 change: 1 addition & 0 deletions ads/aqua/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,4 @@
"--port",
"--host",
}
TEI_CONTAINER_DEFAULT_HOST = "8080"
1 change: 1 addition & 0 deletions ads/aqua/model/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ModelCustomMetadataFields(str, metaclass=ExtendedEnumMeta):
DEPLOYMENT_CONTAINER = "deployment-container"
EVALUATION_CONTAINER = "evaluation-container"
FINETUNE_CONTAINER = "finetune-container"
DEPLOYMENT_CONTAINER_URI = "deployment-container-uri"


class ModelTask(str, metaclass=ExtendedEnumMeta):
Expand Down
2 changes: 2 additions & 0 deletions ads/aqua/model/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class AquaModel(AquaModelSummary, DataClassSerializable):

model_card: str = None
inference_container: str = None
inference_container_uri: str = None
finetuning_container: str = None
evaluation_container: str = None
artifact_location: str = None
Expand Down Expand Up @@ -287,6 +288,7 @@ class ImportModelDetails(CLIBuilderMixin):
compartment_id: Optional[str] = None
project_id: Optional[str] = None
model_file: Optional[str] = None
inference_container_uri: Optional[str] = None

def __post_init__(self):
self._command = "model register"
Loading
Loading