Skip to content

Commit

Permalink
Merge branch 'main' into v2.11.15
Browse files Browse the repository at this point in the history
  • Loading branch information
VipulMascarenhas committed Jul 3, 2024
2 parents a7a9271 + c32e074 commit 6e0c223
Show file tree
Hide file tree
Showing 10 changed files with 294 additions and 263 deletions.
97 changes: 77 additions & 20 deletions ads/aqua/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
"""AQUA utils and constants."""

import asyncio
import base64
import json
Expand All @@ -19,13 +19,30 @@
import oci
from oci.data_science.models import JobRun, Model

from ads.aqua.common.enums import RqsAdditionalDetails
from ads.aqua.common.enums import (
InferenceContainerParamType,
InferenceContainerType,
RqsAdditionalDetails,
)
from ads.aqua.common.errors import (
AquaFileNotFoundError,
AquaRuntimeError,
AquaValueError,
)
from ads.aqua.constants import *
from ads.aqua.constants import (
AQUA_GA_LIST,
COMPARTMENT_MAPPING_KEY,
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
CONTAINER_INDEX,
MAXIMUM_ALLOWED_DATASET_IN_BYTE,
MODEL_BY_REFERENCE_OSS_PATH_KEY,
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
SUPPORTED_FILE_FORMATS,
TGI_INFERENCE_RESTRICTED_PARAMS,
UNKNOWN,
UNKNOWN_JSON_STR,
VLLM_INFERENCE_RESTRICTED_PARAMS,
)
from ads.aqua.data import AquaResourceIdentifier
from ads.common.auth import default_signer
from ads.common.decorator.threaded import threaded
Expand Down Expand Up @@ -74,15 +91,15 @@ def get_status(evaluation_status: str, job_run_status: str = None):

status = LifecycleStatus.UNKNOWN
if evaluation_status == Model.LIFECYCLE_STATE_ACTIVE:
if (
job_run_status == JobRun.LIFECYCLE_STATE_IN_PROGRESS
or job_run_status == JobRun.LIFECYCLE_STATE_ACCEPTED
):
if job_run_status in {
JobRun.LIFECYCLE_STATE_IN_PROGRESS,
JobRun.LIFECYCLE_STATE_ACCEPTED,
}:
status = JobRun.LIFECYCLE_STATE_IN_PROGRESS
elif (
job_run_status == JobRun.LIFECYCLE_STATE_FAILED
or job_run_status == JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION
):
elif job_run_status in {
JobRun.LIFECYCLE_STATE_FAILED,
JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION,
}:
status = JobRun.LIFECYCLE_STATE_FAILED
else:
status = job_run_status
Expand Down Expand Up @@ -199,10 +216,7 @@ def read_file(file_path: str, **kwargs) -> str:
@threaded()
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
if artifact_path.startswith("oci://"):
signer = default_signer()
else:
signer = {}
signer = default_signer() if artifact_path.startswith("oci://") else {}
config = json.loads(
read_file(file_path=artifact_path, auth=signer, **kwargs) or UNKNOWN_JSON_STR
)
Expand Down Expand Up @@ -448,7 +462,7 @@ def _build_resource_identifier(


def _get_experiment_info(
model: Union[oci.resource_search.models.ResourceSummary, DataScienceModel]
model: Union[oci.resource_search.models.ResourceSummary, DataScienceModel],
) -> tuple:
"""Returns ocid and name of the experiment."""
return (
Expand Down Expand Up @@ -609,7 +623,7 @@ def extract_id_and_name_from_tag(tag: str):
base_model_name = UNKNOWN
try:
base_model_ocid, base_model_name = tag.split("#")
except:
except Exception:
pass

if not (is_valid_ocid(base_model_ocid) and base_model_name):
Expand Down Expand Up @@ -646,7 +660,7 @@ def get_resource_name(ocid: str) -> str:
try:
resource = query_resource(ocid, return_all=False)
name = resource.display_name if resource else UNKNOWN
except:
except Exception:
name = UNKNOWN
return name

Expand All @@ -670,8 +684,8 @@ def get_model_by_reference_paths(model_file_description: dict):

if not models:
raise AquaValueError(
f"Model path is not available in the model json artifact. "
f"Please check if the model created by reference has the correct artifact."
"Model path is not available in the model json artifact. "
"Please check if the model created by reference has the correct artifact."
)

if len(models) > 0:
Expand Down Expand Up @@ -848,3 +862,46 @@ def copy_model_config(artifact_path: str, os_path: str, auth: dict = None):
except Exception as ex:
logger.debug(ex)
logger.debug(f"Failed to copy config folder from {artifact_path} to {os_path}.")


def get_container_params_type(container_type_name: str) -> str:
"""The utility function accepts the deployment container type name and returns the corresponding params name.
Parameters
----------
container_type_name: str
type of deployment container, like odsc-vllm-serving or odsc-tgi-serving.
Returns
-------
InferenceContainerParamType value
"""
# check substring instead of direct match in case container_type_name changes in the future
if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_name.lower():
return InferenceContainerParamType.PARAM_TYPE_VLLM
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
return InferenceContainerParamType.PARAM_TYPE_TGI
else:
return UNKNOWN


def get_restricted_params_by_container(container_type_name: str) -> set:
"""The utility function accepts the deployment container type name and returns a set of restricted params
for that container.
Parameters
----------
container_type_name: str
type of deployment container, like odsc-vllm-serving or odsc-tgi-serving.
Returns
-------
Set of restricted params based on container type
"""
# check substring instead of direct match in case container_type_name changes in the future
if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_name.lower():
return VLLM_INFERENCE_RESTRICTED_PARAMS
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
return TGI_INFERENCE_RESTRICTED_PARAMS
else:
return set()
47 changes: 30 additions & 17 deletions ads/aqua/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
"""This module defines constants used in ads.aqua module."""
Expand Down Expand Up @@ -45,19 +44,33 @@
SUPPORTED_FILE_FORMATS = ["jsonl"]
MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"

CONSOLE_LINK_RESOURCE_TYPE_MAPPING = dict(
datasciencemodel="models",
datasciencemodeldeployment="model-deployments",
datasciencemodeldeploymentdev="model-deployments",
datasciencemodeldeploymentint="model-deployments",
datasciencemodeldeploymentpre="model-deployments",
datasciencejob="jobs",
datasciencejobrun="job-runs",
datasciencejobrundev="job-runs",
datasciencejobrunint="job-runs",
datasciencejobrunpre="job-runs",
datasciencemodelversionset="model-version-sets",
datasciencemodelversionsetpre="model-version-sets",
datasciencemodelversionsetint="model-version-sets",
datasciencemodelversionsetdev="model-version-sets",
)
CONSOLE_LINK_RESOURCE_TYPE_MAPPING = {
"datasciencemodel": "models",
"datasciencemodeldeployment": "model-deployments",
"datasciencemodeldeploymentdev": "model-deployments",
"datasciencemodeldeploymentint": "model-deployments",
"datasciencemodeldeploymentpre": "model-deployments",
"datasciencejob": "jobs",
"datasciencejobrun": "job-runs",
"datasciencejobrundev": "job-runs",
"datasciencejobrunint": "job-runs",
"datasciencejobrunpre": "job-runs",
"datasciencemodelversionset": "model-version-sets",
"datasciencemodelversionsetpre": "model-version-sets",
"datasciencemodelversionsetint": "model-version-sets",
"datasciencemodelversionsetdev": "model-version-sets",
}

VLLM_INFERENCE_RESTRICTED_PARAMS = {
"--port",
"--host",
"--served-model-name",
"--seed",
}
TGI_INFERENCE_RESTRICTED_PARAMS = {
"--port",
"--hostname",
"--num-shard",
"--sharded",
"--trust-remote-code",
}
Loading

0 comments on commit 6e0c223

Please sign in to comment.