diff --git a/.github/workflows/run-unittests-py38-cov-report.yml b/.github/workflows/run-unittests-py38-cov-report.yml index 2b234e77d..06476f7d3 100644 --- a/.github/workflows/run-unittests-py38-cov-report.yml +++ b/.github/workflows/run-unittests-py38-cov-report.yml @@ -79,7 +79,7 @@ jobs: run: | pip install -e ".[feature-store-marketplace]" - name: "Run unitary tests folder with maximum ADS dependencies" - timeout-minutes: 40 + timeout-minutes: 60 shell: bash env: CONDA_PREFIX: /usr/share/miniconda diff --git a/.github/workflows/run-unittests-py39-py310.yml b/.github/workflows/run-unittests-py39-py310.yml index dcc5a0f1b..9ea6595e0 100644 --- a/.github/workflows/run-unittests-py39-py310.yml +++ b/.github/workflows/run-unittests-py39-py310.yml @@ -85,7 +85,7 @@ jobs: tests/unitary/with_extras/hpo - name: "Run unitary tests folder with maximum ADS dependencies" - timeout-minutes: 30 + timeout-minutes: 60 shell: bash env: CONDA_PREFIX: /usr/share/miniconda diff --git a/ads/aqua/__init__.py b/ads/aqua/__init__.py index d5d945be4..eed99b3e6 100644 --- a/ads/aqua/__init__.py +++ b/ads/aqua/__init__.py @@ -7,7 +7,7 @@ import os from ads import logger, set_auth -from ads.aqua.utils import fetch_service_compartment +from ads.aqua.common.utils import fetch_service_compartment from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION ENV_VAR_LOG_LEVEL = "ADS_AQUA_LOG_LEVEL" diff --git a/ads/aqua/base.py b/ads/aqua/app.py similarity index 92% rename from ads/aqua/base.py rename to ads/aqua/app.py index ccada724c..966ed6de5 100644 --- a/ads/aqua/base.py +++ b/ads/aqua/app.py @@ -4,6 +4,7 @@ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import os +from dataclasses import fields from typing import Dict, Union import oci @@ -11,15 +12,15 @@ from ads import set_auth from ads.aqua import logger -from ads.aqua.data import Tags -from ads.aqua.exception import AquaRuntimeError, AquaValueError -from ads.aqua.utils import ( - UNKNOWN, +from ads.aqua.common.enums import Tags +from ads.aqua.common.errors import AquaRuntimeError, AquaValueError +from ads.aqua.common.utils import ( _is_valid_mvs, get_artifact_path, is_valid_ocid, load_config, ) +from ads.aqua.constants import UNKNOWN from ads.common import oci_client as oc from ads.common.auth import default_signer from ads.common.utils import extract_region @@ -160,7 +161,7 @@ def create_model_version_set( """ # TODO: tag should be selected based on which operation (eval/FT) invoke this method # currently only used by fine-tuning flow. - tag = Tags.AQUA_FINE_TUNING.value + tag = Tags.AQUA_FINE_TUNING if not model_version_set_id: try: @@ -277,8 +278,8 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict: oci_model = self.ds_client.get_model(model_id).data oci_aqua = ( ( - Tags.AQUA_TAG.value in oci_model.freeform_tags - or Tags.AQUA_TAG.value.lower() in oci_model.freeform_tags + Tags.AQUA_TAG in oci_model.freeform_tags + or Tags.AQUA_TAG.lower() in oci_model.freeform_tags ) if oci_model.freeform_tags else False @@ -319,3 +320,22 @@ def telemetry(self): bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS ) return self._telemetry + + +class CLIBuilderMixin: + """ + CLI builder from API interface. To be used with the DataClass only. + """ + + def build_cli(self) -> str: + """ + Method to turn the dataclass attributes to CLI + """ + cmd = f"ads aqua {self._command}" + params = [ + f"--{field.name} {getattr(self,field.name)}" + for field in fields(self.__class__) + if getattr(self, field.name) + ] + cmd = f"{cmd} {' '.join(params)}" + return cmd diff --git a/ads/aqua/cli.py b/ads/aqua/cli.py index 14c8bd3f8..2bab77be9 100644 --- a/ads/aqua/cli.py +++ b/ads/aqua/cli.py @@ -4,20 +4,20 @@ # Copyright (c) 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import os -import sys from ads.aqua import ( ENV_VAR_LOG_LEVEL, - set_log_level, ODSC_MODEL_COMPARTMENT_OCID, logger, + set_log_level, ) -from ads.aqua.deployment import AquaDeploymentApp +from ads.aqua.common.errors import AquaCLIError, AquaConfigError from ads.aqua.evaluation import AquaEvaluationApp -from ads.aqua.finetune import AquaFineTuningApp +from ads.aqua.finetuning import AquaFineTuningApp from ads.aqua.model import AquaModelApp -from ads.config import NB_SESSION_OCID +from ads.aqua.modeldeployment import AquaDeploymentApp from ads.common.utils import LOG_LEVELS +from ads.config import NB_SESSION_OCID class AquaCommand: @@ -35,6 +35,8 @@ class AquaCommand: def __init__( self, + debug: bool = None, + verbose: bool = None, log_level: str = os.environ.get(ENV_VAR_LOG_LEVEL, "ERROR").upper(), ): """ @@ -44,24 +46,64 @@ def __init__( ----- log_level (str): Sets the logging level for the application. - Default is retrieved from environment variable `LOG_LEVEL`, + Default is retrieved from environment variable `ADS_AQUA_LOG_LEVEL`, or 'ERROR' if not set. Example values include 'DEBUG', 'INFO', 'WARNING', 'ERROR', and 'CRITICAL'. + debug (bool): + Sets the logging level for the application to `DEBUG`. + verbose (bool): + Sets the logging level for the application to `INFO`. + + Raises + ------ + AquaCLIError: + When `--verbose` and `--debug` being used together. + When missing required `ODSC_MODEL_COMPARTMENT_OCID` env var. """ - if log_level.upper() not in LOG_LEVELS: - logger.error( - f"Log level should be one of {LOG_LEVELS}. Setting default to ERROR." + if verbose is not None and debug is not None: + raise AquaCLIError( + "Cannot use `--debug` and `--verbose` at the same time. " + "Please select either `--debug` for `DEBUG` level logging or " + "`--verbose` for `INFO` level logging." ) - log_level = "ERROR" - set_log_level(log_level) - # gracefully exit if env var is not set + elif verbose is not None: + self._validate_value("--verbose", verbose) + aqua_log_level = "INFO" + elif debug is not None: + self._validate_value("--debug", debug) + aqua_log_level = "DEBUG" + else: + if log_level.upper() not in LOG_LEVELS: + logger.warning( + f"Log level should be one of {LOG_LEVELS}. Setting default to ERROR." + ) + log_level = "ERROR" + aqua_log_level = log_level.upper() + + set_log_level(aqua_log_level) + if not ODSC_MODEL_COMPARTMENT_OCID: - logger.debug( - "ODSC_MODEL_COMPARTMENT_OCID environment variable is not set for Aqua." - ) if NB_SESSION_OCID: - logger.error( + raise AquaConfigError( f"Aqua is not available for the notebook session {NB_SESSION_OCID}. For more information, " f"please refer to the documentation." ) - sys.exit(1) + raise AquaConfigError( + "ODSC_MODEL_COMPARTMENT_OCID environment variable is not set for Aqua." + ) + + @staticmethod + def _validate_value(flag, value): + """Check if the given value for bool flag is valid. + + Raises + ------ + AquaCLIError: + When the given value for bool flag is invalid. + """ + if value not in [True, False]: + raise AquaCLIError( + f"Invalid input `{value}` for flag: {flag}, a boolean value is required. " + "If you intend to chain a function call to the result, please separate the " + "flag and the subsequent function call with separator `-`." + ) diff --git a/ads/aqua/common/__init__.py b/ads/aqua/common/__init__.py new file mode 100644 index 000000000..9eadd9943 --- /dev/null +++ b/ads/aqua/common/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ diff --git a/ads/aqua/decorator.py b/ads/aqua/common/decorator.py similarity index 86% rename from ads/aqua/decorator.py rename to ads/aqua/common/decorator.py index 520e930d9..b697afff2 100644 --- a/ads/aqua/decorator.py +++ b/ads/aqua/common/decorator.py @@ -7,6 +7,7 @@ import sys from functools import wraps +from typing import TYPE_CHECKING, Union from oci.exceptions import ( ClientError, @@ -19,9 +20,12 @@ ) from tornado.web import HTTPError -from ads.aqua.exception import AquaError +from ads.aqua.common.errors import AquaError from ads.aqua.extension.base_handler import AquaAPIhandler +if TYPE_CHECKING: + from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler + def handle_exceptions(func): """Writes errors raised during call to JSON. @@ -53,11 +57,13 @@ def handle_exceptions(func): """ @wraps(func) - def inner_function(self: AquaAPIhandler, *args, **kwargs): + def inner_function( + self: Union[AquaAPIhandler, "AquaWSMsgHandler"], *args, **kwargs + ): try: return func(self, *args, **kwargs) except ServiceError as error: - self.write_error( + return self.write_error( status_code=error.status or 500, message=error.message, reason=error.message, @@ -69,25 +75,25 @@ def inner_function(self: AquaAPIhandler, *args, **kwargs): MissingEndpointForNonRegionalServiceClientError, RequestException, ) as error: - self.write_error( + return self.write_error( status_code=400, reason=f"{type(error).__name__}: {str(error)}", exc_info=sys.exc_info(), ) except ConnectTimeout as error: - self.write_error( + return self.write_error( status_code=408, reason=f"{type(error).__name__}: {str(error)}", exc_info=sys.exc_info(), ) except (MultipartUploadError, CompositeOperationError) as error: - self.write_error( + return self.write_error( status_code=500, reason=f"{type(error).__name__}: {str(error)}", exc_info=sys.exc_info(), ) except AquaError as error: - self.write_error( + return self.write_error( status_code=error.status, reason=error.reason, service_payload=error.service_payload, @@ -100,7 +106,7 @@ def inner_function(self: AquaAPIhandler, *args, **kwargs): exc_info=sys.exc_info(), ) except Exception as ex: - self.write_error( + return self.write_error( status_code=500, reason=f"{type(ex).__name__}: {str(ex)}", exc_info=sys.exc_info(), diff --git a/ads/aqua/common/enums.py b/ads/aqua/common/enums.py new file mode 100644 index 000000000..3db1fd72e --- /dev/null +++ b/ads/aqua/common/enums.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +""" +aqua.common.enums +~~~~~~~~~~~~~~ +This module contains the set of enums used in AQUA. +""" +from ads.common.extended_enum import ExtendedEnumMeta + + +class DataScienceResource(str, metaclass=ExtendedEnumMeta): + MODEL_DEPLOYMENT = "datasciencemodeldeployment" + MODEL = "datasciencemodel" + + +class Resource(str, metaclass=ExtendedEnumMeta): + JOB = "jobs" + JOBRUN = "jobruns" + MODEL = "models" + MODEL_DEPLOYMENT = "modeldeployments" + MODEL_VERSION_SET = "model-version-sets" + + +class Tags(str, metaclass=ExtendedEnumMeta): + TASK = "task" + LICENSE = "license" + ORGANIZATION = "organization" + AQUA_TAG = "OCI_AQUA" + AQUA_SERVICE_MODEL_TAG = "aqua_service_model" + AQUA_FINE_TUNED_MODEL_TAG = "aqua_fine_tuned_model" + AQUA_MODEL_NAME_TAG = "aqua_model_name" + AQUA_EVALUATION = "aqua_evaluation" + AQUA_FINE_TUNING = "aqua_finetuning" + READY_TO_FINE_TUNE = "ready_to_fine_tune" + READY_TO_IMPORT = "ready_to_import" + BASE_MODEL_CUSTOM = "aqua_custom_base_model" + AQUA_EVALUATION_MODEL_ID = "evaluation_model_id" + + +class InferenceContainerType(str, metaclass=ExtendedEnumMeta): + CONTAINER_TYPE_VLLM = "vllm" + CONTAINER_TYPE_TGI = "tgi" + + +class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta): + AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving" + AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving" + + +class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta): + PARAM_TYPE_VLLM = "VLLM_PARAMS" + PARAM_TYPE_TGI = "TGI_PARAMS" + + +class HuggingFaceTags(str, metaclass=ExtendedEnumMeta): + TEXT_GENERATION_INFERENCE = "text-generation-inference" + + +class RqsAdditionalDetails(str, metaclass=ExtendedEnumMeta): + METADATA = "metadata" + CREATED_BY = "createdBy" + DESCRIPTION = "description" + MODEL_VERSION_SET_ID = "modelVersionSetId" + MODEL_VERSION_SET_NAME = "modelVersionSetName" + PROJECT_ID = "projectId" + VERSION_LABEL = "versionLabel" diff --git a/ads/aqua/exception.py b/ads/aqua/common/errors.py similarity index 77% rename from ads/aqua/exception.py rename to ads/aqua/common/errors.py index 908f6517e..adfe54397 100644 --- a/ads/aqua/exception.py +++ b/ads/aqua/common/errors.py @@ -10,6 +10,14 @@ This module contains the set of Aqua exceptions. """ +from ads.common.extended_enum import ExtendedEnumMeta + + +class ExitCode(str, metaclass=ExtendedEnumMeta): + SUCCESS = 0 + COMMON_ERROR = 1 + INVALID_CONFIG = 10 + class AquaError(Exception): """AquaError @@ -18,6 +26,8 @@ class AquaError(Exception): will inherit. """ + exit_code = 1 + def __init__( self, reason: str, @@ -80,3 +90,21 @@ class AquaResourceAccessError(AquaError): def __init__(self, reason, status=404, service_payload=None): super().__init__(reason, status, service_payload) + + +class AquaConfigError(AquaError): + """Exception raised when config for AQUA is invalid.""" + + exit_code = ExitCode.INVALID_CONFIG + + def __init__(self, reason, status=404, service_payload=None): + super().__init__(reason, status, service_payload) + + +class AquaCLIError(AquaError): + """Exception raised when AQUA CLI encounter error.""" + + exit_code = ExitCode.COMMON_ERROR + + def __init__(self, reason, status=None, service_payload=None): + super().__init__(reason, status, service_payload) diff --git a/ads/aqua/utils.py b/ads/aqua/common/utils.py similarity index 79% rename from ads/aqua/utils.py rename to ads/aqua/common/utils.py index d47284c0c..94d1d9a1e 100644 --- a/ads/aqua/utils.py +++ b/ads/aqua/common/utils.py @@ -10,7 +10,6 @@ import os import random import re -from enum import Enum from functools import wraps from pathlib import Path from string import Template @@ -20,64 +19,26 @@ import oci from oci.data_science.models import JobRun, Model -from ads.aqua.constants import RqsAdditionalDetails +from ads.aqua.common.enums import RqsAdditionalDetails +from ads.aqua.common.errors import ( + AquaFileNotFoundError, + AquaRuntimeError, + AquaValueError, +) +from ads.aqua.constants import * from ads.aqua.data import AquaResourceIdentifier -from ads.aqua.exception import AquaFileNotFoundError, AquaRuntimeError, AquaValueError from ads.common.auth import default_signer +from ads.common.extended_enum import ExtendedEnumMeta from ads.common.object_storage_details import ObjectStorageDetails from ads.common.oci_resource import SEARCH_TYPE, OCIResource -from ads.common.utils import get_console_link, upload_to_os +from ads.common.utils import get_console_link, upload_to_os, copy_file from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID from ads.model import DataScienceModel, ModelVersionSet logger = logging.getLogger("ads.aqua") -UNKNOWN = "" -UNKNOWN_DICT = {} -README = "README.md" -LICENSE_TXT = "config/LICENSE.txt" -DEPLOYMENT_CONFIG = "deployment_config.json" -COMPARTMENT_MAPPING_KEY = "service-model-compartment" -CONTAINER_INDEX = "container_index.json" -EVALUATION_REPORT_JSON = "report.json" -EVALUATION_REPORT_MD = "report.md" -EVALUATION_REPORT = "report.html" -UNKNOWN_JSON_STR = "{}" -CONSOLE_LINK_RESOURCE_TYPE_MAPPING = dict( - datasciencemodel="models", - datasciencemodeldeployment="model-deployments", - datasciencemodeldeploymentdev="model-deployments", - datasciencemodeldeploymentint="model-deployments", - datasciencemodeldeploymentpre="model-deployments", - datasciencejob="jobs", - datasciencejobrun="job-runs", - datasciencejobrundev="job-runs", - datasciencejobrunint="job-runs", - datasciencejobrunpre="job-runs", - datasciencemodelversionset="model-version-sets", - datasciencemodelversionsetpre="model-version-sets", - datasciencemodelversionsetint="model-version-sets", - datasciencemodelversionsetdev="model-version-sets", -) -FINE_TUNING_RUNTIME_CONTAINER = "iad.ocir.io/ociodscdev/aqua_ft_cuda121:0.3.17.20" -DEFAULT_FT_BLOCK_STORAGE_SIZE = 750 -DEFAULT_FT_REPLICA = 1 -DEFAULT_FT_BATCH_SIZE = 1 -DEFAULT_FT_VALIDATION_SET_SIZE = 0.1 - -HF_MODELS = "/home/datascience/conda/pytorch21_p39_gpu_v1/" -MAXIMUM_ALLOWED_DATASET_IN_BYTE = 52428800 # 1024 x 1024 x 50 = 50MB -JOB_INFRASTRUCTURE_TYPE_DEFAULT_NETWORKING = "ME_STANDALONE" -NB_SESSION_IDENTIFIER = "NB_SESSION_OCID" -LIFECYCLE_DETAILS_MISSING_JOBRUN = "The asscociated JobRun resource has been deleted." -READY_TO_DEPLOY_STATUS = "ACTIVE" -READY_TO_FINE_TUNE_STATUS = "TRUE" -AQUA_GA_LIST = ["id19sfcrra6z"] -AQUA_MODEL_TYPE_SERVICE = "service" -AQUA_MODEL_TYPE_CUSTOM = "custom" - - -class LifecycleStatus(Enum): + +class LifecycleStatus(str, metaclass=ExtendedEnumMeta): UNKNOWN = "" @property @@ -136,8 +97,6 @@ def get_status(evaluation_status: str, job_run_status: str = None): JobRun.LIFECYCLE_STATE_FAILED: "The evaluation failed.", JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION: "Missing jobrun information.", } -SUPPORTED_FILE_FORMATS = ["jsonl"] -MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location" def random_color_generator(word: str): @@ -201,22 +160,28 @@ def get_artifact_path(custom_metadata_list: List) -> str: Parameters ---------- custom_metadata_list: List - A list of custom metadata of model. + A list of custom metadata of OCI model. Returns ------- str: The artifact path from model. """ - for custom_metadata in custom_metadata_list: - if custom_metadata.key == MODEL_BY_REFERENCE_OSS_PATH_KEY: - if ObjectStorageDetails.is_oci_path(custom_metadata.value): - artifact_path = custom_metadata.value - else: - artifact_path = ObjectStorageDetails( - AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, custom_metadata.value - ).path - return artifact_path + try: + for custom_metadata in custom_metadata_list: + if custom_metadata.key == MODEL_BY_REFERENCE_OSS_PATH_KEY: + if ObjectStorageDetails.is_oci_path(custom_metadata.value): + artifact_path = custom_metadata.value + else: + artifact_path = ObjectStorageDetails( + AQUA_SERVICE_MODELS_BUCKET, + CONDA_BUCKET_NS, + custom_metadata.value, + ).path + return artifact_path + except Exception as ex: + logger.debug(ex) + logger.debug("Failed to get artifact path from custom metadata.") return UNKNOWN @@ -260,12 +225,10 @@ def is_valid_ocid(ocid: str) -> bool: bool: Whether the given ocid is valid. """ - # TODO: revisit pattern - pattern = ( - r"^ocid1\.([a-z0-9_]+)\.([a-z0-9]+)\.([a-z0-9-]*)(\.[^.]+)?\.([a-z0-9_]+)$" - ) - match = re.match(pattern, ocid) - return True + + if not ocid: + return False + return ocid.lower().startswith("ocid") def get_resource_type(ocid: str) -> str: @@ -520,6 +483,19 @@ def _build_job_identifier( return AquaResourceIdentifier() +def container_config_path(): + return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config" + + +def get_container_config(): + config = load_config( + file_path=container_config_path(), + config_file_name=CONTAINER_INDEX, + ) + + return config + + def get_container_image( config_file_name: str = None, container_type: str = None ) -> str: @@ -537,14 +513,8 @@ def get_container_image( A dict of allowed configs. """ - config_file_name = ( - f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config" - ) - - config = load_config( - file_path=config_file_name, - config_file_name=CONTAINER_INDEX, - ) + config = config_file_name or get_container_config() + config_file_name = container_config_path() if container_type not in config: raise AquaValueError( @@ -583,7 +553,7 @@ def fetch_service_compartment() -> Union[str, None]: config_file_name=CONTAINER_INDEX, ) except Exception as e: - logger.error( + logger.debug( f"Config file {config_file_name}/{CONTAINER_INDEX} to fetch service compartment OCID could not be found. " f"\n{str(e)}." ) @@ -750,3 +720,124 @@ def get_ocid_substring(ocid: str, key_len: int) -> str: """This helper function returns the last n characters of the ocid specified by key_len parameter. If ocid is None or length is less than key_len, it returns an empty string.""" return ocid[-key_len:] if ocid and len(ocid) > key_len else "" + + +def is_service_managed_container(container): + return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME) + + +def get_params_list(params: str) -> List[str]: + """Parses the string parameter and returns a list of params. + + Parameters + ---------- + params + string parameters by separated by -- delimiter + + Returns + ------- + list of params + + """ + if not params: + return [] + return ["--" + param.strip() for param in params.split("--")[1:]] + + +def get_params_dict(params: Union[str, List[str]]) -> dict: + """Accepts a string or list of string of double-dash parameters and returns a dict with the parameter keys and values. + + Parameters + ---------- + params: + List of parameters or parameter string separated by space. + + Returns + ------- + dict containing parameter keys and values + + """ + params_list = get_params_list(params) if isinstance(params, str) else params + return { + split_result[0]: split_result[1] if len(split_result) > 1 else UNKNOWN + for split_result in (x.split() for x in params_list) + } + + +def get_combined_params(params1: str = None, params2: str = None) -> str: + """ + Combines string of double-dash parameters, and overrides the values from the second string in the first. + Parameters + ---------- + params1: + Parameter string with values + params2: + Parameter string with values that need to be overridden. + + Returns + ------- + A combined list with overridden values from params2. + """ + if not params1: + return params2 + if not params2: + return params1 + + # overwrite values from params2 into params1 + combined_params = [ + f"{key} {value}" if value else key + for key, value in { + **get_params_dict(params1), + **get_params_dict(params2), + }.items() + ] + + return " ".join(combined_params) + + +def copy_model_config(artifact_path: str, os_path: str, auth: dict = None): + """Copies the aqua model config folder from the artifact path to the user provided object storage path. + The config folder is overwritten if the files already exist at the destination path. + + Parameters + ---------- + artifact_path: + Path of the aqua model where config folder is available. + os_path: + User provided path where config folder will be copied. + auth: (Dict, optional). Defaults to None. + The default authentication is set using `ads.set_auth` API. If you need to override the + default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate + authentication signer and kwargs required to instantiate IdentityClient object. + + Returns + ------- + None + Nothing. + """ + + try: + source_dir = ObjectStorageDetails( + AQUA_SERVICE_MODELS_BUCKET, + CONDA_BUCKET_NS, + f"{os.path.dirname(artifact_path).rstrip('/')}/config", + ).path + dest_dir = f"{os_path.rstrip('/')}/config" + + oss_details = ObjectStorageDetails.from_path(source_dir) + objects = oss_details.list_objects(fields="name").objects + + for obj in objects: + source_path = ObjectStorageDetails( + AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, obj.name + ).path + destination_path = os.path.join(dest_dir, os.path.basename(obj.name)) + copy_file( + uri_src=source_path, + uri_dst=destination_path, + force_overwrite=True, + auth=auth, + ) + except Exception as ex: + logger.debug(ex) + logger.debug(f"Failed to copy config folder from {artifact_path} to {os_path}.") diff --git a/ads/aqua/config/config.py b/ads/aqua/config/config.py new file mode 100644 index 000000000..2a358ce4f --- /dev/null +++ b/ads/aqua/config/config.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + + +# TODO: move this to global config.json in object storage +def get_finetuning_config_defaults(): + """Generate and return the fine-tuning default configuration dictionary.""" + return { + "shape": { + "VM.GPU.A10.1": {"batch_size": 1, "replica": "1-10"}, + "VM.GPU.A10.2": {"batch_size": 1, "replica": "1-10"}, + "BM.GPU.A10.4": {"batch_size": 1, "replica": 1}, + "BM.GPU4.8": {"batch_size": 4, "replica": 1}, + "BM.GPU.A100-v2.8": {"batch_size": 6, "replica": 1}, + } + } diff --git a/ads/aqua/constants.py b/ads/aqua/constants.py index 0d4ea4f78..7aaa63373 100644 --- a/ads/aqua/constants.py +++ b/ads/aqua/constants.py @@ -3,43 +3,61 @@ # Copyright (c) 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ """This module defines constants used in ads.aqua module.""" -from enum import Enum +UNKNOWN = "" UNKNOWN_VALUE = "" - - -class RqsAdditionalDetails: - METADATA = "metadata" - CREATED_BY = "createdBy" - DESCRIPTION = "description" - MODEL_VERSION_SET_ID = "modelVersionSetId" - MODEL_VERSION_SET_NAME = "modelVersionSetName" - PROJECT_ID = "projectId" - VERSION_LABEL = "versionLabel" - - -class FineTuningDefinedMetadata(Enum): - """Represents the defined metadata keys used in Fine Tuning.""" - - VAL_SET_SIZE = "val_set_size" - TRAINING_DATA = "training_data" - - -class FineTuningCustomMetadata(Enum): - """Represents the custom metadata keys used in Fine Tuning.""" - - FT_SOURCE = "fine_tune_source" - FT_SOURCE_NAME = "fine_tune_source_name" - FT_OUTPUT_PATH = "fine_tune_output_path" - FT_JOB_ID = "fine_tune_job_id" - FT_JOB_RUN_ID = "fine_tune_jobrun_id" - TRAINING_METRICS_FINAL = "train_metrics_final" - VALIDATION_METRICS_FINAL = "val_metrics_final" - TRAINING_METRICS_EPOCH = "train_metrics_epoch" - VALIDATION_METRICS_EPOCH = "val_metrics_epoch" - +READY_TO_IMPORT_STATUS = "TRUE" +UNKNOWN_DICT = {} +README = "README.md" +LICENSE_TXT = "config/LICENSE.txt" +DEPLOYMENT_CONFIG = "deployment_config.json" +COMPARTMENT_MAPPING_KEY = "service-model-compartment" +CONTAINER_INDEX = "container_index.json" +EVALUATION_REPORT_JSON = "report.json" +EVALUATION_REPORT_MD = "report.md" +EVALUATION_REPORT = "report.html" +UNKNOWN_JSON_STR = "{}" +FINE_TUNING_RUNTIME_CONTAINER = "iad.ocir.io/ociodscdev/aqua_ft_cuda121:0.3.17.20" +DEFAULT_FT_BLOCK_STORAGE_SIZE = 750 +DEFAULT_FT_REPLICA = 1 +DEFAULT_FT_BATCH_SIZE = 1 +DEFAULT_FT_VALIDATION_SET_SIZE = 0.1 + +MAXIMUM_ALLOWED_DATASET_IN_BYTE = 52428800 # 1024 x 1024 x 50 = 50MB +JOB_INFRASTRUCTURE_TYPE_DEFAULT_NETWORKING = "ME_STANDALONE" +NB_SESSION_IDENTIFIER = "NB_SESSION_OCID" +LIFECYCLE_DETAILS_MISSING_JOBRUN = "The asscociated JobRun resource has been deleted." +READY_TO_DEPLOY_STATUS = "ACTIVE" +READY_TO_FINE_TUNE_STATUS = "TRUE" +AQUA_GA_LIST = ["id19sfcrra6z"] +AQUA_MODEL_TYPE_SERVICE = "service" +AQUA_MODEL_TYPE_CUSTOM = "custom" +AQUA_MODEL_ARTIFACT_CONFIG = "config.json" +AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path" +AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type" TRAINING_METRICS_FINAL = "training_metrics_final" VALIDATION_METRICS_FINAL = "validation_metrics_final" TRINING_METRICS = "training_metrics" VALIDATION_METRICS = "validation_metrics" + +SERVICE_MANAGED_CONTAINER_URI_SCHEME = "dsmc://" +SUPPORTED_FILE_FORMATS = ["jsonl"] +MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location" + +CONSOLE_LINK_RESOURCE_TYPE_MAPPING = dict( + datasciencemodel="models", + datasciencemodeldeployment="model-deployments", + datasciencemodeldeploymentdev="model-deployments", + datasciencemodeldeploymentint="model-deployments", + datasciencemodeldeploymentpre="model-deployments", + datasciencejob="jobs", + datasciencejobrun="job-runs", + datasciencejobrundev="job-runs", + datasciencejobrunint="job-runs", + datasciencejobrunpre="job-runs", + datasciencemodelversionset="model-version-sets", + datasciencemodelversionsetpre="model-version-sets", + datasciencemodelversionsetint="model-version-sets", + datasciencemodelversionsetdev="model-version-sets", +) diff --git a/ads/aqua/data.py b/ads/aqua/data.py index b10635169..511936b95 100644 --- a/ads/aqua/data.py +++ b/ads/aqua/data.py @@ -3,8 +3,7 @@ # Copyright (c) 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -from dataclasses import dataclass -from enum import Enum +from dataclasses import dataclass, field from ads.common.serializer import DataClassSerializable @@ -16,27 +15,17 @@ class AquaResourceIdentifier(DataClassSerializable): url: str = "" -class Resource(Enum): - JOB = "jobs" - JOBRUN = "jobruns" - MODEL = "models" - MODEL_DEPLOYMENT = "modeldeployments" - MODEL_VERSION_SET = "model-version-sets" - - -class DataScienceResource(Enum): - MODEL_DEPLOYMENT = "datasciencemodeldeployment" - MODEL = "datasciencemodel" - - -class Tags(Enum): - TASK = "task" - LICENSE = "license" - ORGANIZATION = "organization" - AQUA_TAG = "OCI_AQUA" - AQUA_SERVICE_MODEL_TAG = "aqua_service_model" - AQUA_FINE_TUNED_MODEL_TAG = "aqua_fine_tuned_model" - AQUA_MODEL_NAME_TAG = "aqua_model_name" - AQUA_EVALUATION = "aqua_evaluation" - AQUA_FINE_TUNING = "aqua_finetuning" - READY_TO_FINE_TUNE = "ready_to_fine_tune" +@dataclass(repr=False) +class AquaJobSummary(DataClassSerializable): + """Represents an Aqua job summary.""" + + id: str + name: str + console_url: str + lifecycle_state: str + lifecycle_details: str + time_created: str + tags: dict + experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) + source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) + job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) diff --git a/ads/aqua/evaluation/__init__.py b/ads/aqua/evaluation/__init__.py new file mode 100644 index 000000000..4a783f85f --- /dev/null +++ b/ads/aqua/evaluation/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from ads.aqua.evaluation.evaluation import AquaEvaluationApp + +__all__ = ["AquaEvaluationApp"] diff --git a/ads/aqua/evaluation/constants.py b/ads/aqua/evaluation/constants.py new file mode 100644 index 000000000..0f0331f3f --- /dev/null +++ b/ads/aqua/evaluation/constants.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +""" +aqua.evaluation.const +~~~~~~~~~~~~~~ + +This module contains constants/enums used in Aqua Evaluation. +""" +from oci.data_science.models import JobRun + +from ads.common.extended_enum import ExtendedEnumMeta + +EVAL_TERMINATION_STATE = [ + JobRun.LIFECYCLE_STATE_SUCCEEDED, + JobRun.LIFECYCLE_STATE_FAILED, +] + + +class EvaluationCustomMetadata(str, metaclass=ExtendedEnumMeta): + EVALUATION_SOURCE = "evaluation_source" + EVALUATION_JOB_ID = "evaluation_job_id" + EVALUATION_JOB_RUN_ID = "evaluation_job_run_id" + EVALUATION_OUTPUT_PATH = "evaluation_output_path" + EVALUATION_SOURCE_NAME = "evaluation_source_name" + EVALUATION_ERROR = "aqua_evaluate_error" + + +class EvaluationConfig(str, metaclass=ExtendedEnumMeta): + PARAMS = "model_params" + + +class EvaluationReportJson(str, metaclass=ExtendedEnumMeta): + """Contains evaluation report.json fields name.""" + + METRIC_SUMMARY_RESULT = "metric_summary_result" + METRIC_RESULT = "metric_results" + MODEL_PARAMS = "model_params" + MODEL_DETAILS = "model_details" + DATA = "data" + DATASET = "dataset" + + +class EvaluationMetricResult(str, metaclass=ExtendedEnumMeta): + """Contains metric result's fields name in report.json.""" + + SHORT_NAME = "key" + NAME = "name" + DESCRIPTION = "description" + SUMMARY_DATA = "summary_data" + DATA = "data" diff --git a/ads/aqua/evaluation/entities.py b/ads/aqua/evaluation/entities.py new file mode 100644 index 000000000..df3f69f4a --- /dev/null +++ b/ads/aqua/evaluation/entities.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +""" +aqua.evaluation.entities +~~~~~~~~~~~~~~ + +This module contains dataclasses for aqua evaluation. +""" + +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from ads.aqua.data import AquaResourceIdentifier +from ads.common.serializer import DataClassSerializable + + +@dataclass(repr=False) +class CreateAquaEvaluationDetails(DataClassSerializable): + """Dataclass to create aqua model evaluation. + + Fields + ------ + evaluation_source_id: str + The evaluation source id. Must be either model or model deployment ocid. + evaluation_name: str + The name for evaluation. + dataset_path: str + The dataset path for the evaluation. Could be either a local path from notebook session + or an object storage path. + report_path: str + The report path for the evaluation. Must be an object storage path. + model_parameters: dict + The parameters for the evaluation. + shape_name: str + The shape name for the evaluation job infrastructure. + memory_in_gbs: float + The memory in gbs for the shape selected. + ocpus: float + The ocpu count for the shape selected. + block_storage_size: int + The storage for the evaluation job infrastructure. + compartment_id: (str, optional). Defaults to `None`. + The compartment id for the evaluation. + project_id: (str, optional). Defaults to `None`. + The project id for the evaluation. + evaluation_description: (str, optional). Defaults to `None`. + The description for evaluation + experiment_id: (str, optional). Defaults to `None`. + The evaluation model version set id. If provided, + evaluation model will be associated with it. + experiment_name: (str, optional). Defaults to `None`. + The evaluation model version set name. If provided, + the model version set with the same name will be used if exists, + otherwise a new model version set will be created with the name. + experiment_description: (str, optional). Defaults to `None`. + The description for the evaluation model version set. + log_group_id: (str, optional). Defaults to `None`. + The log group id for the evaluation job infrastructure. + log_id: (str, optional). Defaults to `None`. + The log id for the evaluation job infrastructure. + metrics: (list, optional). Defaults to `None`. + The metrics for the evaluation. + force_overwrite: (bool, optional). Defaults to `False`. + Whether to force overwrite the existing file in object storage. + """ + + evaluation_source_id: str + evaluation_name: str + dataset_path: str + report_path: str + model_parameters: dict + shape_name: str + block_storage_size: int + compartment_id: Optional[str] = None + project_id: Optional[str] = None + evaluation_description: Optional[str] = None + experiment_id: Optional[str] = None + experiment_name: Optional[str] = None + experiment_description: Optional[str] = None + memory_in_gbs: Optional[float] = None + ocpus: Optional[float] = None + log_group_id: Optional[str] = None + log_id: Optional[str] = None + metrics: Optional[List] = None + force_overwrite: Optional[bool] = False + + +@dataclass(repr=False) +class AquaEvalReport(DataClassSerializable): + evaluation_id: str = "" + content: str = "" + + +@dataclass(repr=False) +class ModelParams(DataClassSerializable): + max_tokens: str = "" + top_p: str = "" + top_k: str = "" + temperature: str = "" + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + stop: Optional[Union[str, List[str]]] = field(default_factory=list) + + +@dataclass(repr=False) +class AquaEvalParams(ModelParams, DataClassSerializable): + shape: str = "" + dataset_path: str = "" + report_path: str = "" + + +@dataclass(repr=False) +class AquaEvalMetric(DataClassSerializable): + key: str + name: str + description: str = "" + + +@dataclass(repr=False) +class AquaEvalMetricSummary(DataClassSerializable): + metric: str = "" + score: str = "" + grade: str = "" + + +@dataclass(repr=False) +class AquaEvalMetrics(DataClassSerializable): + id: str + report: str + metric_results: List[AquaEvalMetric] = field(default_factory=list) + metric_summary_result: List[AquaEvalMetricSummary] = field(default_factory=list) + + +@dataclass(repr=False) +class AquaEvaluationCommands(DataClassSerializable): + evaluation_id: str + evaluation_target_id: str + input_data: dict + metrics: list + output_dir: str + params: dict + + +@dataclass(repr=False) +class AquaEvaluationSummary(DataClassSerializable): + """Represents a summary of Aqua evalution.""" + + id: str + name: str + console_url: str + lifecycle_state: str + lifecycle_details: str + time_created: str + tags: dict + experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) + source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) + job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) + parameters: AquaEvalParams = field(default_factory=AquaEvalParams) + + +@dataclass(repr=False) +class AquaEvaluationDetail(AquaEvaluationSummary, DataClassSerializable): + """Represents a details of Aqua evalution.""" + + log_group: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) + log: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) + introspection: dict = field(default_factory=dict) diff --git a/ads/aqua/evaluation/errors.py b/ads/aqua/evaluation/errors.py new file mode 100644 index 000000000..c25f9be44 --- /dev/null +++ b/ads/aqua/evaluation/errors.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +""" +aqua.evaluation.errors +~~~~~~~~~~~~~~ + +This module contains errors in Aqua Evaluation. +""" + +from ads.common.extended_enum import ExtendedEnumMeta + + +class EvaluationJobExitCode(str, metaclass=ExtendedEnumMeta): + SUCCESS = 0 + COMMON_ERROR = 1 + + # Configuration-related issues 10-19 + INVALID_EVALUATION_CONFIG = 10 + EVALUATION_CONFIG_NOT_PROVIDED = 11 + INVALID_OUTPUT_DIR = 12 + INVALID_INPUT_DATASET_PATH = 13 + INVALID_EVALUATION_ID = 14 + INVALID_TARGET_EVALUATION_ID = 15 + INVALID_EVALUATION_CONFIG_VALIDATION = 16 + + # Evaluation process issues 20-39 + OUTPUT_DIR_NOT_FOUND = 20 + INVALID_INPUT_DATASET = 21 + INPUT_DATA_NOT_FOUND = 22 + EVALUATION_ID_NOT_FOUND = 23 + EVALUATION_ALREADY_PERFORMED = 24 + EVALUATION_TARGET_NOT_FOUND = 25 + NO_SUCCESS_INFERENCE_RESULT = 26 + COMPUTE_EVALUATION_ERROR = 27 + EVALUATION_REPORT_ERROR = 28 + MODEL_INFERENCE_WRONG_RESPONSE_FORMAT = 29 + UNSUPPORTED_METRICS = 30 + METRIC_CALCULATION_FAILURE = 31 + EVALUATION_MODEL_CATALOG_RECORD_CREATION_FAILED = 32 + + +EVALUATION_JOB_EXIT_CODE_MESSAGE = { + EvaluationJobExitCode.SUCCESS: "Success", + EvaluationJobExitCode.COMMON_ERROR: "An error occurred during the evaluation, please check the log for more information.", + EvaluationJobExitCode.INVALID_EVALUATION_CONFIG: "The provided evaluation configuration was not in the correct format, supported formats are YAML or JSON.", + EvaluationJobExitCode.EVALUATION_CONFIG_NOT_PROVIDED: "The evaluation config was not provided.", + EvaluationJobExitCode.INVALID_OUTPUT_DIR: "The specified output directory path is invalid.", + EvaluationJobExitCode.INVALID_INPUT_DATASET_PATH: "Dataset path is invalid.", + EvaluationJobExitCode.INVALID_EVALUATION_ID: "Evaluation ID was not found in the Model Catalog.", + EvaluationJobExitCode.INVALID_TARGET_EVALUATION_ID: "Target evaluation ID was not found in the Model Deployment.", + EvaluationJobExitCode.INVALID_EVALUATION_CONFIG_VALIDATION: "Validation errors in the evaluation config.", + EvaluationJobExitCode.OUTPUT_DIR_NOT_FOUND: "Destination folder does not exist or cannot be used for writing, verify the folder's existence and permissions.", + EvaluationJobExitCode.INVALID_INPUT_DATASET: "Input dataset is in an invalid format, ensure the dataset is in jsonl format and that includes the required columns: 'prompt', 'completion' (optional 'category').", + EvaluationJobExitCode.INPUT_DATA_NOT_FOUND: "Input data file does not exist or cannot be use for reading, verify the file's existence and permissions.", + EvaluationJobExitCode.EVALUATION_ID_NOT_FOUND: "Evaluation ID does not match any resource in the Model Catalog, or access may be blocked by policies.", + EvaluationJobExitCode.EVALUATION_ALREADY_PERFORMED: "Evaluation already has an attached artifact, indicating that the evaluation has already been performed.", + EvaluationJobExitCode.EVALUATION_TARGET_NOT_FOUND: "Target evaluation ID does not match any resources in Model Deployment.", + EvaluationJobExitCode.NO_SUCCESS_INFERENCE_RESULT: "Inference process completed without producing expected outcome, verify the model parameters and config.", + EvaluationJobExitCode.COMPUTE_EVALUATION_ERROR: "Evaluation process encountered an issue while calculating metrics.", + EvaluationJobExitCode.EVALUATION_REPORT_ERROR: "Failed to save the evaluation report due to an error. Ensure the evaluation model is currently active and the specified path for the output report is valid and accessible. Verify these conditions and reinitiate the evaluation process.", + EvaluationJobExitCode.MODEL_INFERENCE_WRONG_RESPONSE_FORMAT: "Evaluation encountered unsupported, or unexpected model output, verify the target evaluation model is compatible and produces the correct format.", + EvaluationJobExitCode.UNSUPPORTED_METRICS: "None of the provided metrics are supported by the framework.", + EvaluationJobExitCode.METRIC_CALCULATION_FAILURE: "All attempted metric calculations were unsuccessful. Please review the metric configurations and input data.", + EvaluationJobExitCode.EVALUATION_MODEL_CATALOG_RECORD_CREATION_FAILED: ( + "Failed to create a Model Catalog record for the evaluation. " + "This could be due to missing required permissions. " + "Please check the log for more information." + ), +} diff --git a/ads/aqua/evaluation.py b/ads/aqua/evaluation/evaluation.py similarity index 77% rename from ads/aqua/evaluation.py rename to ads/aqua/evaluation/evaluation.py index 39ad21cca..73d673c23 100644 --- a/ads/aqua/evaluation.py +++ b/ads/aqua/evaluation/evaluation.py @@ -8,12 +8,11 @@ import re import tempfile from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import asdict, dataclass, field +from dataclasses import asdict from datetime import datetime, timedelta -from enum import Enum from pathlib import Path from threading import Lock -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Union import oci from cachetools import TTLCache @@ -24,29 +23,40 @@ UpdateModelProvenanceDetails, ) -from ads.aqua import logger, utils -from ads.aqua.base import AquaApp -from ads.aqua.data import Tags -from ads.aqua.exception import ( +from ads.aqua import logger +from ads.aqua.app import AquaApp +from ads.aqua.common import utils +from ads.aqua.common.enums import ( + DataScienceResource, + Resource, + RqsAdditionalDetails, + Tags, +) +from ads.aqua.common.errors import ( AquaFileExistsError, AquaFileNotFoundError, AquaMissingKeyError, AquaRuntimeError, AquaValueError, ) -from ads.aqua.utils import ( - JOB_INFRASTRUCTURE_TYPE_DEFAULT_NETWORKING, - NB_SESSION_IDENTIFIER, - UNKNOWN, +from ads.aqua.common.utils import ( extract_id_and_name_from_tag, fire_and_forget, get_container_image, is_valid_ocid, upload_local_to_os, ) +from ads.aqua.constants import ( + JOB_INFRASTRUCTURE_TYPE_DEFAULT_NETWORKING, + NB_SESSION_IDENTIFIER, + UNKNOWN, + CONSOLE_LINK_RESOURCE_TYPE_MAPPING, +) +from ads.aqua.evaluation.constants import * +from ads.aqua.evaluation.entities import * +from ads.aqua.evaluation.errors import * from ads.common.auth import default_signer from ads.common.object_storage_details import ObjectStorageDetails -from ads.common.serializer import DataClassSerializable from ads.common.utils import get_console_link, get_files, get_log_links, upload_to_os from ads.config import ( AQUA_JOB_SUBNET_ID, @@ -69,279 +79,6 @@ from ads.model.model_version_set import ModelVersionSet from ads.telemetry import telemetry -EVAL_TERMINATION_STATE = [ - JobRun.LIFECYCLE_STATE_SUCCEEDED, - JobRun.LIFECYCLE_STATE_FAILED, -] - - -class EvaluationJobExitCode(Enum): - SUCCESS = 0 - COMMON_ERROR = 1 - - # Configuration-related issues 10-19 - INVALID_EVALUATION_CONFIG = 10 - EVALUATION_CONFIG_NOT_PROVIDED = 11 - INVALID_OUTPUT_DIR = 12 - INVALID_INPUT_DATASET_PATH = 13 - INVALID_EVALUATION_ID = 14 - INVALID_TARGET_EVALUATION_ID = 15 - INVALID_EVALUATION_CONFIG_VALIDATION = 16 - - # Evaluation process issues 20-39 - OUTPUT_DIR_NOT_FOUND = 20 - INVALID_INPUT_DATASET = 21 - INPUT_DATA_NOT_FOUND = 22 - EVALUATION_ID_NOT_FOUND = 23 - EVALUATION_ALREADY_PERFORMED = 24 - EVALUATION_TARGET_NOT_FOUND = 25 - NO_SUCCESS_INFERENCE_RESULT = 26 - COMPUTE_EVALUATION_ERROR = 27 - EVALUATION_REPORT_ERROR = 28 - MODEL_INFERENCE_WRONG_RESPONSE_FORMAT = 29 - UNSUPPORTED_METRICS = 30 - METRIC_CALCULATION_FAILURE = 31 - EVALUATION_MODEL_CATALOG_RECORD_CREATION_FAILED = 32 - - -EVALUATION_JOB_EXIT_CODE_MESSAGE = { - EvaluationJobExitCode.SUCCESS.value: "Success", - EvaluationJobExitCode.COMMON_ERROR.value: "An error occurred during the evaluation, please check the log for more information.", - EvaluationJobExitCode.INVALID_EVALUATION_CONFIG.value: "The provided evaluation configuration was not in the correct format, supported formats are YAML or JSON.", - EvaluationJobExitCode.EVALUATION_CONFIG_NOT_PROVIDED.value: "The evaluation config was not provided.", - EvaluationJobExitCode.INVALID_OUTPUT_DIR.value: "The specified output directory path is invalid.", - EvaluationJobExitCode.INVALID_INPUT_DATASET_PATH.value: "Dataset path is invalid.", - EvaluationJobExitCode.INVALID_EVALUATION_ID.value: "Evaluation ID was not found in the Model Catalog.", - EvaluationJobExitCode.INVALID_TARGET_EVALUATION_ID.value: "Target evaluation ID was not found in the Model Deployment.", - EvaluationJobExitCode.INVALID_EVALUATION_CONFIG_VALIDATION.value: "Validation errors in the evaluation config.", - EvaluationJobExitCode.OUTPUT_DIR_NOT_FOUND.value: "Destination folder does not exist or cannot be used for writing, verify the folder's existence and permissions.", - EvaluationJobExitCode.INVALID_INPUT_DATASET.value: "Input dataset is in an invalid format, ensure the dataset is in jsonl format and that includes the required columns: 'prompt', 'completion' (optional 'category').", - EvaluationJobExitCode.INPUT_DATA_NOT_FOUND.value: "Input data file does not exist or cannot be use for reading, verify the file's existence and permissions.", - EvaluationJobExitCode.EVALUATION_ID_NOT_FOUND.value: "Evaluation ID does not match any resource in the Model Catalog, or access may be blocked by policies.", - EvaluationJobExitCode.EVALUATION_ALREADY_PERFORMED.value: "Evaluation already has an attached artifact, indicating that the evaluation has already been performed.", - EvaluationJobExitCode.EVALUATION_TARGET_NOT_FOUND.value: "Target evaluation ID does not match any resources in Model Deployment.", - EvaluationJobExitCode.NO_SUCCESS_INFERENCE_RESULT.value: "Inference process completed without producing expected outcome, verify the model parameters and config.", - EvaluationJobExitCode.COMPUTE_EVALUATION_ERROR.value: "Evaluation process encountered an issue while calculating metrics.", - EvaluationJobExitCode.EVALUATION_REPORT_ERROR.value: "Failed to save the evaluation report due to an error. Ensure the evaluation model is currently active and the specified path for the output report is valid and accessible. Verify these conditions and reinitiate the evaluation process.", - EvaluationJobExitCode.MODEL_INFERENCE_WRONG_RESPONSE_FORMAT.value: "Evaluation encountered unsupported, or unexpected model output, verify the target evaluation model is compatible and produces the correct format.", - EvaluationJobExitCode.UNSUPPORTED_METRICS.value: "None of the provided metrics are supported by the framework.", - EvaluationJobExitCode.METRIC_CALCULATION_FAILURE.value: "All attempted metric calculations were unsuccessful. Please review the metric configurations and input data.", - EvaluationJobExitCode.EVALUATION_MODEL_CATALOG_RECORD_CREATION_FAILED.value: ( - "Failed to create a Model Catalog record for the evaluation. " - "This could be due to missing required permissions. " - "Please check the log for more information." - ), -} - - -class Resource(Enum): - JOB = "jobs" - MODEL = "models" - MODEL_DEPLOYMENT = "modeldeployments" - MODEL_VERSION_SET = "model-version-sets" - - -class DataScienceResource(Enum): - MODEL_DEPLOYMENT = "datasciencemodeldeployment" - MODEL = "datasciencemodel" - - -class EvaluationCustomMetadata(Enum): - EVALUATION_SOURCE = "evaluation_source" - EVALUATION_JOB_ID = "evaluation_job_id" - EVALUATION_JOB_RUN_ID = "evaluation_job_run_id" - EVALUATION_OUTPUT_PATH = "evaluation_output_path" - EVALUATION_SOURCE_NAME = "evaluation_source_name" - EVALUATION_ERROR = "aqua_evaluate_error" - - -class EvaluationModelTags(Enum): - AQUA_EVALUATION = "aqua_evaluation" - - -class EvaluationJobTags(Enum): - AQUA_EVALUATION = "aqua_evaluation" - EVALUATION_MODEL_ID = "evaluation_model_id" - - -class EvaluationUploadStatus(Enum): - IN_PROGRESS = "IN_PROGRESS" - COMPLETED = "COMPLETED" - - -@dataclass(repr=False) -class AquaResourceIdentifier(DataClassSerializable): - id: str = "" - name: str = "" - url: str = "" - - -@dataclass(repr=False) -class AquaEvalReport(DataClassSerializable): - evaluation_id: str = "" - content: str = "" - - -@dataclass(repr=False) -class ModelParams(DataClassSerializable): - max_tokens: str = "" - top_p: str = "" - top_k: str = "" - temperature: str = "" - presence_penalty: Optional[float] = 0.0 - frequency_penalty: Optional[float] = 0.0 - stop: Optional[Union[str, List[str]]] = field(default_factory=list) - - -@dataclass(repr=False) -class AquaEvalParams(ModelParams, DataClassSerializable): - shape: str = "" - dataset_path: str = "" - report_path: str = "" - - -@dataclass(repr=False) -class AquaEvalMetric(DataClassSerializable): - key: str - name: str - description: str = "" - - -@dataclass(repr=False) -class AquaEvalMetricSummary(DataClassSerializable): - metric: str = "" - score: str = "" - grade: str = "" - - -@dataclass(repr=False) -class AquaEvalMetrics(DataClassSerializable): - id: str - report: str - metric_results: List[AquaEvalMetric] = field(default_factory=list) - metric_summary_result: List[AquaEvalMetricSummary] = field(default_factory=list) - - -@dataclass(repr=False) -class AquaEvaluationCommands(DataClassSerializable): - evaluation_id: str - evaluation_target_id: str - input_data: dict - metrics: list - output_dir: str - params: dict - - -@dataclass(repr=False) -class AquaEvaluationSummary(DataClassSerializable): - """Represents a summary of Aqua evalution.""" - - id: str - name: str - console_url: str - lifecycle_state: str - lifecycle_details: str - time_created: str - tags: dict - experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - parameters: AquaEvalParams = field(default_factory=AquaEvalParams) - - -@dataclass(repr=False) -class AquaEvaluationDetail(AquaEvaluationSummary, DataClassSerializable): - """Represents a details of Aqua evalution.""" - - log_group: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - log: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - introspection: dict = field(default_factory=dict) - - -class RqsAdditionalDetails: - METADATA = "metadata" - CREATED_BY = "createdBy" - DESCRIPTION = "description" - MODEL_VERSION_SET_ID = "modelVersionSetId" - MODEL_VERSION_SET_NAME = "modelVersionSetName" - PROJECT_ID = "projectId" - VERSION_LABEL = "versionLabel" - - -class EvaluationConfig: - PARAMS = "model_params" - - -@dataclass(repr=False) -class CreateAquaEvaluationDetails(DataClassSerializable): - """Dataclass to create aqua model evaluation. - - Fields - ------ - evaluation_source_id: str - The evaluation source id. Must be either model or model deployment ocid. - evaluation_name: str - The name for evaluation. - dataset_path: str - The dataset path for the evaluation. Could be either a local path from notebook session - or an object storage path. - report_path: str - The report path for the evaluation. Must be an object storage path. - model_parameters: dict - The parameters for the evaluation. - shape_name: str - The shape name for the evaluation job infrastructure. - memory_in_gbs: float - The memory in gbs for the shape selected. - ocpus: float - The ocpu count for the shape selected. - block_storage_size: int - The storage for the evaluation job infrastructure. - compartment_id: (str, optional). Defaults to `None`. - The compartment id for the evaluation. - project_id: (str, optional). Defaults to `None`. - The project id for the evaluation. - evaluation_description: (str, optional). Defaults to `None`. - The description for evaluation - experiment_id: (str, optional). Defaults to `None`. - The evaluation model version set id. If provided, - evaluation model will be associated with it. - experiment_name: (str, optional). Defaults to `None`. - The evaluation model version set name. If provided, - the model version set with the same name will be used if exists, - otherwise a new model version set will be created with the name. - experiment_description: (str, optional). Defaults to `None`. - The description for the evaluation model version set. - log_group_id: (str, optional). Defaults to `None`. - The log group id for the evaluation job infrastructure. - log_id: (str, optional). Defaults to `None`. - The log id for the evaluation job infrastructure. - metrics: (list, optional). Defaults to `None`. - The metrics for the evaluation. - force_overwrite: (bool, optional). Defaults to `False`. - Whether to force overwrite the existing file in object storage. - """ - - evaluation_source_id: str - evaluation_name: str - dataset_path: str - report_path: str - model_parameters: dict - shape_name: str - block_storage_size: int - compartment_id: Optional[str] = None - project_id: Optional[str] = None - evaluation_description: Optional[str] = None - experiment_id: Optional[str] = None - experiment_name: Optional[str] = None - experiment_description: Optional[str] = None - memory_in_gbs: Optional[float] = None - ocpus: Optional[float] = None - log_group_id: Optional[str] = None - log_id: Optional[str] = None - metrics: Optional[List] = None - force_overwrite: Optional[bool] = False - class AquaEvaluationApp(AquaApp): """Provides a suite of APIs to interact with Aqua evaluations within the @@ -367,6 +104,9 @@ class AquaEvaluationApp(AquaApp): _report_cache = TTLCache(maxsize=10, ttl=timedelta(hours=5), timer=datetime.now) _metrics_cache = TTLCache(maxsize=10, ttl=timedelta(hours=5), timer=datetime.now) _eval_cache = TTLCache(maxsize=200, ttl=timedelta(hours=10), timer=datetime.now) + _deletion_cache = TTLCache( + maxsize=10, ttl=timedelta(minutes=10), timer=datetime.now + ) _cache_lock = Lock() @telemetry(entry_point="plugin=evaluation&action=create", name="aqua") @@ -408,14 +148,14 @@ def create( evaluation_source = None if ( - DataScienceResource.MODEL_DEPLOYMENT.value + DataScienceResource.MODEL_DEPLOYMENT in create_aqua_evaluation_details.evaluation_source_id ): evaluation_source = ModelDeployment.from_id( create_aqua_evaluation_details.evaluation_source_id ) elif ( - DataScienceResource.MODEL.value + DataScienceResource.MODEL in create_aqua_evaluation_details.evaluation_source_id ): evaluation_source = DataScienceModel.from_id( @@ -500,11 +240,9 @@ def create( name=experiment_model_version_set_name, compartment_id=target_compartment, ) - if not utils._is_valid_mvs( - model_version_set, Tags.AQUA_EVALUATION.value - ): + if not utils._is_valid_mvs(model_version_set, Tags.AQUA_EVALUATION): raise AquaValueError( - f"Invalid experiment name. Please provide an experiment with `{Tags.AQUA_EVALUATION.value}` in tags." + f"Invalid experiment name. Please provide an experiment with `{Tags.AQUA_EVALUATION}` in tags." ) except: logger.debug( @@ -513,7 +251,7 @@ def create( ) evaluation_mvs_freeform_tags = { - Tags.AQUA_EVALUATION.value: Tags.AQUA_EVALUATION.value, + Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION, } model_version_set = ( @@ -534,23 +272,23 @@ def create( experiment_model_version_set_id = model_version_set.id else: model_version_set = ModelVersionSet.from_id(experiment_model_version_set_id) - if not utils._is_valid_mvs(model_version_set, Tags.AQUA_EVALUATION.value): + if not utils._is_valid_mvs(model_version_set, Tags.AQUA_EVALUATION): raise AquaValueError( - f"Invalid experiment id. Please provide an experiment with `{Tags.AQUA_EVALUATION.value}` in tags." + f"Invalid experiment id. Please provide an experiment with `{Tags.AQUA_EVALUATION}` in tags." ) experiment_model_version_set_name = model_version_set.name evaluation_model_custom_metadata = ModelCustomMetadata() evaluation_model_custom_metadata.add( - key=EvaluationCustomMetadata.EVALUATION_SOURCE.value, + key=EvaluationCustomMetadata.EVALUATION_SOURCE, value=create_aqua_evaluation_details.evaluation_source_id, ) evaluation_model_custom_metadata.add( - key=EvaluationCustomMetadata.EVALUATION_OUTPUT_PATH.value, + key=EvaluationCustomMetadata.EVALUATION_OUTPUT_PATH, value=create_aqua_evaluation_details.report_path, ) evaluation_model_custom_metadata.add( - key=EvaluationCustomMetadata.EVALUATION_SOURCE_NAME.value, + key=EvaluationCustomMetadata.EVALUATION_SOURCE_NAME, value=evaluation_source.display_name, ) @@ -588,8 +326,8 @@ def create( # TODO: validate metrics if it's provided evaluation_job_freeform_tags = { - EvaluationJobTags.AQUA_EVALUATION.value: EvaluationJobTags.AQUA_EVALUATION.value, - EvaluationJobTags.EVALUATION_MODEL_ID.value: evaluation_model.id, + Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION, + Tags.AQUA_EVALUATION_MODEL_ID: evaluation_model.id, } evaluation_job = Job(name=evaluation_model.display_name).with_infrastructure( @@ -654,11 +392,11 @@ def create( ) evaluation_model_custom_metadata.add( - key=EvaluationCustomMetadata.EVALUATION_JOB_ID.value, + key=EvaluationCustomMetadata.EVALUATION_JOB_ID, value=evaluation_job.id, ) evaluation_model_custom_metadata.add( - key=EvaluationCustomMetadata.EVALUATION_JOB_RUN_ID.value, + key=EvaluationCustomMetadata.EVALUATION_JOB_RUN_ID, value=evaluation_job_run.id, ) updated_custom_metadata_list = [ @@ -671,7 +409,7 @@ def create( update_model_details=UpdateModelDetails( custom_metadata_list=updated_custom_metadata_list, freeform_tags={ - EvaluationModelTags.AQUA_EVALUATION.value: EvaluationModelTags.AQUA_EVALUATION.value, + Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION, }, ), ) @@ -702,7 +440,7 @@ def create( id=evaluation_model.id, name=evaluation_model.display_name, console_url=get_console_link( - resource=Resource.MODEL.value, + resource=Resource.MODEL, ocid=evaluation_model.id, region=self.region, ), @@ -713,7 +451,7 @@ def create( id=experiment_model_version_set_id, name=experiment_model_version_set_name, url=get_console_link( - resource=Resource.MODEL_VERSION_SET.value, + resource=Resource.MODEL_VERSION_SET, ocid=experiment_model_version_set_id, region=self.region, ), @@ -723,10 +461,10 @@ def create( name=evaluation_source.display_name, url=get_console_link( resource=( - Resource.MODEL_DEPLOYMENT.value - if DataScienceResource.MODEL_DEPLOYMENT.value + Resource.MODEL_DEPLOYMENT + if DataScienceResource.MODEL_DEPLOYMENT in create_aqua_evaluation_details.evaluation_source_id - else Resource.MODEL.value + else Resource.MODEL ), ocid=create_aqua_evaluation_details.evaluation_source_id, region=self.region, @@ -736,13 +474,13 @@ def create( id=evaluation_job.id, name=evaluation_job.name, url=get_console_link( - resource=Resource.JOB.value, + resource=Resource.JOB, ocid=evaluation_job.id, region=self.region, ), ), tags=dict( - aqua_evaluation=EvaluationModelTags.AQUA_EVALUATION.value, + aqua_evaluation=Tags.AQUA_EVALUATION, evaluation_job_id=evaluation_job.id, evaluation_source=create_aqua_evaluation_details.evaluation_source_id, evaluation_experiment_id=experiment_model_version_set_id, @@ -805,10 +543,10 @@ def _get_service_model_name( """ if isinstance(source, ModelDeployment): fine_tuned_model_tag = source.freeform_tags.get( - Tags.AQUA_FINE_TUNED_MODEL_TAG.value, UNKNOWN + Tags.AQUA_FINE_TUNED_MODEL_TAG, UNKNOWN ) if not fine_tuned_model_tag: - return source.freeform_tags.get(Tags.AQUA_MODEL_NAME_TAG.value) + return source.freeform_tags.get(Tags.AQUA_MODEL_NAME_TAG) else: return extract_id_and_name_from_tag(fine_tuned_model_tag)[1] @@ -971,12 +709,10 @@ def list( models = utils.query_resources( compartment_id=compartment_id, resource_type="datasciencemodel", - tag_list=[EvaluationModelTags.AQUA_EVALUATION.value], + tag_list=[Tags.AQUA_EVALUATION], ) logger.info(f"Fetched {len(models)} evaluations.") - # TODO: add filter based on project_id if needed. - mapping = self._prefetch_resources(compartment_id) evaluations = [] @@ -988,7 +724,7 @@ def list( else: jobrun_id = self._get_attribute_from_model_metadata( - model, EvaluationCustomMetadata.EVALUATION_JOB_RUN_ID.value + model, EvaluationCustomMetadata.EVALUATION_JOB_RUN_ID ) job_run = mapping.get(jobrun_id) @@ -1197,11 +933,13 @@ def load_metrics(self, eval_id: str) -> AquaEvalMetrics: ) files_in_artifact = get_files(temp_dir) - report_content = self._read_from_artifact( + md_report_content = self._read_from_artifact( temp_dir, files_in_artifact, utils.EVALUATION_REPORT_MD ) + + # json report not availiable for failed evaluation try: - report = json.loads( + json_report = json.loads( self._read_from_artifact( temp_dir, files_in_artifact, utils.EVALUATION_REPORT_JSON ) @@ -1210,27 +948,32 @@ def load_metrics(self, eval_id: str) -> AquaEvalMetrics: logger.debug( "Failed to load `report.json` from evaluation artifact" f"{str(e)}" ) - report = {} + json_report = {} - # TODO: after finalizing the format of report.json, move the constant to class eval_metrics = AquaEvalMetrics( id=eval_id, - report=base64.b64encode(report_content).decode(), + report=base64.b64encode(md_report_content).decode(), metric_results=[ AquaEvalMetric( - key=metric_key, - name=metadata.get("name", utils.UNKNOWN), - description=metadata.get("description", utils.UNKNOWN), + key=metadata.get(EvaluationMetricResult.SHORT_NAME, utils.UNKNOWN), + name=metadata.get(EvaluationMetricResult.NAME, utils.UNKNOWN), + description=metadata.get( + EvaluationMetricResult.DESCRIPTION, utils.UNKNOWN + ), ) - for metric_key, metadata in report.get("metric_results", {}).items() + for _, metadata in json_report.get( + EvaluationReportJson.METRIC_RESULT, {} + ).items() ], metric_summary_result=[ AquaEvalMetricSummary(**m) - for m in report.get("metric_summary_result", [{}]) + for m in json_report.get( + EvaluationReportJson.METRIC_SUMMARY_RESULT, [{}] + ) ], ) - if report_content: + if md_report_content: self._metrics_cache.__setitem__(key=eval_id, value=eval_metrics) return eval_metrics @@ -1371,6 +1114,7 @@ def _cancel_job_run(run, model): @telemetry(entry_point="plugin=evaluation&action=delete", name="aqua") def delete(self, eval_id): """Deletes the job and the associated model for the given evaluation id. + Parameters ---------- eval_id: str @@ -1383,9 +1127,9 @@ def delete(self, eval_id): Raises ------ AquaRuntimeError: - if a model doesn't exist for the given eval_id + if a model doesn't exist for the given eval_id. AquaMissingKeyError: - if training_id is missing the job run id + if job/jobrun id is missing. """ model = DataScienceModel.from_id(eval_id) @@ -1396,20 +1140,32 @@ def delete(self, eval_id): try: job_id = model.custom_metadata_list.get( - EvaluationCustomMetadata.EVALUATION_JOB_ID.value + EvaluationCustomMetadata.EVALUATION_JOB_ID ).value except Exception: raise AquaMissingKeyError( - f"Custom metadata is missing {EvaluationCustomMetadata.EVALUATION_JOB_ID.value} key" + f"Custom metadata is missing {EvaluationCustomMetadata.EVALUATION_JOB_ID} key" ) job = DataScienceJob.from_id(job_id) self._delete_job_and_model(job, model) + try: + jobrun_id = model.custom_metadata_list.get( + EvaluationCustomMetadata.EVALUATION_JOB_RUN_ID + ).value + jobrun = utils.query_resource(jobrun_id, return_all=False) + except Exception: + logger.debug("Associated Job Run OCID is missing.") + jobrun = None + + self._eval_cache.pop(key=eval_id, default=None) + self._deletion_cache.__setitem__(key=eval_id, value="") + status = dict( id=eval_id, - lifecycle_state="DELETING", + lifecycle_state=jobrun.lifecycle_state if jobrun else "DELETING", time_accepted=datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f%z"), ) return status @@ -1503,7 +1259,7 @@ def _get_source( """Returns ocid and name of the model has been evaluated.""" source_id = self._get_attribute_from_model_metadata( evaluation, - EvaluationCustomMetadata.EVALUATION_SOURCE.value, + EvaluationCustomMetadata.EVALUATION_SOURCE, ) try: @@ -1512,20 +1268,20 @@ def _get_source( source.display_name if source else self._get_attribute_from_model_metadata( - evaluation, EvaluationCustomMetadata.EVALUATION_SOURCE_NAME.value + evaluation, EvaluationCustomMetadata.EVALUATION_SOURCE_NAME ) ) - if not source_name: + # try to resolve source_name from source id + if source_id and not source_name: resource_type = utils.get_resource_type(source_id) - # TODO: adjust resource principal mapping - if resource_type == "datasciencemodel": - source_name = self.ds_client.get_model(source_id).data.display_name - elif resource_type == "datasciencemodeldeployment": + if resource_type.startswith("datasciencemodeldeployment"): source_name = self.ds_client.get_model_deployment( source_id ).data.display_name + elif resource_type.startswith("datasciencemodel"): + source_name = self.ds_client.get_model(source_id).data.display_name else: raise AquaRuntimeError( f"Not supported source type: {resource_type}" @@ -1589,7 +1345,7 @@ def _build_resource_identifier( ) -> AquaResourceIdentifier: """Constructs AquaResourceIdentifier based on the given ocid and display name.""" try: - resource_type = utils.CONSOLE_LINK_RESOURCE_TYPE_MAPPING.get( + resource_type = CONSOLE_LINK_RESOURCE_TYPE_MAPPING.get( utils.get_resource_type(id) ) @@ -1620,7 +1376,7 @@ def _fetch_jobrun( """Extracts job run id from metadata, and gets related job run information.""" jobrun_id = jobrun_id or self._get_attribute_from_model_metadata( - resource, EvaluationCustomMetadata.EVALUATION_JOB_RUN_ID.value + resource, EvaluationCustomMetadata.EVALUATION_JOB_RUN_ID ) logger.info(f"Fetching associated job run: {jobrun_id}") @@ -1654,8 +1410,6 @@ def _fetch_runtime_params( "model parameters have not been saved in correct format in model taxonomy. ", service_payload={"params": params}, ) - # TODO: validate the format of parameters. - # self._validate_params(params) return AquaEvalParams(**params[EvaluationConfig.PARAMS]) except Exception as e: @@ -1688,7 +1442,6 @@ def _build_job_identifier( ) return AquaResourceIdentifier() - # TODO: fix the logic for determine termination state def _get_status( self, model: oci.resource_search.models.ResourceSummary, @@ -1697,30 +1450,33 @@ def _get_status( ] = None, ) -> dict: """Builds evaluation status based on the model status and job run status. - When detect `aqua_evaluation_error` in custom metadata, the jobrun is failed. - However, if jobrun failed before saving this meta, we need to check the existance - of the evaluation artifact. + When missing jobrun information, the status will be decided based on: - """ - # TODO: revisit for CANCELED evaluation - job_run_status = ( - JobRun.LIFECYCLE_STATE_FAILED - if self._get_attribute_from_model_metadata( - model, EvaluationCustomMetadata.EVALUATION_ERROR.value - ) - else None - ) + * If the evaluation just has been deleted, the jobrun status should be deleted. + * When detect `aqua_evaluation_error` in custom metadata, the jobrun is failed. + * If jobrun failed before saving this meta, we need to check the existance + of the evaluation artifact. + """ model_status = model.lifecycle_state - job_run_status = job_run_status or ( - jobrun.lifecycle_state - if jobrun and not jobrun.lifecycle_state == JobRun.LIFECYCLE_STATE_DELETED - else ( - JobRun.LIFECYCLE_STATE_SUCCEEDED - if self._if_eval_artifact_exist(model) - else JobRun.LIFECYCLE_STATE_FAILED - ) - ) + job_run_status = None + + if jobrun: + job_run_status = jobrun.lifecycle_state + + if jobrun is None: + if model.identifier in self._deletion_cache.keys(): + job_run_status = JobRun.LIFECYCLE_STATE_DELETED + + elif self._get_attribute_from_model_metadata( + model, EvaluationCustomMetadata.EVALUATION_ERROR + ): + job_run_status = JobRun.LIFECYCLE_STATE_FAILED + + elif self._if_eval_artifact_exist(model): + job_run_status = JobRun.LIFECYCLE_STATE_SUCCEEDED + else: + job_run_status = JobRun.LIFECYCLE_STATE_FAILED lifecycle_state = utils.LifecycleStatus.get_status( evaluation_status=model_status, job_run_status=job_run_status @@ -1738,21 +1494,17 @@ def _get_status( return dict( lifecycle_state=( - lifecycle_state - if isinstance(lifecycle_state, str) - else lifecycle_state.value + lifecycle_state if isinstance(lifecycle_state, str) else lifecycle_state ), lifecycle_details=lifecycle_details, ) def _prefetch_resources(self, compartment_id) -> dict: """Fetches all AQUA resources.""" - # TODO: handle cross compartment/tenency resources - # TODO: add cache resources = utils.query_resources( compartment_id=compartment_id, resource_type="all", - tag_list=[EvaluationModelTags.AQUA_EVALUATION.value, "OCI_AQUA"], + tag_list=[Tags.AQUA_EVALUATION, "OCI_AQUA"], connect_by_ampersands=False, return_all=False, ) diff --git a/ads/aqua/extension/__init__.py b/ads/aqua/extension/__init__.py index 5fe85bcfb..4c8d9f3f3 100644 --- a/ads/aqua/extension/__init__.py +++ b/ads/aqua/extension/__init__.py @@ -14,6 +14,7 @@ from ads.aqua.extension.finetune_handler import __handlers__ as __finetune_handlers__ from ads.aqua.extension.model_handler import __handlers__ as __model_handlers__ from ads.aqua.extension.ui_handler import __handlers__ as __ui_handlers__ +from ads.aqua.extension.ui_websocket_handler import __handlers__ as __ws_handlers__ __handlers__ = ( __finetune_handlers__ @@ -22,6 +23,7 @@ + __deployment_handlers__ + __ui_handlers__ + __eval_handlers__ + + __ws_handlers__ ) diff --git a/ads/aqua/extension/aqua_ws_msg_handler.py b/ads/aqua/extension/aqua_ws_msg_handler.py new file mode 100644 index 000000000..1494ce028 --- /dev/null +++ b/ads/aqua/extension/aqua_ws_msg_handler.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import traceback +from abc import abstractmethod +from http.client import responses +from typing import List + +from tornado.web import HTTPError + +from ads.aqua import logger +from ads.aqua.common.decorator import handle_exceptions +from ads.aqua.extension.base_handler import AquaAPIhandler +from ads.aqua.extension.models.ws_models import ( + AquaWsError, + BaseRequest, + BaseResponse, + ErrorResponse, + RequestResponseType, +) +from ads.config import AQUA_TELEMETRY_BUCKET, AQUA_TELEMETRY_BUCKET_NS +from ads.telemetry.client import TelemetryClient + + +class AquaWSMsgHandler: + message: str + + def __init__(self, message: str): + self.message = message + try: + self.telemetry = TelemetryClient( + bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS + ) + except: + pass + + @staticmethod + @abstractmethod + def get_message_types() -> List[RequestResponseType]: + """This method should be implemented by the child class. + This method should return a list of RequestResponseType that the child class can handle + """ + pass + + @abstractmethod + @handle_exceptions + def process(self) -> BaseResponse: + """This method should be implemented by the child class. + This method will contain the core logic to be executed for handling the message + """ + pass + + def write_error(self, status_code, **kwargs): + """AquaWSMSGhandler errors are JSON, not human pages.""" + reason = kwargs.get("reason") + service_payload = kwargs.get("service_payload", {}) + default_msg = responses.get(status_code, "Unknown HTTP Error") + message = AquaAPIhandler.get_default_error_messages( + service_payload, str(status_code), kwargs.get("message", default_msg) + ) + reply = { + "status": status_code, + "message": message, + "service_payload": service_payload, + "reason": reason, + } + exc_info = kwargs.get("exc_info") + if exc_info: + logger.error("".join(traceback.format_exception(*exc_info))) + e = exc_info[1] + if isinstance(e, HTTPError): + reply["message"] = e.log_message or message + reply["reason"] = e.reason + else: + logger.warning(reply["message"]) + # telemetry may not be present if there is an error while initializing + if hasattr(self, "telemetry"): + self.telemetry.record_event_async( + category="aqua/error", + action=str(status_code), + value=reason, + ) + response = AquaWsError( + status=status_code, + message=message, + service_payload=service_payload, + reason=reason, + ) + base_message = BaseRequest.from_json(self.message, ignore_unknown=True) + return ErrorResponse( + message_id=base_message.message_id, + kind=RequestResponseType.Error, + data=response, + ) diff --git a/ads/aqua/extension/base_handler.py b/ads/aqua/extension/base_handler.py index b92e2deab..d84602bf7 100644 --- a/ads/aqua/extension/base_handler.py +++ b/ads/aqua/extension/base_handler.py @@ -137,10 +137,3 @@ def get_default_error_messages( return messages[status_code] else: return default_msg - - -# todo: remove after error handler is implemented -class Errors(str): - INVALID_INPUT_DATA_FORMAT = "Invalid format of input data." - NO_INPUT_DATA = "No input data provided." - MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'" diff --git a/ads/aqua/extension/common_handler.py b/ads/aqua/extension/common_handler.py index dd56284fe..2f78bc973 100644 --- a/ads/aqua/extension/common_handler.py +++ b/ads/aqua/extension/common_handler.py @@ -6,11 +6,15 @@ from importlib import metadata +import requests +from tornado.web import HTTPError + from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID -from ads.aqua.decorator import handle_exceptions -from ads.aqua.exception import AquaResourceAccessError +from ads.aqua.common.decorator import handle_exceptions +from ads.aqua.common.errors import AquaResourceAccessError, AquaRuntimeError +from ads.aqua.common.utils import fetch_service_compartment, known_realm from ads.aqua.extension.base_handler import AquaAPIhandler -from ads.aqua.utils import known_realm, fetch_service_compartment +from ads.aqua.extension.errors import Errors class ADSVersionHandler(AquaAPIhandler): @@ -33,9 +37,11 @@ def get(self): Returns ------- - status dict: - ok or compatible - Raises: + status dict: + ok or compatible + + Raises + ------ AquaResourceAccessError: raised when aqua is not accessible in the given session/region. """ diff --git a/ads/aqua/extension/deployment_handler.py b/ads/aqua/extension/deployment_handler.py index 1e74fdcf4..717ebfd2d 100644 --- a/ads/aqua/extension/deployment_handler.py +++ b/ads/aqua/extension/deployment_handler.py @@ -7,9 +7,11 @@ from tornado.web import HTTPError -from ads.aqua.decorator import handle_exceptions -from ads.aqua.deployment import AquaDeploymentApp, MDInferenceResponse, ModelParams -from ads.aqua.extension.base_handler import AquaAPIhandler, Errors +from ads.aqua.common.decorator import handle_exceptions +from ads.aqua.extension.errors import Errors +from ads.aqua.extension.base_handler import AquaAPIhandler +from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse +from ads.aqua.modeldeployment.entities import ModelParams from ads.config import COMPARTMENT_OCID, PROJECT_OCID @@ -93,6 +95,11 @@ def post(self, *args, **kwargs): description = input_data.get("description") instance_count = input_data.get("instance_count") bandwidth_mbps = input_data.get("bandwidth_mbps") + web_concurrency = input_data.get("web_concurrency") + server_port = input_data.get("server_port") + health_check_port = input_data.get("health_check_port") + env_var = input_data.get("env_var") + container_family = input_data.get("container_family") self.finish( AquaDeploymentApp().create( @@ -107,6 +114,11 @@ def post(self, *args, **kwargs): access_log_id=access_log_id, predict_log_id=predict_log_id, bandwidth_mbps=bandwidth_mbps, + web_concurrency=web_concurrency, + server_port=server_port, + health_check_port=health_check_port, + env_var=env_var, + container_family=container_family, ) ) @@ -192,8 +204,62 @@ def post(self, *args, **kwargs): ) +class AquaDeploymentParamsHandler(AquaAPIhandler): + """Handler for Aqua deployment params REST APIs. + + Methods + ------- + get(self, model_id) + Retrieves a list of model deployment parameters. + post(self, *args, **kwargs) + Validates parameters for the given model id. + """ + + @handle_exceptions + def get(self, model_id): + """Handle GET request.""" + instance_shape = self.get_argument("instance_shape") + return self.finish( + AquaDeploymentApp().get_deployment_default_params( + model_id=model_id, instance_shape=instance_shape + ) + ) + + @handle_exceptions + def post(self, *args, **kwargs): + """Handles post request for the deployment param handler API. + + Raises + ------ + HTTPError + Raises HTTPError if inputs are missing or are invalid. + """ + try: + input_data = self.get_json_body() + except Exception: + raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) + + if not input_data: + raise HTTPError(400, Errors.NO_INPUT_DATA) + + model_id = input_data.get("model_id") + if not model_id: + raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model_id")) + + params = input_data.get("params") + container_family = input_data.get("container_family") + return self.finish( + AquaDeploymentApp().validate_deployment_params( + model_id=model_id, + params=params, + container_family=container_family, + ) + ) + + __handlers__ = [ - ("deployments/?([^/]*)", AquaDeploymentHandler), + ("deployments/?([^/]*)/params", AquaDeploymentParamsHandler), ("deployments/config/?([^/]*)", AquaDeploymentHandler), + ("deployments/?([^/]*)", AquaDeploymentHandler), ("inference", AquaDeploymentInferenceHandler), ] diff --git a/ads/aqua/extension/errors.py b/ads/aqua/extension/errors.py new file mode 100644 index 000000000..2603ce33a --- /dev/null +++ b/ads/aqua/extension/errors.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + + +class Errors(str): + INVALID_INPUT_DATA_FORMAT = "Invalid format of input data." + NO_INPUT_DATA = "No input data provided." + MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'" diff --git a/ads/aqua/extension/evaluation_handler.py b/ads/aqua/extension/evaluation_handler.py index 1555fb827..38ed373e9 100644 --- a/ads/aqua/extension/evaluation_handler.py +++ b/ads/aqua/extension/evaluation_handler.py @@ -7,9 +7,11 @@ from tornado.web import HTTPError -from ads.aqua.decorator import handle_exceptions -from ads.aqua.evaluation import AquaEvaluationApp, CreateAquaEvaluationDetails -from ads.aqua.extension.base_handler import AquaAPIhandler, Errors +from ads.aqua.common.decorator import handle_exceptions +from ads.aqua.evaluation import AquaEvaluationApp +from ads.aqua.evaluation.entities import CreateAquaEvaluationDetails +from ads.aqua.extension.errors import Errors +from ads.aqua.extension.base_handler import AquaAPIhandler from ads.aqua.extension.utils import validate_function_parameters from ads.config import COMPARTMENT_OCID diff --git a/ads/aqua/extension/evaluation_ws_msg_handler.py b/ads/aqua/extension/evaluation_ws_msg_handler.py new file mode 100644 index 000000000..99384a216 --- /dev/null +++ b/ads/aqua/extension/evaluation_ws_msg_handler.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +from typing import List, Union + +from tornado.web import HTTPError + +from ads.aqua.common.decorator import handle_exceptions +from ads.aqua.evaluation import AquaEvaluationApp +from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler +from ads.aqua.extension.models.ws_models import ( + ListEvaluationsRequest, + ListEvaluationsResponse, + RequestResponseType, +) +from ads.config import COMPARTMENT_OCID + + +class AquaEvaluationWSMsgHandler(AquaWSMsgHandler): + @staticmethod + def get_message_types() -> List[RequestResponseType]: + return [RequestResponseType.ListEvaluations] + + def __init__(self, message: Union[str, bytes]): + super().__init__(message) + + @handle_exceptions + def process(self) -> ListEvaluationsResponse: + list_eval_request = ListEvaluationsRequest.from_json(self.message) + + eval_list = AquaEvaluationApp().list( + list_eval_request.compartment_id or COMPARTMENT_OCID, + list_eval_request.project_id, + ) + response = ListEvaluationsResponse( + message_id=list_eval_request.message_id, + kind=RequestResponseType.ListEvaluations, + data=eval_list, + ) + return response diff --git a/ads/aqua/extension/finetune_handler.py b/ads/aqua/extension/finetune_handler.py index 1809742ef..ba82b2dc1 100644 --- a/ads/aqua/extension/finetune_handler.py +++ b/ads/aqua/extension/finetune_handler.py @@ -8,10 +8,12 @@ from tornado.web import HTTPError -from ads.aqua.decorator import handle_exceptions -from ads.aqua.extension.base_handler import AquaAPIhandler, Errors +from ads.aqua.common.decorator import handle_exceptions +from ads.aqua.extension.errors import Errors +from ads.aqua.extension.base_handler import AquaAPIhandler from ads.aqua.extension.utils import validate_function_parameters -from ads.aqua.finetune import AquaFineTuningApp, CreateFineTuningDetails +from ads.aqua.finetuning import AquaFineTuningApp +from ads.aqua.finetuning.entities import CreateFineTuningDetails class AquaFineTuneHandler(AquaAPIhandler): @@ -59,7 +61,43 @@ def get_finetuning_config(self, model_id): return self.finish(AquaFineTuningApp().get_finetuning_config(model_id=model_id)) +class AquaFineTuneParamsHandler(AquaAPIhandler): + """Handler for Aqua finetuning params REST APIs.""" + + @handle_exceptions + def get(self, model_id): + """Handle GET request.""" + return self.finish( + AquaFineTuningApp().get_finetuning_default_params(model_id=model_id) + ) + + @handle_exceptions + def post(self, *args, **kwargs): + """Handles post request for the finetuning param handler API. + + Raises + ------ + HTTPError + Raises HTTPError if inputs are missing or are invalid. + """ + try: + input_data = self.get_json_body() + except Exception: + raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) + + if not input_data: + raise HTTPError(400, Errors.NO_INPUT_DATA) + + params = input_data.get("params", None) + return self.finish( + AquaFineTuningApp().validate_finetuning_params( + params=params, + ) + ) + + __handlers__ = [ + ("finetuning/?([^/]*)/params", AquaFineTuneParamsHandler), ("finetuning/?([^/]*)", AquaFineTuneHandler), ("finetuning/config/?([^/]*)", AquaFineTuneHandler), ] diff --git a/ads/aqua/extension/model_handler.py b/ads/aqua/extension/model_handler.py index 6715dc903..b1d1e40a9 100644 --- a/ads/aqua/extension/model_handler.py +++ b/ads/aqua/extension/model_handler.py @@ -3,11 +3,13 @@ # Copyright (c) 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import re +from typing import Optional from urllib.parse import urlparse from tornado.web import HTTPError - -from ads.aqua.decorator import handle_exceptions +from ads.aqua.extension.errors import Errors +from ads.aqua.common.decorator import handle_exceptions from ads.aqua.extension.base_handler import AquaAPIhandler from ads.aqua.model import AquaModelApp @@ -41,7 +43,56 @@ def list(self): compartment_id = self.get_argument("compartment_id", default=None) # project_id is no needed. project_id = self.get_argument("project_id", default=None) - return self.finish(AquaModelApp().list(compartment_id, project_id)) + model_type = self.get_argument("model_type", default=None) + return self.finish( + AquaModelApp().list( + compartment_id=compartment_id, + project_id=project_id, + model_type=model_type, + ) + ) + + @handle_exceptions + def post(self, *args, **kwargs): + """ + Handles post request for the registering any Aqua model. + Raises + ------ + HTTPError + Raises HTTPError if inputs are missing or are invalid + """ + try: + input_data = self.get_json_body() + except Exception: + raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) + + if not input_data: + raise HTTPError(400, Errors.NO_INPUT_DATA) + + # required input parameters + model = input_data.get("model") + if not model: + raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model")) + os_path = input_data.get("os_path") + if not os_path: + raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("os_path")) + + inference_container = input_data.get("inference_container") + finetuning_container = input_data.get("finetuning_container") + compartment_id = input_data.get("compartment_id") + project_id = input_data.get("project_id") + + return self.finish( + AquaModelApp().register( + model=model, + os_path=os_path, + inference_container=inference_container, + finetuning_container=finetuning_container, + compartment_id=compartment_id, + project_id=project_id, + ) + ) + class AquaModelLicenseHandler(AquaAPIhandler): """Handler for Aqua Model license REST APIs.""" @@ -49,10 +100,11 @@ class AquaModelLicenseHandler(AquaAPIhandler): @handle_exceptions def get(self, model_id): """Handle GET request.""" - + model_id = model_id.split("/")[0] return self.finish(AquaModelApp().load_license(model_id)) + __handlers__ = [ ("model/?([^/]*)", AquaModelHandler), ("model/?([^/]*)/license", AquaModelLicenseHandler), diff --git a/ads/aqua/extension/models/__init__.py b/ads/aqua/extension/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ads/aqua/extension/models/ws_models.py b/ads/aqua/extension/models/ws_models.py new file mode 100644 index 000000000..d9d20afe7 --- /dev/null +++ b/ads/aqua/extension/models/ws_models.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +from dataclasses import dataclass +from typing import List, Optional + +from ads.aqua.evaluation.entities import AquaEvaluationSummary +from ads.aqua.model.entities import AquaModelSummary +from ads.common.extended_enum import ExtendedEnumMeta +from ads.common.serializer import DataClassSerializable + + +class RequestResponseType(str, metaclass=ExtendedEnumMeta): + ListEvaluations = "ListEvaluations" + ListModels = "ListModels" + Error = "Error" + + +@dataclass +class BaseResponse(DataClassSerializable): + message_id: str + kind: RequestResponseType + data: object + + +@dataclass +class BaseRequest(DataClassSerializable): + message_id: str + kind: RequestResponseType + + +@dataclass +class ListEvaluationsRequest(BaseRequest): + compartment_id: Optional[str] = None + limit: Optional[int] = None + project_id: Optional[str] = None + kind = RequestResponseType.ListEvaluations + + +@dataclass +class ListModelsRequest(BaseRequest): + compartment_id: Optional[str] = None + + +@dataclass +class ListEvaluationsResponse(BaseResponse): + data: List[AquaEvaluationSummary] + + +@dataclass +class ListModelsResponse(BaseResponse): + data: List[AquaModelSummary] + + +@dataclass +class AquaWsError(DataClassSerializable): + status: str + message: str + service_payload: Optional[dict] + reason: Optional[str] + + +@dataclass +class ErrorResponse(BaseResponse): + data: AquaWsError + kind = RequestResponseType.Error diff --git a/ads/aqua/extension/ui_handler.py b/ads/aqua/extension/ui_handler.py index 05cf0d0c7..9aff616bb 100644 --- a/ads/aqua/extension/ui_handler.py +++ b/ads/aqua/extension/ui_handler.py @@ -3,17 +3,30 @@ # Copyright (c) 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from dataclasses import dataclass from urllib.parse import urlparse from tornado.web import HTTPError -from ads.aqua.data import Tags -from ads.aqua.decorator import handle_exceptions +from ads.aqua.common.decorator import handle_exceptions +from ads.aqua.common.enums import Tags +from ads.aqua.extension.errors import Errors from ads.aqua.extension.base_handler import AquaAPIhandler +from ads.aqua.extension.utils import validate_function_parameters +from ads.aqua.model.entities import ImportModelDetails from ads.aqua.ui import AquaUIApp from ads.config import COMPARTMENT_OCID +@dataclass +class CLIDetails: + """Interface to capture payload and command details for generating ads cli command""" + + command: str + subcommand: str + payload: dict + + class AquaUIHandler(AquaAPIhandler): """ Handler for Aqua UI REST APIs. @@ -63,6 +76,8 @@ def get(self, id=""): return self.get_shape_availability() elif paths.startswith("aqua/bucket/versioning"): return self.is_bucket_versioned() + elif paths.startswith("aqua/containers"): + return self.list_containers() else: raise HTTPError(400, f"The request {self.request.path} is invalid.") @@ -92,6 +107,10 @@ def list_compartments(self): """Lists the compartments in a compartment specified by ODSC_MODEL_COMPARTMENT_OCID env variable.""" return self.finish(AquaUIApp().list_compartments()) + def list_containers(self): + """Lists the AQUA containers.""" + return self.finish(AquaUIApp().list_containers()) + def get_default_compartment(self): """Returns user compartment ocid.""" return self.finish(AquaUIApp().get_default_compartment()) @@ -103,7 +122,7 @@ def list_model_version_sets(self, **kwargs): return self.finish( AquaUIApp().list_model_version_sets( compartment_id=compartment_id, - target_tag=Tags.AQUA_FINE_TUNING.value, + target_tag=Tags.AQUA_FINE_TUNING, **kwargs, ) ) @@ -115,7 +134,7 @@ def list_experiments(self, **kwargs): return self.finish( AquaUIApp().list_model_version_sets( compartment_id=compartment_id, - target_tag=Tags.AQUA_EVALUATION.value, + target_tag=Tags.AQUA_EVALUATION, **kwargs, ) ) @@ -175,6 +194,46 @@ def is_bucket_versioned(self): return self.finish(AquaUIApp().is_bucket_versioned(bucket_uri=bucket_uri)) +class AquaCLIHandler(AquaAPIhandler): + """Handler for Aqua model import + command_interface_map is a map of command+subcommand to corresponding API dataclas. + Eg. In command `ads aqua model register ....`, command is `model` and subcommand is `register` + The key in the map will be f"{command}_{sub_command}" and value will be a DataClass + """ + + command_interface_map = {"model_register": ImportModelDetails} + + @handle_exceptions + def post(self, *args, **kwargs): + """Handles cli command construction + + Raises + ------ + HTTPError + Raises HTTPError if inputs are missing or are invalid. + """ + try: + input_data = self.get_json_body() + except Exception: + raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) + + if not input_data: + raise HTTPError(400, Errors.NO_INPUT_DATA) + + validate_function_parameters(data_class=CLIDetails, input_data=input_data) + command_details = CLIDetails(**input_data) + + interface = AquaCLIHandler.command_interface_map[ + f"{command_details.command}_{command_details.subcommand}" + ] + + validate_function_parameters( + data_class=interface, input_data=command_details.payload + ) + payload = interface(**command_details.payload) + self.finish({"command": payload.build_cli()}) + + __handlers__ = [ ("logging/?([^/]*)", AquaUIHandler), ("compartments/?([^/]*)", AquaUIHandler), @@ -187,4 +246,6 @@ def is_bucket_versioned(self): ("subnets/?([^/]*)", AquaUIHandler), ("shapes/limit/?([^/]*)", AquaUIHandler), ("bucket/versioning/?([^/]*)", AquaUIHandler), + ("containers/?([^/]*)", AquaUIHandler), + ("cli/?([^/]*)", AquaCLIHandler), ] diff --git a/ads/aqua/extension/ui_websocket_handler.py b/ads/aqua/extension/ui_websocket_handler.py new file mode 100644 index 000000000..fa3a4e254 --- /dev/null +++ b/ads/aqua/extension/ui_websocket_handler.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import concurrent.futures +from asyncio.futures import Future +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Type, Union + +import tornado +from tornado import httputil +from tornado.ioloop import IOLoop +from tornado.websocket import WebSocketHandler + +from ads.aqua import logger +from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler +from ads.aqua.extension.evaluation_ws_msg_handler import AquaEvaluationWSMsgHandler +from ads.aqua.extension.models.ws_models import ( + AquaWsError, + BaseRequest, + BaseResponse, + ErrorResponse, + RequestResponseType, +) + +MAX_WORKERS = 20 + + +def get_aqua_internal_error_response(message_id: str) -> ErrorResponse: + error = AquaWsError( + status="500", + message="Internal Server Error", + service_payload={}, + reason="", + ) + return ErrorResponse( + message_id=message_id, + kind=RequestResponseType.Error, + data=error, + ) + + +class AquaUIWebSocketHandler(WebSocketHandler): + """Handler for Aqua Websocket.""" + + _handlers_: List[Type[AquaWSMsgHandler]] = [AquaEvaluationWSMsgHandler] + + thread_pool: ThreadPoolExecutor + + future_message_map: Dict[Future, BaseRequest] + message_type_handler_map: Dict[RequestResponseType, Type[AquaWSMsgHandler]] + + def __init__( + self, + application: tornado.web.Application, + request: httputil.HTTPServerRequest, + **kwargs, + ): + # Create a mapping of message type to handler and check for duplicates + self.future_message_map = {} + self.message_type_handler_map = {} + for handler in self._handlers_: + for message_type in handler.get_message_types(): + if message_type in self.message_type_handler_map: + raise ValueError( + f"Duplicate message type {message_type} in AQUA websocket handlers." + ) + else: + self.message_type_handler_map[message_type] = handler + + super().__init__(application, request, **kwargs) + + def open(self, *args, **kwargs): + self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS) + logger.info("AQUA WebSocket opened") + + def on_message(self, message: Union[str, bytes]): + try: + request = BaseRequest.from_json(message, ignore_unknown=True) + except Exception as e: + logger.error( + f"Unable to parse WebSocket message {message}\nWith exception: {str(e)}" + ) + raise e + # Find the handler for the message type. + # Each handler is responsible for some specific message types + handler = self.message_type_handler_map.get(request.kind, None) + if handler is None: + self.write_message( + get_aqua_internal_error_response(request.message_id).to_json() + ) + raise ValueError(f"No handler found for message type {request.kind}") + else: + message_handler = handler(message) + future: Future = self.thread_pool.submit(message_handler.process) + self.future_message_map[future] = request + future.add_done_callback(self.on_message_processed) + + def on_message_processed(self, future: concurrent.futures.Future): + """Callback function to handle the response from the various AquaWSMsgHandlers.""" + try: + response: BaseResponse = future.result() + + # Any exception coming here is an unhandled exception in the handler. We should log it and return an internal server error. + # In non WebSocket scenarios this would be handled by the tornado webserver + except Exception as e: + logger.error( + f"Unable to handle WebSocket message {self.future_message_map[future]}\nWith exception: {str(e)}" + ) + response: BaseResponse = get_aqua_internal_error_response( + self.future_message_map[future].message_id + ) + raise e + finally: + self.future_message_map.pop(future) + # Send the response back to the client on the event thread + IOLoop.current().run_sync(lambda: self.write_message(response.to_json())) + + def on_close(self) -> None: + self.thread_pool.shutdown() + logger.info("AQUA WebSocket closed") + + +__handlers__ = [("ws?([^/]*)", AquaUIWebSocketHandler)] diff --git a/ads/aqua/extension/utils.py b/ads/aqua/extension/utils.py index 5f8320498..c757d91e2 100644 --- a/ads/aqua/extension/utils.py +++ b/ads/aqua/extension/utils.py @@ -7,7 +7,7 @@ from tornado.web import HTTPError -from ads.aqua.extension.base_handler import Errors +from ads.aqua.extension.errors import Errors def validate_function_parameters(data_class, input_data: Dict): diff --git a/ads/aqua/finetuning/__init__.py b/ads/aqua/finetuning/__init__.py new file mode 100644 index 000000000..f6209b15d --- /dev/null +++ b/ads/aqua/finetuning/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from ads.aqua.finetuning.finetuning import AquaFineTuningApp + +__all__ = ["AquaFineTuningApp"] diff --git a/ads/aqua/finetuning/constants.py b/ads/aqua/finetuning/constants.py new file mode 100644 index 000000000..1e7309e61 --- /dev/null +++ b/ads/aqua/finetuning/constants.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +from ads.common.extended_enum import ExtendedEnumMeta + + +class FineTuneCustomMetadata(str, metaclass=ExtendedEnumMeta): + FINE_TUNE_SOURCE = "fine_tune_source" + FINE_TUNE_SOURCE_NAME = "fine_tune_source_name" + FINE_TUNE_OUTPUT_PATH = "fine_tune_output_path" + FINE_TUNE_JOB_ID = "fine_tune_job_id" + FINE_TUNE_JOB_RUN_ID = "fine_tune_job_run_id" + SERVICE_MODEL_ARTIFACT_LOCATION = "artifact_location" + SERVICE_MODEL_DEPLOYMENT_CONTAINER = "deployment-container" + SERVICE_MODEL_FINE_TUNE_CONTAINER = "finetune-container" diff --git a/ads/aqua/finetuning/entities.py b/ads/aqua/finetuning/entities.py new file mode 100644 index 000000000..f17788f03 --- /dev/null +++ b/ads/aqua/finetuning/entities.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from dataclasses import dataclass, field +from typing import List, Optional + +from ads.aqua.data import AquaJobSummary +from ads.common.serializer import DataClassSerializable + + +@dataclass(repr=False) +class AquaFineTuningParams(DataClassSerializable): + epochs: int + learning_rate: Optional[float] = None + sample_packing: Optional[bool] = "auto" + batch_size: Optional[ + int + ] = None # make it batch_size for user, but internally this is micro_batch_size + sequence_len: Optional[int] = None + pad_to_sequence_len: Optional[bool] = None + lora_r: Optional[int] = None + lora_alpha: Optional[int] = None + lora_dropout: Optional[float] = None + lora_target_linear: Optional[bool] = None + lora_target_modules: Optional[List] = None + + +@dataclass(repr=False) +class AquaFineTuningSummary(AquaJobSummary, DataClassSerializable): + parameters: AquaFineTuningParams = field(default_factory=AquaFineTuningParams) + + +@dataclass(repr=False) +class CreateFineTuningDetails(DataClassSerializable): + """Dataclass to create aqua model fine tuning. + + Fields + ------ + ft_source_id: str + The fine tuning source id. Must be model ocid. + ft_name: str + The name for fine tuning. + dataset_path: str + The dataset path for fine tuning. Could be either a local path from notebook session + or an object storage path. + report_path: str + The report path for fine tuning. Must be an object storage path. + ft_parameters: dict + The parameters for fine tuning. + shape_name: str + The shape name for fine tuning job infrastructure. + replica: int + The replica for fine tuning job runtime. + validation_set_size: float + The validation set size for fine tuning job. Must be a float in between [0,1). + ft_description: (str, optional). Defaults to `None`. + The description for fine tuning. + compartment_id: (str, optional). Defaults to `None`. + The compartment id for fine tuning. + project_id: (str, optional). Defaults to `None`. + The project id for fine tuning. + experiment_id: (str, optional). Defaults to `None`. + The fine tuning model version set id. If provided, + fine tuning model will be associated with it. + experiment_name: (str, optional). Defaults to `None`. + The fine tuning model version set name. If provided, + the fine tuning version set with the same name will be used if exists, + otherwise a new model version set will be created with the name. + experiment_description: (str, optional). Defaults to `None`. + The description for fine tuning model version set. + block_storage_size: (int, optional). Defaults to 256. + The storage for fine tuning job infrastructure. + subnet_id: (str, optional). Defaults to `None`. + The custom egress for fine tuning job. + log_group_id: (str, optional). Defaults to `None`. + The log group id for fine tuning job infrastructure. + log_id: (str, optional). Defaults to `None`. + The log id for fine tuning job infrastructure. + force_overwrite: (bool, optional). Defaults to `False`. + Whether to force overwrite the existing file in object storage. + """ + + ft_source_id: str + ft_name: str + dataset_path: str + report_path: str + ft_parameters: dict + shape_name: str + replica: int + validation_set_size: float + ft_description: Optional[str] = None + compartment_id: Optional[str] = None + project_id: Optional[str] = None + experiment_id: Optional[str] = None + experiment_name: Optional[str] = None + experiment_description: Optional[str] = None + block_storage_size: Optional[int] = None + subnet_id: Optional[str] = None + log_id: Optional[str] = None + log_group_id: Optional[str] = None + force_overwrite: Optional[bool] = False diff --git a/ads/aqua/finetune.py b/ads/aqua/finetuning/finetuning.py similarity index 72% rename from ads/aqua/finetune.py rename to ads/aqua/finetuning/finetuning.py index ad2d1027f..826d5df70 100644 --- a/ads/aqua/finetune.py +++ b/ads/aqua/finetuning/finetuning.py @@ -5,9 +5,8 @@ import json import os -from dataclasses import asdict, dataclass, field -from enum import Enum -from typing import Dict, Optional +from dataclasses import asdict, fields, MISSING +from typing import Dict from oci.data_science.models import ( Metadata, @@ -16,11 +15,14 @@ ) from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger -from ads.aqua.base import AquaApp -from ads.aqua.data import AquaResourceIdentifier, Resource, Tags -from ads.aqua.exception import AquaFileExistsError, AquaValueError -from ads.aqua.job import AquaJobSummary -from ads.aqua.utils import ( +from ads.aqua.app import AquaApp +from ads.aqua.common.enums import Resource, Tags +from ads.aqua.common.errors import AquaFileExistsError, AquaValueError +from ads.aqua.common.utils import ( + get_container_image, + upload_local_to_os, +) +from ads.aqua.constants import ( DEFAULT_FT_BATCH_SIZE, DEFAULT_FT_BLOCK_STORAGE_SIZE, DEFAULT_FT_REPLICA, @@ -28,14 +30,16 @@ JOB_INFRASTRUCTURE_TYPE_DEFAULT_NETWORKING, UNKNOWN, UNKNOWN_DICT, - get_container_image, - upload_local_to_os, ) +from ads.aqua.config.config import get_finetuning_config_defaults +from ads.aqua.data import AquaResourceIdentifier +from ads.aqua.finetuning.constants import * +from ads.aqua.finetuning.entities import * from ads.common.auth import default_signer from ads.common.object_storage_details import ObjectStorageDetails -from ads.common.serializer import DataClassSerializable from ads.common.utils import get_console_link from ads.config import ( + AQUA_FINETUNING_CONTAINER_OVERRIDE_FLAG_METADATA_NAME, AQUA_JOB_SUBNET_ID, AQUA_MODEL_FINETUNING_CONFIG, COMPARTMENT_OCID, @@ -54,100 +58,6 @@ from ads.telemetry import telemetry -class FineTuneCustomMetadata(Enum): - FINE_TUNE_SOURCE = "fine_tune_source" - FINE_TUNE_SOURCE_NAME = "fine_tune_source_name" - FINE_TUNE_OUTPUT_PATH = "fine_tune_output_path" - FINE_TUNE_JOB_ID = "fine_tune_job_id" - FINE_TUNE_JOB_RUN_ID = "fine_tune_job_run_id" - SERVICE_MODEL_ARTIFACT_LOCATION = "artifact_location" - SERVICE_MODEL_DEPLOYMENT_CONTAINER = "deployment-container" - SERVICE_MODEL_FINE_TUNE_CONTAINER = "finetune-container" - - -@dataclass(repr=False) -class AquaFineTuningParams(DataClassSerializable): - epochs: int = None - learning_rate: float = None - sample_packing: str = "True" - - -@dataclass(repr=False) -class AquaFineTuningSummary(AquaJobSummary, DataClassSerializable): - parameters: AquaFineTuningParams = field(default_factory=AquaFineTuningParams) - - -@dataclass(repr=False) -class CreateFineTuningDetails(DataClassSerializable): - """Dataclass to create aqua model fine tuning. - - Fields - ------ - ft_source_id: str - The fine tuning source id. Must be model ocid. - ft_name: str - The name for fine tuning. - dataset_path: str - The dataset path for fine tuning. Could be either a local path from notebook session - or an object storage path. - report_path: str - The report path for fine tuning. Must be an object storage path. - ft_parameters: dict - The parameters for fine tuning. - shape_name: str - The shape name for fine tuning job infrastructure. - replica: int - The replica for fine tuning job runtime. - validation_set_size: float - The validation set size for fine tuning job. Must be a float in between [0,1). - ft_description: (str, optional). Defaults to `None`. - The description for fine tuning. - compartment_id: (str, optional). Defaults to `None`. - The compartment id for fine tuning. - project_id: (str, optional). Defaults to `None`. - The project id for fine tuning. - experiment_id: (str, optional). Defaults to `None`. - The fine tuning model version set id. If provided, - fine tuning model will be associated with it. - experiment_name: (str, optional). Defaults to `None`. - The fine tuning model version set name. If provided, - the fine tuning version set with the same name will be used if exists, - otherwise a new model version set will be created with the name. - experiment_description: (str, optional). Defaults to `None`. - The description for fine tuning model version set. - block_storage_size: (int, optional). Defaults to 256. - The storage for fine tuning job infrastructure. - subnet_id: (str, optional). Defaults to `None`. - The custom egress for fine tuning job. - log_group_id: (str, optional). Defaults to `None`. - The log group id for fine tuning job infrastructure. - log_id: (str, optional). Defaults to `None`. - The log id for fine tuning job infrastructure. - force_overwrite: (bool, optional). Defaults to `False`. - Whether to force overwrite the existing file in object storage. - """ - - ft_source_id: str - ft_name: str - dataset_path: str - report_path: str - ft_parameters: dict - shape_name: str - replica: int - validation_set_size: float - ft_description: Optional[str] = None - compartment_id: Optional[str] = None - project_id: Optional[str] = None - experiment_id: Optional[str] = None - experiment_name: Optional[str] = None - experiment_description: Optional[str] = None - block_storage_size: Optional[int] = None - subnet_id: Optional[str] = None - log_id: Optional[str] = None - log_group_id: Optional[str] = None - force_overwrite: Optional[bool] = False - - class AquaFineTuningApp(AquaApp): """Provides a suite of APIs to interact with Aqua fine-tuned models within the Oracle Cloud Infrastructure Data Science service, serving as an interface for creating fine-tuned models. @@ -190,9 +100,12 @@ def create( try: create_fine_tuning_details = CreateFineTuningDetails(**kwargs) except: + allowed_create_fine_tuning_details = ", ".join( + field.name for field in fields(CreateFineTuningDetails) + ).rstrip() raise AquaValueError( "Invalid create fine tuning parameters. Allowable parameters are: " - f"{', '.join(list(asdict(CreateFineTuningDetails).keys()))}." + f"{allowed_create_fine_tuning_details}." ) source = self.get_source(create_fine_tuning_details.ft_source_id) @@ -247,9 +160,12 @@ def create( **create_fine_tuning_details.ft_parameters, ) except: + allowed_fine_tuning_parameters = ", ".join( + field.name for field in fields(AquaFineTuningParams) + ).rstrip() raise AquaValueError( "Invalid fine tuning parameters. Fine tuning parameters should " - f"be a dictionary with keys: {', '.join(list(asdict(AquaFineTuningParams).keys()))}." + f"be a dictionary with keys: {allowed_fine_tuning_parameters}." ) experiment_model_version_set_id = create_fine_tuning_details.experiment_id @@ -307,19 +223,20 @@ def create( ft_model_custom_metadata = ModelCustomMetadata() ft_model_custom_metadata.add( - key=FineTuneCustomMetadata.FINE_TUNE_SOURCE.value, + key=FineTuneCustomMetadata.FINE_TUNE_SOURCE, value=create_fine_tuning_details.ft_source_id, ) ft_model_custom_metadata.add( - key=FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME.value, + key=FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME, value=source.display_name, ) service_model_artifact_location = source.custom_metadata_list.get( - FineTuneCustomMetadata.SERVICE_MODEL_ARTIFACT_LOCATION.value + FineTuneCustomMetadata.SERVICE_MODEL_ARTIFACT_LOCATION ) service_model_deployment_container = source.custom_metadata_list.get( - FineTuneCustomMetadata.SERVICE_MODEL_DEPLOYMENT_CONTAINER.value + FineTuneCustomMetadata.SERVICE_MODEL_DEPLOYMENT_CONTAINER ) + ft_model_custom_metadata.add( key=service_model_artifact_location.key, value=service_model_artifact_location.value, @@ -350,8 +267,8 @@ def create( ) ft_job_freeform_tags = { - Tags.AQUA_TAG.value: UNKNOWN, - Tags.AQUA_FINE_TUNED_MODEL_TAG.value: f"{source.id}#{source.display_name}", + Tags.AQUA_TAG: UNKNOWN, + Tags.AQUA_FINE_TUNED_MODEL_TAG: f"{source.id}#{source.display_name}", } ft_job = Job(name=ft_model.display_name).with_infrastructure( @@ -381,10 +298,19 @@ def create( ft_config = self.get_finetuning_config(source.id) ft_container = source.custom_metadata_list.get( - FineTuneCustomMetadata.SERVICE_MODEL_FINE_TUNE_CONTAINER.value + FineTuneCustomMetadata.SERVICE_MODEL_FINE_TUNE_CONTAINER ).value - - batch_size = ( + is_custom_container = False + try: + # Check if the container override flag is set. If set, then the user has chosen custom image + if source.custom_metadata_list.get( + AQUA_FINETUNING_CONTAINER_OVERRIDE_FLAG_METADATA_NAME + ).value: + is_custom_container = True + except Exception: + pass + + ft_parameters.batch_size = ft_parameters.batch_size or ( ft_config.get("shape", UNKNOWN_DICT) .get(create_fine_tuning_details.shape_name, UNKNOWN_DICT) .get("batch_size", DEFAULT_FT_BATCH_SIZE) @@ -398,7 +324,6 @@ def create( dataset_path=ft_dataset_path, report_path=create_fine_tuning_details.report_path, replica=create_fine_tuning_details.replica, - batch_size=batch_size, finetuning_params=finetuning_params, val_set_size=( create_fine_tuning_details.validation_set_size @@ -406,6 +331,7 @@ def create( ), parameters=ft_parameters, ft_container=ft_container, + is_custom_container=is_custom_container, ) ).create() logger.debug( @@ -422,11 +348,11 @@ def create( ) ft_model_custom_metadata.add( - key=FineTuneCustomMetadata.FINE_TUNE_JOB_ID.value, + key=FineTuneCustomMetadata.FINE_TUNE_JOB_ID, value=ft_job.id, ) ft_model_custom_metadata.add( - key=FineTuneCustomMetadata.FINE_TUNE_JOB_RUN_ID.value, + key=FineTuneCustomMetadata.FINE_TUNE_JOB_RUN_ID, value=ft_job_run.id, ) updated_custom_metadata_list = [ @@ -435,16 +361,16 @@ def create( ] source_freeform_tags = source.freeform_tags or {} - source_freeform_tags.pop(Tags.LICENSE.value, None) - source_freeform_tags.update({Tags.READY_TO_FINE_TUNE.value: "false"}) - source_freeform_tags.update({Tags.AQUA_TAG.value: UNKNOWN}) + source_freeform_tags.pop(Tags.LICENSE, None) + source_freeform_tags.update({Tags.READY_TO_FINE_TUNE: "false"}) + source_freeform_tags.update({Tags.AQUA_TAG: UNKNOWN}) self.update_model( model_id=ft_model.id, update_model_details=UpdateModelDetails( custom_metadata_list=updated_custom_metadata_list, freeform_tags={ - Tags.AQUA_FINE_TUNED_MODEL_TAG.value: ( + Tags.AQUA_FINE_TUNED_MODEL_TAG: ( f"{source.id}#{source.display_name}" ), **source_freeform_tags, @@ -489,7 +415,7 @@ def create( id=ft_model.id, name=ft_model.display_name, console_url=get_console_link( - resource=Resource.MODEL.value, + resource=Resource.MODEL, ocid=ft_model.id, region=self.region, ), @@ -500,7 +426,7 @@ def create( id=experiment_model_version_set_id, name=experiment_model_version_set_name, url=get_console_link( - resource=Resource.MODEL_VERSION_SET.value, + resource=Resource.MODEL_VERSION_SET, ocid=experiment_model_version_set_id, region=self.region, ), @@ -509,7 +435,7 @@ def create( id=source.id, name=source.display_name, url=get_console_link( - resource=Resource.MODEL.value, + resource=Resource.MODEL, ocid=source.id, region=self.region, ), @@ -518,18 +444,22 @@ def create( id=ft_job.id, name=ft_job.name, url=get_console_link( - resource=Resource.JOB.value, + resource=Resource.JOB, ocid=ft_job.id, region=self.region, ), ), tags=dict( - aqua_finetuning=Tags.AQUA_FINE_TUNING.value, + aqua_finetuning=Tags.AQUA_FINE_TUNING, finetuning_job_id=ft_job.id, finetuning_source=source.id, finetuning_experiment_id=experiment_model_version_set_id, ), - parameters=ft_parameters, + parameters={ + key: value + for key, value in asdict(ft_parameters).items() + if value is not None + }, ) def _build_fine_tuning_runtime( @@ -539,15 +469,19 @@ def _build_fine_tuning_runtime( dataset_path: str, report_path: str, replica: int, - batch_size: int, val_set_size: float, parameters: AquaFineTuningParams, ft_container: str = None, finetuning_params: str = None, + is_custom_container: bool = False, ) -> Runtime: """Builds fine tuning runtime for Job.""" - container = get_container_image( - container_type=ft_container, + container = ( + get_container_image( + container_type=ft_container, + ) + if not is_custom_container + else ft_container ) runtime = ( ContainerRuntime() @@ -562,9 +496,12 @@ def _build_fine_tuning_runtime( }, } ), - "OCI__LAUNCH_CMD": ( - f"--micro_batch_size {batch_size} --num_epochs {parameters.epochs} --learning_rate {parameters.learning_rate} --training_data {dataset_path} --output_dir {report_path} --val_set_size {val_set_size} --sample_packing {parameters.sample_packing} " - + (f"{finetuning_params}" if finetuning_params else "") + "OCI__LAUNCH_CMD": self._build_oci_launch_cmd( + dataset_path=dataset_path, + report_path=report_path, + val_set_size=val_set_size, + parameters=parameters, + finetuning_params=finetuning_params, ), "CONDA_BUCKET_NS": CONDA_BUCKET_NS, } @@ -575,6 +512,30 @@ def _build_fine_tuning_runtime( return runtime + @staticmethod + def _build_oci_launch_cmd( + dataset_path: str, + report_path: str, + val_set_size: float, + parameters: AquaFineTuningParams, + finetuning_params: str = None, + ) -> str: + """Builds the oci launch cmd for fine tuning container runtime.""" + oci_launch_cmd = f"--training_data {dataset_path} --output_dir {report_path} --val_set_size {val_set_size} " + for key, value in asdict(parameters).items(): + if value is not None: + if key == "batch_size": + oci_launch_cmd += f"--micro_{key} {value} " + elif key == "epochs": + oci_launch_cmd += f"--num_{key} {value} " + elif key == "lora_target_modules": + oci_launch_cmd += f"--{key} {','.join(str(k) for k in value)} " + else: + oci_launch_cmd += f"--{key} {value} " + + oci_launch_cmd += f"{finetuning_params}" if finetuning_params else "" + return oci_launch_cmd.rstrip() + @telemetry( entry_point="plugin=finetuning&action=get_finetuning_config", name="aqua" ) @@ -592,4 +553,69 @@ def get_finetuning_config(self, model_id: str) -> Dict: A dict of allowed finetuning configs. """ - return self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG) + config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG) + if not config: + logger.info(f"Fetching default fine-tuning config for model: {model_id}") + config = get_finetuning_config_defaults() + return config + + @telemetry( + entry_point="plugin=finetuning&action=get_finetuning_default_params", + name="aqua", + ) + def get_finetuning_default_params(self, model_id: str) -> Dict: + """Gets the default params set in the finetuning configs for the given model. Only the fields that are + available in AquaFineTuningParams will be accessible for user overrides. + + Parameters + ---------- + model_id: str + The OCID of the Aqua model. + + Returns + ------- + Dict: + Dict of parameters from the loaded from finetuning config json file. If config information is not available, + then an empty dict is returned. + """ + default_params = {"params": {}} + finetuning_config = self.get_finetuning_config(model_id) + config_parameters = finetuning_config.get("configuration", UNKNOWN_DICT) + dataclass_fields = {field.name for field in fields(AquaFineTuningParams)} + for name, value in config_parameters.items(): + if name == "micro_batch_size": + name = "batch_size" + if name in dataclass_fields: + default_params["params"][name] = value + + return default_params + + def validate_finetuning_params(self, params: Dict = None) -> Dict: + """Validate if the fine-tuning parameters passed by the user can be overridden. Parameter values are not + validated, only param keys are validated. + + Parameters + ---------- + params :Dict, optional + Params passed by the user. + + Returns + ------- + Return a list of restricted params. + """ + try: + AquaFineTuningParams( + **params, + ) + except Exception as e: + logger.debug(str(e)) + allowed_fine_tuning_parameters = ", ".join( + f"{field.name} (required)" if field.default is MISSING else field.name + for field in fields(AquaFineTuningParams) + ).rstrip() + raise AquaValueError( + f"Invalid fine tuning parameters. Allowable parameters are: " + f"{allowed_fine_tuning_parameters}." + ) + + return dict(valid=True) diff --git a/ads/aqua/job.py b/ads/aqua/job.py deleted file mode 100644 index 2ff958283..000000000 --- a/ads/aqua/job.py +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8; -*- - -# Copyright (c) 2024 Oracle and/or its affiliates. -# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ - - -import logging -from dataclasses import dataclass, field -from ads.common.serializer import DataClassSerializable -from ads.aqua.data import AquaResourceIdentifier - -logger = logging.getLogger(__name__) - - -@dataclass(repr=False) -class AquaJobSummary(DataClassSerializable): - """Represents an Aqua job summary.""" - - id: str - name: str - console_url: str - lifecycle_state: str - lifecycle_details: str - time_created: str - tags: dict - experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) diff --git a/ads/aqua/model/__init__.py b/ads/aqua/model/__init__.py new file mode 100644 index 000000000..74fcea5a3 --- /dev/null +++ b/ads/aqua/model/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from ads.aqua.model.model import AquaModelApp + +__all__ = ["AquaModelApp"] diff --git a/ads/aqua/model/constants.py b/ads/aqua/model/constants.py new file mode 100644 index 000000000..bedfe2202 --- /dev/null +++ b/ads/aqua/model/constants.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +""" +aqua.model.constants +~~~~~~~~~~~~~~~~~~~~ + +This module contains constants/enums used in Aqua Model. +""" +from ads.common.extended_enum import ExtendedEnumMeta + + +class ModelCustomMetadataFields(str, metaclass=ExtendedEnumMeta): + ARTIFACT_LOCATION = "artifact_location" + DEPLOYMENT_CONTAINER = "deployment-container" + EVALUATION_CONTAINER = "evaluation-container" + FINETUNE_CONTAINER = "finetune-container" + + +class ModelTask(str, metaclass=ExtendedEnumMeta): + TEXT_GENERATION = "text-generation" + + +class FineTuningMetricCategories(str, metaclass=ExtendedEnumMeta): + VALIDATION = "validation" + TRAINING = "training" + + +class ModelType(str, metaclass=ExtendedEnumMeta): + FT = "FT" # Fine Tuned Model + BASE = "BASE" # Base model + + +# TODO: merge metadata key used in create FT +class FineTuningCustomMetadata(str, metaclass=ExtendedEnumMeta): + FT_SOURCE = "fine_tune_source" + FT_SOURCE_NAME = "fine_tune_source_name" + FT_OUTPUT_PATH = "fine_tune_output_path" + FT_JOB_ID = "fine_tune_job_id" + FT_JOB_RUN_ID = "fine_tune_jobrun_id" + TRAINING_METRICS_FINAL = "train_metrics_final" + VALIDATION_METRICS_FINAL = "val_metrics_final" + TRAINING_METRICS_EPOCH = "train_metrics_epoch" + VALIDATION_METRICS_EPOCH = "val_metrics_epoch" diff --git a/ads/aqua/model/entities.py b/ads/aqua/model/entities.py new file mode 100644 index 000000000..8f6cb234b --- /dev/null +++ b/ads/aqua/model/entities.py @@ -0,0 +1,266 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +""" +aqua.model.entities +~~~~~~~~~~~~~~~~~~~ + +This module contains dataclasses for Aqua Model. +""" +import re +from dataclasses import InitVar, dataclass, field +from typing import List, Optional + +import oci + +from ads.aqua import logger +from ads.aqua.app import CLIBuilderMixin +from ads.aqua.common import utils +from ads.aqua.constants import UNKNOWN_VALUE +from ads.aqua.data import AquaResourceIdentifier +from ads.aqua.model.enums import FineTuningDefinedMetadata +from ads.aqua.training.exceptions import exit_code_dict +from ads.common.serializer import DataClassSerializable +from ads.common.utils import get_log_links +from ads.model.datascience_model import DataScienceModel +from ads.model.model_metadata import MetadataTaxonomyKeys + + +@dataclass(repr=False) +class FineTuningShapeInfo(DataClassSerializable): + instance_shape: str = field(default_factory=str) + replica: int = field(default_factory=int) + + +# TODO: give a better name +@dataclass(repr=False) +class AquaFineTuneValidation(DataClassSerializable): + type: str = "Automatic split" + value: str = "" + + +@dataclass(repr=False) +class AquaFineTuningMetric(DataClassSerializable): + name: str = field(default_factory=str) + category: str = field(default_factory=str) + scores: list = field(default_factory=list) + + +@dataclass(repr=False) +class AquaModelLicense(DataClassSerializable): + """Represents the response of Get Model License.""" + + id: str = field(default_factory=str) + license: str = field(default_factory=str) + + +@dataclass(repr=False) +class AquaModelSummary(DataClassSerializable): + """Represents a summary of Aqua model.""" + + compartment_id: str = None + icon: str = None + id: str = None + is_fine_tuned_model: bool = None + license: str = None + name: str = None + organization: str = None + project_id: str = None + tags: dict = None + task: str = None + time_created: str = None + console_link: str = None + search_text: str = None + ready_to_deploy: bool = True + ready_to_finetune: bool = False + ready_to_import: bool = False + + +@dataclass(repr=False) +class AquaModel(AquaModelSummary, DataClassSerializable): + """Represents an Aqua model.""" + + model_card: str = None + inference_container: str = None + finetuning_container: str = None + evaluation_container: str = None + + +@dataclass(repr=False) +class HFModelContainerInfo: + """Container defauls for model""" + + inference_container: str = None + finetuning_container: str = None + + +@dataclass(repr=False) +class AquaEvalFTCommon(DataClassSerializable): + """Represents common fields for evaluation and fine-tuning.""" + + lifecycle_state: str = None + lifecycle_details: str = None + job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) + source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) + experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) + log_group: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) + log: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) + + model: InitVar = None + region: InitVar = None + jobrun: InitVar = None + + def __post_init__( + self, model, region: str, jobrun: oci.data_science.models.JobRun = None + ): + try: + log_id = jobrun.log_details.log_id + except Exception as e: + logger.debug(f"No associated log found. {str(e)}") + log_id = "" + + try: + loggroup_id = jobrun.log_details.log_group_id + except Exception as e: + logger.debug(f"No associated loggroup found. {str(e)}") + loggroup_id = "" + + loggroup_url = get_log_links(region=region, log_group_id=loggroup_id) + log_url = ( + get_log_links( + region=region, + log_group_id=loggroup_id, + log_id=log_id, + compartment_id=jobrun.compartment_id, + source_id=jobrun.id, + ) + if jobrun + else "" + ) + + log_name = None + loggroup_name = None + + if log_id: + try: + log = utils.query_resource(log_id, return_all=False) + log_name = log.display_name if log else "" + except: + pass + + if loggroup_id: + try: + loggroup = utils.query_resource(loggroup_id, return_all=False) + loggroup_name = loggroup.display_name if loggroup else "" + except: + pass + + experiment_id, experiment_name = utils._get_experiment_info(model) + + self.log_group = AquaResourceIdentifier( + loggroup_id, loggroup_name, loggroup_url + ) + self.log = AquaResourceIdentifier(log_id, log_name, log_url) + self.experiment = utils._build_resource_identifier( + id=experiment_id, name=experiment_name, region=region + ) + self.job = utils._build_job_identifier(job_run_details=jobrun, region=region) + self.lifecycle_details = ( + utils.LIFECYCLE_DETAILS_MISSING_JOBRUN + if not jobrun + else jobrun.lifecycle_details + ) + + +@dataclass(repr=False) +class AquaFineTuneModel(AquaModel, AquaEvalFTCommon, DataClassSerializable): + """Represents an Aqua Fine Tuned Model.""" + + dataset: str = field(default_factory=str) + validation: AquaFineTuneValidation = field(default_factory=AquaFineTuneValidation) + shape_info: FineTuningShapeInfo = field(default_factory=FineTuningShapeInfo) + metrics: List[AquaFineTuningMetric] = field(default_factory=list) + + def __post_init__( + self, + model: DataScienceModel, + region: str, + jobrun: oci.data_science.models.JobRun = None, + ): + super().__post_init__(model=model, region=region, jobrun=jobrun) + + if jobrun is not None: + jobrun_env_vars = ( + jobrun.job_configuration_override_details.environment_variables or {} + ) + self.shape_info = FineTuningShapeInfo( + instance_shape=jobrun.job_infrastructure_configuration_details.shape_name, + # TODO: use variable for `NODE_COUNT` in ads/jobs/builders/runtimes/base.py + replica=jobrun_env_vars.get("NODE_COUNT", UNKNOWN_VALUE), + ) + + try: + model_hyperparameters = model.defined_metadata_list.get( + MetadataTaxonomyKeys.HYPERPARAMETERS + ).value + except Exception as e: + logger.debug( + f"Failed to extract model hyperparameters from {model.id}: " f"{str(e)}" + ) + model_hyperparameters = {} + + self.dataset = model_hyperparameters.get( + FineTuningDefinedMetadata.TRAINING_DATA + ) + if not self.dataset: + logger.debug( + f"Key={FineTuningDefinedMetadata.TRAINING_DATA} not found in model hyperparameters." + ) + + self.validation = AquaFineTuneValidation( + value=model_hyperparameters.get(FineTuningDefinedMetadata.VAL_SET_SIZE) + ) + if not self.validation: + logger.debug( + f"Key={FineTuningDefinedMetadata.VAL_SET_SIZE} not found in model hyperparameters." + ) + + if self.lifecycle_details: + self.lifecycle_details = self._extract_job_lifecycle_details( + self.lifecycle_details + ) + + def _extract_job_lifecycle_details(self, lifecycle_details): + message = lifecycle_details + try: + # Extract exit code + match = re.search(r"exit code (\d+)", lifecycle_details) + if match: + exit_code = int(match.group(1)) + if exit_code == 1: + return message + # Match exit code to message + exception = exit_code_dict().get( + exit_code, + lifecycle_details, + ) + message = f"{exception.reason} (exit code {exit_code})" + except: + pass + + return message + + +@dataclass +class ImportModelDetails(CLIBuilderMixin): + model: str + os_path: str + inference_container: Optional[str] = None + finetuning_container: Optional[str] = None + compartment_id: Optional[str] = None + project_id: Optional[str] = None + + def __post_init__(self): + self._command = "model register" diff --git a/ads/aqua/model/enums.py b/ads/aqua/model/enums.py new file mode 100644 index 000000000..aee985a8f --- /dev/null +++ b/ads/aqua/model/enums.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from ads.common.extended_enum import ExtendedEnumMeta + + +class FineTuningDefinedMetadata(str, metaclass=ExtendedEnumMeta): + """Represents the defined metadata keys used in Fine Tuning.""" + + VAL_SET_SIZE = "val_set_size" + TRAINING_DATA = "training_data" + + +class FineTuningCustomMetadata(str, metaclass=ExtendedEnumMeta): + """Represents the custom metadata keys used in Fine Tuning.""" + + FT_SOURCE = "fine_tune_source" + FT_SOURCE_NAME = "fine_tune_source_name" + FT_OUTPUT_PATH = "fine_tune_output_path" + FT_JOB_ID = "fine_tune_job_id" + FT_JOB_RUN_ID = "fine_tune_jobrun_id" + TRAINING_METRICS_FINAL = "train_metrics_final" + VALIDATION_METRICS_FINAL = "val_metrics_final" + TRAINING_METRICS_EPOCH = "train_metrics_epoch" + VALIDATION_METRICS_EPOCH = "val_metrics_epoch" diff --git a/ads/aqua/model.py b/ads/aqua/model/model.py similarity index 52% rename from ads/aqua/model.py rename to ads/aqua/model/model.py index 982553f50..d92a3ab63 100644 --- a/ads/aqua/model.py +++ b/ads/aqua/model/model.py @@ -3,292 +3,59 @@ # Copyright (c) 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import os -import re -from dataclasses import InitVar, dataclass, field from datetime import datetime, timedelta -from enum import Enum from threading import Lock -from typing import List, Union +from typing import List, Optional, Union -import oci from cachetools import TTLCache from oci.data_science.models import JobRun, Model -from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger, utils -from ads.aqua.base import AquaApp -from ads.aqua.constants import ( - TRAINING_METRICS_FINAL, - TRINING_METRICS, - UNKNOWN_VALUE, - VALIDATION_METRICS, - VALIDATION_METRICS_FINAL, - FineTuningDefinedMetadata, +from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID +from ads.aqua.app import AquaApp +from ads.aqua.common.enums import Tags +from ads.aqua.common.errors import AquaRuntimeError +from ads.aqua.common.utils import ( + create_word_icon, + get_artifact_path, + read_file, + copy_model_config, + load_config, ) -from ads.aqua.data import AquaResourceIdentifier, Tags -from ads.aqua.exception import AquaRuntimeError -from ads.aqua.training.exceptions import exit_code_dict -from ads.aqua.utils import ( +from ads.aqua.constants import ( LICENSE_TXT, + MODEL_BY_REFERENCE_OSS_PATH_KEY, README, READY_TO_DEPLOY_STATUS, READY_TO_FINE_TUNE_STATUS, + READY_TO_IMPORT_STATUS, + TRAINING_METRICS_FINAL, + TRINING_METRICS, UNKNOWN, - create_word_icon, - get_artifact_path, - read_file, + VALIDATION_METRICS, + VALIDATION_METRICS_FINAL, + AQUA_MODEL_ARTIFACT_CONFIG, + AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME, + AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE, + AQUA_MODEL_TYPE_CUSTOM, ) +from ads.aqua.model.constants import * +from ads.aqua.model.entities import * from ads.common.auth import default_signer -from ads.common.object_storage_details import ObjectStorageDetails from ads.common.oci_resource import SEARCH_TYPE, OCIResource -from ads.common.serializer import DataClassSerializable -from ads.common.utils import get_console_link, get_log_links +from ads.common.utils import get_console_link from ads.config import ( - AQUA_SERVICE_MODELS_BUCKET, + AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME, + AQUA_EVALUATION_CONTAINER_METADATA_NAME, + AQUA_FINETUNING_CONTAINER_METADATA_NAME, COMPARTMENT_OCID, - CONDA_BUCKET_NS, PROJECT_OCID, TENANCY_OCID, ) from ads.model import DataScienceModel -from ads.model.model_metadata import MetadataTaxonomyKeys, ModelCustomMetadata +from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem from ads.telemetry import telemetry -class FineTuningMetricCategories(Enum): - VALIDATION = "validation" - TRAINING = "training" - - -@dataclass(repr=False) -class FineTuningShapeInfo(DataClassSerializable): - instance_shape: str = field(default_factory=str) - replica: int = field(default_factory=int) - - -# TODO: give a better name -@dataclass(repr=False) -class AquaFineTuneValidation(DataClassSerializable): - type: str = "Automatic split" - value: str = "" - - -@dataclass(repr=False) -class AquaFineTuningMetric(DataClassSerializable): - name: str = field(default_factory=str) - category: str = field(default_factory=str) - scores: list = field(default_factory=list) - - -@dataclass(repr=False) -class AquaModelLicense(DataClassSerializable): - """Represents the response of Get Model License.""" - - id: str = field(default_factory=str) - license: str = field(default_factory=str) - - -@dataclass(repr=False) -class AquaModelSummary(DataClassSerializable): - """Represents a summary of Aqua model.""" - - compartment_id: str = None - icon: str = None - id: str = None - is_fine_tuned_model: bool = None - license: str = None - name: str = None - organization: str = None - project_id: str = None - tags: dict = None - task: str = None - time_created: str = None - console_link: str = None - search_text: str = None - ready_to_deploy: bool = True - ready_to_finetune: bool = False - - -@dataclass(repr=False) -class AquaModel(AquaModelSummary, DataClassSerializable): - """Represents an Aqua model.""" - - model_card: str = None - - -@dataclass(repr=False) -class AquaEvalFTCommon(DataClassSerializable): - """Represents common fields for evaluation and fine-tuning.""" - - lifecycle_state: str = None - lifecycle_details: str = None - job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - log_group: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - log: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - - model: InitVar = None - region: InitVar = None - jobrun: InitVar = None - - def __post_init__( - self, model, region: str, jobrun: oci.data_science.models.JobRun = None - ): - try: - log_id = jobrun.log_details.log_id - except Exception as e: - logger.debug(f"No associated log found. {str(e)}") - log_id = "" - - try: - loggroup_id = jobrun.log_details.log_group_id - except Exception as e: - logger.debug(f"No associated loggroup found. {str(e)}") - loggroup_id = "" - - loggroup_url = get_log_links(region=region, log_group_id=loggroup_id) - log_url = ( - get_log_links( - region=region, - log_group_id=loggroup_id, - log_id=log_id, - compartment_id=jobrun.compartment_id, - source_id=jobrun.id, - ) - if jobrun - else "" - ) - - log_name = None - loggroup_name = None - - if log_id: - try: - log = utils.query_resource(log_id, return_all=False) - log_name = log.display_name if log else "" - except: - pass - - if loggroup_id: - try: - loggroup = utils.query_resource(loggroup_id, return_all=False) - loggroup_name = loggroup.display_name if loggroup else "" - except: - pass - - experiment_id, experiment_name = utils._get_experiment_info(model) - - self.log_group = AquaResourceIdentifier( - loggroup_id, loggroup_name, loggroup_url - ) - self.log = AquaResourceIdentifier(log_id, log_name, log_url) - self.experiment = utils._build_resource_identifier( - id=experiment_id, name=experiment_name, region=region - ) - self.job = utils._build_job_identifier(job_run_details=jobrun, region=region) - self.lifecycle_details = ( - utils.LIFECYCLE_DETAILS_MISSING_JOBRUN - if not jobrun - else jobrun.lifecycle_details - ) - - -@dataclass(repr=False) -class AquaFineTuneModel(AquaModel, AquaEvalFTCommon, DataClassSerializable): - """Represents an Aqua Fine Tuned Model.""" - - dataset: str = field(default_factory=str) - validation: AquaFineTuneValidation = field(default_factory=AquaFineTuneValidation) - shape_info: FineTuningShapeInfo = field(default_factory=FineTuningShapeInfo) - metrics: List[AquaFineTuningMetric] = field(default_factory=list) - - def __post_init__( - self, - model: DataScienceModel, - region: str, - jobrun: oci.data_science.models.JobRun = None, - ): - super().__post_init__(model=model, region=region, jobrun=jobrun) - - if jobrun is not None: - jobrun_env_vars = ( - jobrun.job_configuration_override_details.environment_variables or {} - ) - self.shape_info = FineTuningShapeInfo( - instance_shape=jobrun.job_infrastructure_configuration_details.shape_name, - # TODO: use variable for `NODE_COUNT` in ads/jobs/builders/runtimes/base.py - replica=jobrun_env_vars.get("NODE_COUNT", UNKNOWN_VALUE), - ) - - try: - model_hyperparameters = model.defined_metadata_list.get( - MetadataTaxonomyKeys.HYPERPARAMETERS - ).value - except Exception as e: - logger.debug( - f"Failed to extract model hyperparameters from {model.id}: " f"{str(e)}" - ) - model_hyperparameters = {} - - self.dataset = model_hyperparameters.get( - FineTuningDefinedMetadata.TRAINING_DATA.value - ) - if not self.dataset: - logger.debug( - f"Key={FineTuningDefinedMetadata.TRAINING_DATA.value} not found in model hyperparameters." - ) - - self.validation = AquaFineTuneValidation( - value=model_hyperparameters.get( - FineTuningDefinedMetadata.VAL_SET_SIZE.value - ) - ) - if not self.validation: - logger.debug( - f"Key={FineTuningDefinedMetadata.VAL_SET_SIZE.value} not found in model hyperparameters." - ) - - if self.lifecycle_details: - self.lifecycle_details = self._extract_job_lifecycle_details( - self.lifecycle_details - ) - - def _extract_job_lifecycle_details(self, lifecycle_details): - message = lifecycle_details - try: - # Extract exit code - match = re.search(r"exit code (\d+)", lifecycle_details) - if match: - exit_code = int(match.group(1)) - if exit_code == 1: - return message - # Match exit code to message - exception = exit_code_dict().get( - exit_code, - lifecycle_details, - ) - message = f"{exception.reason} (exit code {exit_code})" - except: - pass - - return message - - -# TODO: merge metadata key used in create FT - - -class FineTuningCustomMetadata(Enum): - FT_SOURCE = "fine_tune_source" - FT_SOURCE_NAME = "fine_tune_source_name" - FT_OUTPUT_PATH = "fine_tune_output_path" - FT_JOB_ID = "fine_tune_job_id" - FT_JOB_RUN_ID = "fine_tune_jobrun_id" - TRAINING_METRICS_FINAL = "train_metrics_final" - VALIDATION_METRICS_FINAL = "val_metrics_final" - TRAINING_METRICS_EPOCH = "train_metrics_epoch" - VALIDATION_METRICS_EPOCH = "val_metrics_epoch" - - class AquaModelApp(AquaApp): """Provides a suite of APIs to interact with Aqua models within the Oracle Cloud Infrastructure Data Science service, serving as an interface for @@ -305,6 +72,7 @@ class AquaModelApp(AquaApp): Lists all Aqua models within a specified compartment and/or project. clear_model_list_cache() Allows clear list model cache items from the service models compartment. + register(model: str, os_path: str, local_dir: str = None) Note: This class is designed to work within the Oracle Cloud Infrastructure @@ -382,13 +150,15 @@ def create( return custom_model @telemetry(entry_point="plugin=model&action=get", name="aqua") - def get(self, model_id) -> "AquaModel": + def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaModel": """Gets the information of an Aqua model. Parameters ---------- model_id: str The model OCID. + load_model_card: (bool, optional). Defaults to `True`. + Whether to load model card from artifacts or not. Returns ------- @@ -407,38 +177,57 @@ def get(self, model_id) -> "AquaModel": is_fine_tuned_model = ( True if ds_model.freeform_tags - and ds_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG.value) + and ds_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG) else False ) # todo: consolidate this logic in utils for model and deployment use - try: - artifact_path = ds_model.custom_metadata_list.get( - utils.MODEL_BY_REFERENCE_OSS_PATH_KEY - ).value.rstrip("/") - if not ObjectStorageDetails.is_oci_path(artifact_path): - artifact_path = ObjectStorageDetails( - AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, artifact_path - ).path - except ValueError: - artifact_path = utils.UNKNOWN + is_verified_type = ( + ds_model.freeform_tags.get(Tags.READY_TO_IMPORT, "false").upper() + == READY_TO_IMPORT_STATUS + ) - if not artifact_path: - logger.debug("Failed to get artifact path from custom metadata.") + model_card = "" + if load_model_card: + artifact_path = get_artifact_path( + ds_model.custom_metadata_list._to_oci_metadata() + ) + if artifact_path != UNKNOWN: + model_card = str( + read_file( + file_path=( + f"{artifact_path.rstrip('/')}/config/{README}" + if is_verified_type + else f"{artifact_path.rstrip('/')}/{README}" + ), + auth=default_signer(), + ) + ) - aqua_model_atttributes = dict( + inference_container = ds_model.custom_metadata_list.get( + ModelCustomMetadataFields.DEPLOYMENT_CONTAINER, + ModelCustomMetadataItem(key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER), + ).value + evaluation_container = ds_model.custom_metadata_list.get( + ModelCustomMetadataFields.EVALUATION_CONTAINER, + ModelCustomMetadataItem(key=ModelCustomMetadataFields.EVALUATION_CONTAINER), + ).value + finetuning_container: str = ds_model.custom_metadata_list.get( + ModelCustomMetadataFields.FINETUNE_CONTAINER, + ModelCustomMetadataItem(key=ModelCustomMetadataFields.FINETUNE_CONTAINER), + ).value + + aqua_model_attributes = dict( **self._process_model(ds_model, self.region), project_id=ds_model.project_id, - model_card=str( - read_file( - file_path=f"{artifact_path}/{README}", - auth=self._auth, - ) - ), + model_card=model_card, + inference_container=inference_container, + finetuning_container=finetuning_container, + evaluation_container=evaluation_container, ) if not is_fine_tuned_model: - model_details = AquaModel(**aqua_model_atttributes) + model_details = AquaModel(**aqua_model_attributes) self._service_model_details_cache.__setitem__( key=model_id, value=model_details ) @@ -455,7 +244,7 @@ def get(self, model_id) -> "AquaModel": try: source_id = ds_model.custom_metadata_list.get( - FineTuningCustomMetadata.FT_SOURCE.value + FineTuningCustomMetadata.FT_SOURCE ).value except ValueError as e: logger.debug(str(e)) @@ -463,7 +252,7 @@ def get(self, model_id) -> "AquaModel": try: source_name = ds_model.custom_metadata_list.get( - FineTuningCustomMetadata.FT_SOURCE_NAME.value + FineTuningCustomMetadata.FT_SOURCE_NAME ).value except ValueError as e: logger.debug(str(e)) @@ -494,7 +283,7 @@ def get(self, model_id) -> "AquaModel": ) model_details = AquaFineTuneModel( - **aqua_model_atttributes, + **aqua_model_attributes, source=source_identifier, lifecycle_state=( Model.LIFECYCLE_STATE_ACTIVE @@ -541,29 +330,29 @@ def _build_ft_metrics( validation_metrics = self._fetch_metric_from_metadata( custom_metadata_list=custom_metadata_list, - target=FineTuningCustomMetadata.VALIDATION_METRICS_EPOCH.value, - category=FineTuningMetricCategories.VALIDATION.value, + target=FineTuningCustomMetadata.VALIDATION_METRICS_EPOCH, + category=FineTuningMetricCategories.VALIDATION, metric_name=VALIDATION_METRICS, ) training_metrics = self._fetch_metric_from_metadata( custom_metadata_list=custom_metadata_list, - target=FineTuningCustomMetadata.TRAINING_METRICS_EPOCH.value, - category=FineTuningMetricCategories.TRAINING.value, + target=FineTuningCustomMetadata.TRAINING_METRICS_EPOCH, + category=FineTuningMetricCategories.TRAINING, metric_name=TRINING_METRICS, ) validation_final = self._fetch_metric_from_metadata( custom_metadata_list=custom_metadata_list, - target=FineTuningCustomMetadata.VALIDATION_METRICS_FINAL.value, - category=FineTuningMetricCategories.VALIDATION.value, + target=FineTuningCustomMetadata.VALIDATION_METRICS_FINAL, + category=FineTuningMetricCategories.VALIDATION, metric_name=VALIDATION_METRICS_FINAL, ) training_final = self._fetch_metric_from_metadata( custom_metadata_list=custom_metadata_list, - target=FineTuningCustomMetadata.TRAINING_METRICS_FINAL.value, - category=FineTuningMetricCategories.TRAINING.value, + target=FineTuningCustomMetadata.TRAINING_METRICS_FINAL, + category=FineTuningMetricCategories.TRAINING, metric_name=TRAINING_METRICS_FINAL, ) @@ -623,23 +412,27 @@ def _process_model( ) freeform_tags = model.freeform_tags or {} - is_fine_tuned_model = Tags.AQUA_FINE_TUNED_MODEL_TAG.value in freeform_tags + is_fine_tuned_model = Tags.AQUA_FINE_TUNED_MODEL_TAG in freeform_tags ready_to_deploy = ( - freeform_tags.get(Tags.AQUA_TAG.value, "").upper() == READY_TO_DEPLOY_STATUS + freeform_tags.get(Tags.AQUA_TAG, "").upper() == READY_TO_DEPLOY_STATUS ) ready_to_finetune = ( - freeform_tags.get(Tags.READY_TO_FINE_TUNE.value, "").upper() + freeform_tags.get(Tags.READY_TO_FINE_TUNE, "").upper() == READY_TO_FINE_TUNE_STATUS ) + ready_to_import = ( + freeform_tags.get(Tags.READY_TO_IMPORT, "").upper() + == READY_TO_IMPORT_STATUS + ) return dict( compartment_id=model.compartment_id, icon=icon or UNKNOWN, id=model_id, - license=freeform_tags.get(Tags.LICENSE.value, UNKNOWN), + license=freeform_tags.get(Tags.LICENSE, UNKNOWN), name=model.display_name, - organization=freeform_tags.get(Tags.ORGANIZATION.value, UNKNOWN), - task=freeform_tags.get(Tags.TASK.value, UNKNOWN), + organization=freeform_tags.get(Tags.ORGANIZATION, UNKNOWN), + task=freeform_tags.get(Tags.TASK, UNKNOWN), time_created=model.time_created, is_fine_tuned_model=is_fine_tuned_model, tags=tags, @@ -647,11 +440,16 @@ def _process_model( search_text=search_text, ready_to_deploy=ready_to_deploy, ready_to_finetune=ready_to_finetune, + ready_to_import=ready_to_import, ) @telemetry(entry_point="plugin=model&action=list", name="aqua") def list( - self, compartment_id: str = None, project_id: str = None, **kwargs + self, + compartment_id: str = None, + project_id: str = None, + model_type: str = None, + **kwargs, ) -> List["AquaModelSummary"]: """Lists all Aqua models within a specified compartment and/or project. If `compartment_id` is not specified, the method defaults to returning @@ -665,6 +463,8 @@ def list( The compartment OCID. project_id: (str, optional). Defaults to `None`. The project OCID. + model_type: (str, optional). Defaults to `None`. + Model type represents the type of model in the user compartment, can be either FT or BASE. **kwargs: Additional keyword arguments that can be used to filter the results. @@ -682,7 +482,8 @@ def list( ) logger.info(f"Fetching custom models from compartment_id={compartment_id}.") - models = self._rqs(compartment_id) + model_type = model_type.upper() if model_type else ModelType.FT + models = self._rqs(compartment_id, model_type=model_type) else: # tracks number of times service model listing was called self.telemetry.record_event_async( @@ -751,16 +552,275 @@ def clear_model_list_cache( } return res + def _create_model_catalog_entry( + self, + os_path: str, + model_name: str, + inference_container: str, + finetuning_container: str, + verified_model: DataScienceModel, + compartment_id: Optional[str], + project_id: Optional[str], + ) -> DataScienceModel: + """Create model by reference from the object storage path + + Args: + os_path (str): OCI where the model is uploaded - oci://bucket@namespace/prefix + model_name (str): name of the model + inference_container (str): selects service defaults + finetuning_container (str): selects service defaults + verified_model (DataScienceModel): If set, then copies all the tags and custom metadata information from the service verified model + compartment_id (Optional[str]): Compartment Id of the compartment where the model has to be created + project_id (Optional[str]): Project id of the project where the model has to be created + + Returns: + DataScienceModel: Returns Datascience model instance. + """ + model = DataScienceModel() + tags = ( + { + **verified_model.freeform_tags, + Tags.AQUA_SERVICE_MODEL_TAG: verified_model.id, + } + if verified_model + else {Tags.AQUA_TAG: "active", Tags.BASE_MODEL_CUSTOM: "true"} + ) + tags.update({Tags.BASE_MODEL_CUSTOM: "true"}) + + # Remove `ready_to_import` tag that might get copied from service model. + tags.pop(Tags.READY_TO_IMPORT, None) + metadata = None + if verified_model: + # Verified model is a model in the service catalog that either has no artifacts but contains all the necessary metadata for deploying and fine tuning. + # If set, then we copy all the model metadata. + metadata = verified_model.custom_metadata_list + if verified_model.model_file_description: + model = model.with_model_file_description( + json_dict=verified_model.model_file_description + ) + + else: + metadata = ModelCustomMetadata() + if not inference_container: + raise AquaRuntimeError( + f"Require Inference container information. Model: {model_name} does not have associated inference container defaults. Check docs for more information on how to pass inference container." + ) + if finetuning_container: + tags[Tags.READY_TO_FINE_TUNE] = "true" + metadata.add( + key=AQUA_FINETUNING_CONTAINER_METADATA_NAME, + value=finetuning_container, + description=f"Fine-tuning container mapping for {model_name}", + category="Other", + ) + else: + logger.warn( + f"Proceeding with model registration without the fine-tuning container information. " + f"This model will not be available for fine tuning." + ) + + metadata.add( + key=AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME, + value=inference_container, + description=f"Inference container mapping for {model_name}", + category="Other", + ) + metadata.add( + key=AQUA_EVALUATION_CONTAINER_METADATA_NAME, + value="odsc-llm-evaluate", + description="Evaluation container mapping for SMC", + category="Other", + ) + # TODO: either get task and organization from user or a config file + # tags["task"] = "UNKNOWN" + # tags["organization"] = "UNKNOWN" + + try: + # If verified model already has a artifact json, use that. + artifact_path = metadata.get(MODEL_BY_REFERENCE_OSS_PATH_KEY).value + logger.info( + f"Found model artifact in the service bucket. " + f"Using artifact from service bucket instead of {os_path}" + ) + + # todo: implement generic copy_folder method + # copy model config from artifact path to user bucket + copy_model_config( + artifact_path=artifact_path, os_path=os_path, auth=default_signer() + ) + + except: + # Add artifact from user bucket + metadata.add( + key=MODEL_BY_REFERENCE_OSS_PATH_KEY, + value=os_path, + description="artifact location", + category="Other", + ) + + model = ( + model.with_custom_metadata_list(metadata) + .with_compartment_id(compartment_id or COMPARTMENT_OCID) + .with_project_id(project_id or PROJECT_OCID) + .with_artifact(os_path) + .with_display_name(model_name) + .with_freeform_tags(**tags) + ).create(model_by_reference=True) + logger.debug(model) + return model + + def register( + self, import_model_details: ImportModelDetails = None, **kwargs + ) -> AquaModel: + """Loads the model from object storage and registers as Model in Data Science Model catalog + The inference container and finetuning container could be of type Service Manged Container(SMC) or custom. + If it is custom, full container URI is expected. If it of type SMC, only the container family name is expected. + + Args: + import_model_details (ImportModelDetails): Model details for importing the model. + kwargs: + model (str): name of the model or OCID of the service model that has inference and finetuning information + os_path (str): Object storage destination URI to store the downloaded model. Format: oci://bucket-name@namespace/prefix + inference_container (str): selects service defaults + finetuning_container (str): selects service defaults + + Returns: + AquaModel: + The registered model as a AquaModel object. + """ + verified_model_details: DataScienceModel = None + + if not import_model_details: + import_model_details = ImportModelDetails(**kwargs) + + try: + model_config = load_config( + file_path=import_model_details.os_path, + config_file_name=AQUA_MODEL_ARTIFACT_CONFIG, + ) + except Exception as ex: + logger.error( + f"Exception occurred while loading config file from {import_model_details.os_path}" + f"Exception message: {ex}" + ) + raise AquaRuntimeError( + f"The model path {import_model_details.os_path} does not contain the file config.json. " + f"Please check if the path is correct or the model artifacts are available at this location." + ) + + model_service_id = None + # If OCID of a model is passed, we need to copy the defaults for Tags and metadata from the service model. + if ( + import_model_details.model.startswith("ocid") + and "datasciencemodel" in import_model_details.model + ): + model_service_id = import_model_details.model + else: + # If users passes model name, check if there is model with the same name in the service model catalog. If it is there, then use that model + model_service_id = self._find_matching_aqua_model( + import_model_details.model + ) + logger.info( + f"Found service model for {import_model_details.model}: {model_service_id}" + ) + if model_service_id: + verified_model_details = DataScienceModel.from_id(model_service_id) + try: + metadata_model_type = verified_model_details.custom_metadata_list.get( + AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE + ).value + if metadata_model_type: + if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config: + if ( + model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE] + != metadata_model_type + ): + raise AquaRuntimeError( + f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}" + f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for " + f"the model {import_model_details.model}. Please check if the path is correct or " + f"the correct model artifacts are available at this location." + f"" + ) + else: + logger.debug( + f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in " + f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration." + ) + except: + pass + + # Copy the model name from the service model if `model` is ocid + model_name = ( + verified_model_details.display_name + if verified_model_details + else import_model_details.model + ) + + # Create Model catalog entry with pass by reference + ds_model = self._create_model_catalog_entry( + os_path=import_model_details.os_path, + model_name=model_name, + inference_container=import_model_details.inference_container, + finetuning_container=import_model_details.finetuning_container, + verified_model=verified_model_details, + compartment_id=import_model_details.compartment_id, + project_id=import_model_details.project_id, + ) + # registered model will always have inference and evaluation container, but + # fine-tuning container may be not set + inference_container = ds_model.custom_metadata_list.get( + ModelCustomMetadataFields.DEPLOYMENT_CONTAINER + ).value + evaluation_container = ds_model.custom_metadata_list.get( + ModelCustomMetadataFields.EVALUATION_CONTAINER, + ).value + try: + finetuning_container = ds_model.custom_metadata_list.get( + ModelCustomMetadataFields.FINETUNE_CONTAINER, + ).value + except: + finetuning_container = None + + aqua_model_attributes = dict( + **self._process_model(ds_model, self.region), + project_id=ds_model.project_id, + model_card=str( + read_file( + file_path=f"{import_model_details.os_path.rstrip('/')}/{README}", + auth=default_signer(), + ) + ), + inference_container=inference_container, + finetuning_container=finetuning_container, + evaluation_container=evaluation_container, + ) + + if verified_model_details: + telemetry_model_name = model_name + else: + if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config: + telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}" + elif AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config: + telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}" + else: + telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM + + self.telemetry.record_event_async( + category="aqua/model", + action="register", + detail=telemetry_model_name, + ) + + return AquaModel(**aqua_model_attributes) + def _if_show(self, model: DataScienceModel) -> bool: """Determine if the given model should be return by `list`.""" if model.freeform_tags is None: return False TARGET_TAGS = model.freeform_tags.keys() - return ( - Tags.AQUA_TAG.value in TARGET_TAGS - or Tags.AQUA_TAG.value.lower() in TARGET_TAGS - ) + return Tags.AQUA_TAG in TARGET_TAGS or Tags.AQUA_TAG.lower() in TARGET_TAGS def _load_icon(self, model_name: str) -> str: """Loads icon.""" @@ -772,10 +832,18 @@ def _load_icon(self, model_name: str) -> str: logger.debug(f"Failed to load icon for the model={model_name}: {str(e)}.") return None - def _rqs(self, compartment_id: str, **kwargs): + def _rqs(self, compartment_id: str, model_type="FT", **kwargs): """Use RQS to fetch models in the user tenancy.""" + if model_type == ModelType.FT: + filter_tag = Tags.AQUA_FINE_TUNED_MODEL_TAG + elif model_type == ModelType.BASE: + filter_tag = Tags.BASE_MODEL_CUSTOM + else: + raise ValueError( + f"Model of type {model_type} is unknown. The values should be in {ModelType.values()}" + ) - condition_tags = f"&& (freeformTags.key = '{Tags.AQUA_TAG.value}' && freeformTags.key = '{Tags.AQUA_FINE_TUNED_MODEL_TAG.value}')" + condition_tags = f"&& (freeformTags.key = '{Tags.AQUA_TAG}' && freeformTags.key = '{filter_tag}')" condition_lifecycle = "&& lifecycleState = 'ACTIVE'" query = f"query datasciencemodel resources where (compartmentId = '{compartment_id}' {condition_lifecycle} {condition_tags})" logger.info(query) @@ -820,3 +888,27 @@ def load_license(self, model_id: str) -> AquaModelLicense: ) return AquaModelLicense(id=model_id, license=content) + + def _find_matching_aqua_model(self, model_id: str) -> Optional[str]: + """ + Finds a matching model in AQUA based on the model ID from list of verified models. + + Parameters + ---------- + model_id (str): Verified model ID to match. + + Returns + ------- + Optional[str] + Returns model ocid that matches the model in the service catalog else returns None. + """ + # Convert the model ID to lowercase once + model_id_lower = model_id.lower() + + aqua_model_list = self.list() + + for aqua_model_summary in aqua_model_list: + if aqua_model_summary.name.lower() == model_id_lower: + return aqua_model_summary.id + + return None diff --git a/ads/aqua/modeldeployment/__init__.py b/ads/aqua/modeldeployment/__init__.py new file mode 100644 index 000000000..baf5c5b53 --- /dev/null +++ b/ads/aqua/modeldeployment/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from ads.aqua.modeldeployment.deployment import AquaDeploymentApp +from ads.aqua.modeldeployment.inference import MDInferenceResponse + +__all__ = ["AquaDeploymentApp", "MDInferenceResponse"] diff --git a/ads/aqua/modeldeployment/constants.py b/ads/aqua/modeldeployment/constants.py new file mode 100644 index 000000000..3eed6884d --- /dev/null +++ b/ads/aqua/modeldeployment/constants.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +""" +aqua.modeldeployment.constants +~~~~~~~~~~~~~~ + +This module contains constants used in Aqua Model Deployment. +""" + +VLLMInferenceRestrictedParams = { + "--tensor-parallel-size", + "--port", + "--host", + "--served-model-name", + "--seed", +} +TGIInferenceRestrictedParams = { + "--port", + "--hostname", + "--num-shard", + "--sharded", + "--trust-remote-code", +} diff --git a/ads/aqua/deployment.py b/ads/aqua/modeldeployment/deployment.py similarity index 58% rename from ads/aqua/deployment.py rename to ads/aqua/modeldeployment/deployment.py index b42634707..f496a7104 100644 --- a/ads/aqua/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -5,158 +5,68 @@ import json import logging -from dataclasses import dataclass, field, asdict from typing import Dict, List, Union -import requests -from oci.data_science.models import ModelDeployment, ModelDeploymentSummary +from oci.data_science.models import ModelDeployment -from ads.aqua.base import AquaApp, logger -from ads.aqua.exception import AquaRuntimeError, AquaValueError -from ads.aqua.model import AquaModelApp, Tags -from ads.aqua.utils import ( - UNKNOWN, - MODEL_BY_REFERENCE_OSS_PATH_KEY, - load_config, +from ads.aqua.app import AquaApp, logger +from ads.aqua.common.enums import ( + Tags, + InferenceContainerParamType, + InferenceContainerType, + InferenceContainerTypeFamily, +) +from ads.aqua.common.errors import AquaRuntimeError, AquaValueError +from ads.aqua.common.utils import ( + get_container_config, get_container_image, - UNKNOWN_DICT, - get_resource_name, get_model_by_reference_paths, get_ocid_substring, - AQUA_MODEL_TYPE_SERVICE, + get_combined_params, + get_params_dict, + get_params_list, + get_resource_name, + load_config, +) +from ads.aqua.constants import ( AQUA_MODEL_TYPE_CUSTOM, + AQUA_MODEL_TYPE_SERVICE, + MODEL_BY_REFERENCE_OSS_PATH_KEY, + UNKNOWN, + UNKNOWN_DICT, ) -from ads.aqua.finetune import FineTuneCustomMetadata from ads.aqua.data import AquaResourceIdentifier -from ads.common.utils import get_console_link, get_log_links -from ads.common.auth import default_signer +from ads.aqua.finetuning.finetuning import FineTuneCustomMetadata +from ads.aqua.model import AquaModelApp +from ads.aqua.modeldeployment.entities import ( + AquaDeployment, + AquaDeploymentDetail, + ContainerSpec, +) +from ads.aqua.modeldeployment.constants import ( + VLLMInferenceRestrictedParams, + TGIInferenceRestrictedParams, +) +from ads.common.object_storage_details import ObjectStorageDetails +from ads.common.utils import get_log_links +from ads.config import ( + AQUA_CONFIG_FOLDER, + AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME, + AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME, + AQUA_MODEL_DEPLOYMENT_CONFIG, + AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS, + COMPARTMENT_OCID, +) +from ads.model.datascience_model import DataScienceModel from ads.model.deployment import ( ModelDeployment, ModelDeploymentContainerRuntime, ModelDeploymentInfrastructure, ModelDeploymentMode, ) -from ads.common.serializer import DataClassSerializable -from ads.config import ( - AQUA_MODEL_DEPLOYMENT_CONFIG, - COMPARTMENT_OCID, - AQUA_CONFIG_FOLDER, - AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS, - AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME, - AQUA_SERVED_MODEL_NAME, -) -from ads.common.object_storage_details import ObjectStorageDetails from ads.telemetry import telemetry -@dataclass -class ShapeInfo: - instance_shape: str = None - instance_count: int = None - ocpus: float = None - memory_in_gbs: float = None - - -@dataclass(repr=False) -class AquaDeployment(DataClassSerializable): - """Represents an Aqua Model Deployment""" - - id: str = None - display_name: str = None - aqua_service_model: bool = None - aqua_model_name: str = None - state: str = None - description: str = None - created_on: str = None - created_by: str = None - endpoint: str = None - console_link: str = None - lifecycle_details: str = None - shape_info: field(default_factory=ShapeInfo) = None - tags: dict = None - - @classmethod - def from_oci_model_deployment( - cls, - oci_model_deployment: Union[ModelDeploymentSummary, ModelDeployment], - region: str, - ) -> "AquaDeployment": - """Converts oci model deployment response to AquaDeployment instance. - - Parameters - ---------- - oci_model_deployment: Union[ModelDeploymentSummary, ModelDeployment] - The instance of either oci.data_science.models.ModelDeployment or - oci.data_science.models.ModelDeploymentSummary class. - region: str - The region of this model deployment. - - Returns - ------- - AquaDeployment: - The instance of the Aqua model deployment. - """ - instance_configuration = ( - oci_model_deployment.model_deployment_configuration_details.model_configuration_details.instance_configuration - ) - instance_shape_config_details = ( - instance_configuration.model_deployment_instance_shape_config_details - ) - instance_count = ( - oci_model_deployment.model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count - ) - shape_info = ShapeInfo( - instance_shape=instance_configuration.instance_shape_name, - instance_count=instance_count, - ocpus=( - instance_shape_config_details.ocpus - if instance_shape_config_details - else None - ), - memory_in_gbs=( - instance_shape_config_details.memory_in_gbs - if instance_shape_config_details - else None - ), - ) - - freeform_tags = oci_model_deployment.freeform_tags or UNKNOWN_DICT - aqua_service_model_tag = freeform_tags.get( - Tags.AQUA_SERVICE_MODEL_TAG.value, None - ) - aqua_model_name = freeform_tags.get(Tags.AQUA_MODEL_NAME_TAG.value, UNKNOWN) - - return AquaDeployment( - id=oci_model_deployment.id, - display_name=oci_model_deployment.display_name, - aqua_service_model=aqua_service_model_tag is not None, - aqua_model_name=aqua_model_name, - shape_info=shape_info, - state=oci_model_deployment.lifecycle_state, - lifecycle_details=getattr( - oci_model_deployment, "lifecycle_details", UNKNOWN - ), - description=oci_model_deployment.description, - created_on=str(oci_model_deployment.time_created), - created_by=oci_model_deployment.created_by, - endpoint=oci_model_deployment.model_deployment_url, - console_link=get_console_link( - resource="model-deployments", - ocid=oci_model_deployment.id, - region=region, - ), - tags=freeform_tags, - ) - - -@dataclass(repr=False) -class AquaDeploymentDetail(AquaDeployment, DataClassSerializable): - """Represents a details of Aqua deployment.""" - - log_group: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - log: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - - class AquaDeploymentApp(AquaApp): """Provides a suite of APIs to interact with Aqua model deployments within the Oracle Cloud Infrastructure Data Science service, serving as an interface for deploying @@ -196,9 +106,10 @@ def create( description: str = None, bandwidth_mbps: int = None, web_concurrency: int = None, - server_port: int = 8080, - health_check_port: int = 8080, + server_port: int = None, + health_check_port: int = None, env_var: Dict = None, + container_family: str = None, ) -> "AquaDeployment": """ Creates a new Aqua deployment @@ -231,18 +142,21 @@ def create( The number of worker processes/threads to handle incoming requests with_bucket_uri(bucket_uri) Sets the bucket uri when uploading large size model. - server_port: (int). Defaults to 8080. + server_port: (int). The server port for docker container image. - health_check_port: (int). Defaults to 8080. + health_check_port: (int). The health check port for docker container image. env_var : dict, optional Environment variable for the deployment, by default None. + container_family: str + The image family of model deployment container runtime. Required for unverified Aqua models. Returns ------- AquaDeployment An Aqua deployment instance """ + # TODO validate if the service model has no artifact and if it requires import step before deployment. # Create a model catalog entry in the user compartment aqua_model = AquaModelApp().create( model_id=model_id, compartment_id=compartment_id, project_id=project_id @@ -250,45 +164,35 @@ def create( tags = {} for tag in [ - Tags.AQUA_SERVICE_MODEL_TAG.value, - Tags.AQUA_FINE_TUNED_MODEL_TAG.value, - Tags.AQUA_TAG.value, + Tags.AQUA_SERVICE_MODEL_TAG, + Tags.AQUA_FINE_TUNED_MODEL_TAG, + Tags.AQUA_TAG, ]: if tag in aqua_model.freeform_tags: tags[tag] = aqua_model.freeform_tags[tag] - tags.update({Tags.AQUA_MODEL_NAME_TAG.value: aqua_model.display_name}) + tags.update({Tags.AQUA_MODEL_NAME_TAG: aqua_model.display_name}) # Set up info to get deployment config config_source_id = model_id model_name = aqua_model.display_name - is_fine_tuned_model = ( - Tags.AQUA_FINE_TUNED_MODEL_TAG.value in aqua_model.freeform_tags - ) + is_fine_tuned_model = Tags.AQUA_FINE_TUNED_MODEL_TAG in aqua_model.freeform_tags if is_fine_tuned_model: try: config_source_id = aqua_model.custom_metadata_list.get( - FineTuneCustomMetadata.FINE_TUNE_SOURCE.value + FineTuneCustomMetadata.FINE_TUNE_SOURCE ).value model_name = aqua_model.custom_metadata_list.get( - FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME.value + FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME ).value except: raise AquaValueError( - f"Either {FineTuneCustomMetadata.FINE_TUNE_SOURCE.value} or {FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME.value} is missing " + f"Either {FineTuneCustomMetadata.FINE_TUNE_SOURCE} or {FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME} is missing " f"from custom metadata for the model {config_source_id}" ) - deployment_config = self.get_deployment_config(config_source_id) - vllm_params = ( - deployment_config.get("configuration", UNKNOWN_DICT) - .get(instance_shape, UNKNOWN_DICT) - .get("parameters", UNKNOWN_DICT) - .get("VLLM_PARAMS", UNKNOWN) - ) - # set up env vars if not env_var: env_var = dict() @@ -302,18 +206,11 @@ def create( f"{MODEL_BY_REFERENCE_OSS_PATH_KEY} key is not available in the custom metadata field." ) - # todo: remove this after absolute path is removed from env var if ObjectStorageDetails.is_oci_path(model_path_prefix): os_path = ObjectStorageDetails.from_path(model_path_prefix) model_path_prefix = os_path.filepath.rstrip("/") env_var.update({"BASE_MODEL": f"{model_path_prefix}"}) - params = f"--served-model-name {AQUA_SERVED_MODEL_NAME} --seed 42 " - if vllm_params: - params += vllm_params - env_var.update({"PARAMS": params}) - env_var.update({"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions"}) - env_var.update({"MODEL_DEPLOY_ENABLE_STREAMING": "true"}) if is_fine_tuned_model: _, fine_tune_output_path = get_model_by_reference_paths( @@ -330,28 +227,94 @@ def create( env_var.update({"FT_MODEL": f"{fine_tune_output_path}"}) - logging.info(f"Env vars used for deploying {aqua_model.id} :{env_var}") - + is_custom_container = False try: container_type_key = aqua_model.custom_metadata_list.get( AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME ).value except ValueError: - raise AquaValueError( - f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field for model {aqua_model.id}" + message = ( + f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field " + f"for model {aqua_model.id}." ) + logger.debug(message) + if not container_family: + raise AquaValueError( + f"{message}. For unverified Aqua models, container_family parameter should be " + f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}." + ) + container_type_key = container_family + try: + # Check if the container override flag is set. If set, then the user has chosen custom image + if aqua_model.custom_metadata_list.get( + AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME + ).value: + is_custom_container = True + except Exception: + pass # fetch image name from config - container_image = get_container_image( - container_type=container_type_key, + # If the image is of type custom, then `container_type_key` is the inference image + container_image = ( + get_container_image( + container_type=container_type_key, + ) + if not is_custom_container + else container_type_key ) logging.info( f"Aqua Image used for deploying {aqua_model.id} : {container_image}" ) + # Fetch the startup cli command for the container + # container_index.json will have "containerSpec" section which will provide the cli params for a given container family + container_config = get_container_config() + container_spec = container_config.get(ContainerSpec.CONTAINER_SPEC, {}).get( + container_type_key, {} + ) + # these params cannot be overridden for Aqua deployments + params = container_spec.get(ContainerSpec.CLI_PARM, "") + server_port = server_port or container_spec.get( + ContainerSpec.SERVER_PORT + ) # Give precendece to the input parameter + health_check_port = health_check_port or container_spec.get( + ContainerSpec.HEALTH_CHECK_PORT + ) # Give precendece to the input parameter + + deployment_config = self.get_deployment_config(config_source_id) + vllm_params = ( + deployment_config.get("configuration", UNKNOWN_DICT) + .get(instance_shape, UNKNOWN_DICT) + .get("parameters", UNKNOWN_DICT) + .get(InferenceContainerParamType.PARAM_TYPE_VLLM, UNKNOWN) + ) + + # validate user provided params + user_params = env_var.get("PARAMS", UNKNOWN) + if user_params: + restricted_params = self._find_restricted_params( + params, user_params, container_type_key + ) + if restricted_params: + raise AquaValueError( + f"Parameters {restricted_params} are set by Aqua " + f"and cannot be overridden or are invalid." + ) + + deployment_params = get_combined_params(vllm_params, user_params) + + if deployment_params: + params = f"{params} {deployment_params}" + + env_var.update({"PARAMS": params}) + for env in container_spec.get(ContainerSpec.ENV_VARS, []): + if isinstance(env, dict): + env_var.update(env) + + logging.info(f"Env vars used for deploying {aqua_model.id} :{env_var}") + # Start model deployment # configure model deployment infrastructure - # todo : any other infrastructure params needed? infrastructure = ( ModelDeploymentInfrastructure() .with_project_id(project_id) @@ -370,7 +333,6 @@ def create( ) ) # configure model deployment runtime - # todo : any other runtime params needed? container_runtime = ( ModelDeploymentContainerRuntime() .with_image(container_image) @@ -384,7 +346,6 @@ def create( .with_remove_existing_artifact(True) ) # configure model deployment and deploy model on container runtime - # todo : any other deployment params needed? deployment = ( ModelDeployment() .with_display_name(display_name) @@ -447,8 +408,8 @@ def list(self, **kwargs) -> List["AquaDeployment"]: for model_deployment in model_deployments: oci_aqua = ( ( - Tags.AQUA_TAG.value in model_deployment.freeform_tags - or Tags.AQUA_TAG.value.lower() in model_deployment.freeform_tags + Tags.AQUA_TAG in model_deployment.freeform_tags + or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags ) if model_deployment.freeform_tags else False @@ -502,8 +463,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail": oci_aqua = ( ( - Tags.AQUA_TAG.value in model_deployment.freeform_tags - or Tags.AQUA_TAG.value.lower() in model_deployment.freeform_tags + Tags.AQUA_TAG in model_deployment.freeform_tags + or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags ) if model_deployment.freeform_tags else False @@ -575,71 +536,171 @@ def get_deployment_config(self, model_id: str) -> Dict: ) return config + def get_deployment_default_params( + self, + model_id: str, + instance_shape: str, + ) -> List[str]: + """Gets the default params set in the deployment configs for the given model and instance shape. -@dataclass -class ModelParams: - max_tokens: int = None - temperature: float = None - top_k: float = None - top_p: float = None - model: str = None + Parameters + ---------- + model_id: str + The OCID of the Aqua model. + instance_shape: (str). + The shape of the instance used for deployment. + + Returns + ------- + List[str]: + List of parameters from the loaded from deployment config json file. If not available, then an empty list + is returned. + + """ + default_params = [] + model = DataScienceModel.from_id(model_id) + try: + container_type_key = model.custom_metadata_list.get( + AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME + ).value + except ValueError: + container_type_key = UNKNOWN + logger.debug( + f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field for model {model_id}." + ) -@dataclass -class MDInferenceResponse(AquaApp): - """Contains APIs for Aqua Model deployments Inference. + if container_type_key: + container_type_key = container_type_key.lower() + if container_type_key in InferenceContainerTypeFamily.values(): + deployment_config = self.get_deployment_config(model_id) + config_parameters = ( + deployment_config.get("configuration", UNKNOWN_DICT) + .get(instance_shape, UNKNOWN_DICT) + .get("parameters", UNKNOWN_DICT) + ) + if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_key: + params = config_parameters.get( + InferenceContainerParamType.PARAM_TYPE_VLLM, UNKNOWN + ) + elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_key: + params = config_parameters.get( + InferenceContainerParamType.PARAM_TYPE_TGI, UNKNOWN + ) + else: + params = UNKNOWN + logger.debug( + f"Default inference parameters are not available for the model {model_id} and " + f"instance {instance_shape}." + ) + if params: + # account for param that can have --arg but no values, e.g. --trust-remote-code + default_params.extend(get_params_list(params)) - Attributes - ---------- + return default_params - model_params: Dict - prompt: string + def validate_deployment_params( + self, + model_id: str, + params: List[str] = None, + container_family: str = None, + ) -> Dict: + """Validate if the deployment parameters passed by the user can be overridden. Parameter values are not + validated, only param keys are validated. - Methods - ------- - get_model_deployment_response(self, **kwargs) -> "String" - Creates an instance of model deployment via Aqua - """ + Parameters + ---------- + model_id: str + The OCID of the Aqua model. + params : List[str], optional + Params passed by the user. + container_family: str + The image family of model deployment container runtime. Required for unverified Aqua models. - prompt: str = None - model_params: field(default_factory=ModelParams) = None + Returns + ------- + Return a list of restricted params. - @telemetry(entry_point="plugin=inference&action=get_response", name="aqua") - def get_model_deployment_response(self, endpoint): """ - Returns MD inference response + restricted_params = [] + if params: + model = DataScienceModel.from_id(model_id) + try: + container_type_key = model.custom_metadata_list.get( + AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME + ).value + except ValueError: + message = ( + f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field " + f"for model {model_id}." + ) + logger.debug(message) + + if not container_family: + raise AquaValueError( + f"{message}. For unverified Aqua models, container_family parameter should be " + f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}." + ) + container_type_key = container_family + + container_config = get_container_config() + container_spec = container_config.get(ContainerSpec.CONTAINER_SPEC, {}).get( + container_type_key, {} + ) + cli_params = container_spec.get(ContainerSpec.CLI_PARM, "") + + restricted_params = self._find_restricted_params( + cli_params, params, container_type_key + ) + + if restricted_params: + raise AquaValueError( + f"Parameters {restricted_params} are set by Aqua " + f"and cannot be overridden or are invalid." + ) + return dict(valid=True) + + @staticmethod + def _find_restricted_params( + default_params: Union[str, List[str]], + user_params: Union[str, List[str]], + container_family: str, + ) -> List[str]: + """Returns a list of restricted params that user chooses to override when creating an Aqua deployment. + The default parameters coming from the container index json file cannot be overridden. In addition to this, + a set of parameters maintained in Parameters ---------- - endpoint: str - MD predict url - prompt: str - User prompt. - - model_params: (Dict, optional) - Model parameters to be associated with the message. - Currently supported VLLM+OpenAI parameters. - - --model-params '{ - "max_tokens":500, - "temperature": 0.5, - "top_k": 10, - "top_p": 0.5, - "model": "/opt/ds/model/deployed_model", - ...}' + default_params: + Inference container parameter string with default values. + user_params: + Inference container parameter string with user provided values. + container_family: str + The image family of model deployment container runtime. Returns ------- - model_response_content + A list with params keys common between params1 and params2. + """ + restricted_params = [] + if default_params and user_params: + default_params_dict = get_params_dict(default_params) + user_params_dict = get_params_dict(user_params) + + for key, items in user_params_dict.items(): + if ( + key in default_params_dict + or ( + InferenceContainerType.CONTAINER_TYPE_VLLM in container_family + and key in VLLMInferenceRestrictedParams + ) + or ( + InferenceContainerType.CONTAINER_TYPE_TGI in container_family + and key in TGIInferenceRestrictedParams + ) + ): + restricted_params.append(key.lstrip("--")) - params_dict = asdict(self.model_params) - params_dict = { - key: value for key, value in params_dict.items() if value is not None - } - body = {"prompt": self.prompt, **params_dict} - request_kwargs = {"json": body, "headers": {"Content-Type": "application/json"}} - response = requests.post( - endpoint, auth=default_signer()["signer"], **request_kwargs - ) - return json.loads(response.content) + return restricted_params diff --git a/ads/aqua/modeldeployment/entities.py b/ads/aqua/modeldeployment/entities.py new file mode 100644 index 000000000..cb0d47071 --- /dev/null +++ b/ads/aqua/modeldeployment/entities.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +from dataclasses import dataclass, field +from typing import Union + +from oci.data_science.models import ModelDeployment, ModelDeploymentSummary + +from ads.aqua.common.enums import Tags +from ads.aqua.constants import UNKNOWN, UNKNOWN_DICT +from ads.aqua.data import AquaResourceIdentifier +from ads.common.serializer import DataClassSerializable +from ads.common.utils import get_console_link + + +@dataclass +class ModelParams: + max_tokens: int = None + temperature: float = None + top_k: float = None + top_p: float = None + model: str = None + + +class ContainerSpec: + """ + Class to hold to hold keys within the container spec. + """ + + CONTAINER_SPEC = "containerSpec" + CLI_PARM = "cliParam" + SERVER_PORT = "serverPort" + HEALTH_CHECK_PORT = "healthCheckPort" + ENV_VARS = "envVars" + + +@dataclass +class ShapeInfo: + instance_shape: str = None + instance_count: int = None + ocpus: float = None + memory_in_gbs: float = None + + +@dataclass(repr=False) +class AquaDeployment(DataClassSerializable): + """Represents an Aqua Model Deployment""" + + id: str = None + display_name: str = None + aqua_service_model: bool = None + aqua_model_name: str = None + state: str = None + description: str = None + created_on: str = None + created_by: str = None + endpoint: str = None + console_link: str = None + lifecycle_details: str = None + shape_info: field(default_factory=ShapeInfo) = None + tags: dict = None + + @classmethod + def from_oci_model_deployment( + cls, + oci_model_deployment: Union[ModelDeploymentSummary, ModelDeployment], + region: str, + ) -> "AquaDeployment": + """Converts oci model deployment response to AquaDeployment instance. + + Parameters + ---------- + oci_model_deployment: Union[ModelDeploymentSummary, ModelDeployment] + The instance of either oci.data_science.models.ModelDeployment or + oci.data_science.models.ModelDeploymentSummary class. + region: str + The region of this model deployment. + + Returns + ------- + AquaDeployment: + The instance of the Aqua model deployment. + """ + instance_configuration = ( + oci_model_deployment.model_deployment_configuration_details.model_configuration_details.instance_configuration + ) + instance_shape_config_details = ( + instance_configuration.model_deployment_instance_shape_config_details + ) + instance_count = ( + oci_model_deployment.model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count + ) + shape_info = ShapeInfo( + instance_shape=instance_configuration.instance_shape_name, + instance_count=instance_count, + ocpus=( + instance_shape_config_details.ocpus + if instance_shape_config_details + else None + ), + memory_in_gbs=( + instance_shape_config_details.memory_in_gbs + if instance_shape_config_details + else None + ), + ) + + freeform_tags = oci_model_deployment.freeform_tags or UNKNOWN_DICT + aqua_service_model_tag = freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None) + aqua_model_name = freeform_tags.get(Tags.AQUA_MODEL_NAME_TAG, UNKNOWN) + + return AquaDeployment( + id=oci_model_deployment.id, + display_name=oci_model_deployment.display_name, + aqua_service_model=aqua_service_model_tag is not None, + aqua_model_name=aqua_model_name, + shape_info=shape_info, + state=oci_model_deployment.lifecycle_state, + lifecycle_details=getattr( + oci_model_deployment, "lifecycle_details", UNKNOWN + ), + description=oci_model_deployment.description, + created_on=str(oci_model_deployment.time_created), + created_by=oci_model_deployment.created_by, + endpoint=oci_model_deployment.model_deployment_url, + console_link=get_console_link( + resource="model-deployments", + ocid=oci_model_deployment.id, + region=region, + ), + tags=freeform_tags, + ) + + +@dataclass(repr=False) +class AquaDeploymentDetail(AquaDeployment, DataClassSerializable): + """Represents a details of Aqua deployment.""" + + log_group: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) + log: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) diff --git a/ads/aqua/modeldeployment/inference.py b/ads/aqua/modeldeployment/inference.py new file mode 100644 index 000000000..02f9bb408 --- /dev/null +++ b/ads/aqua/modeldeployment/inference.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import json +from dataclasses import asdict, dataclass, field + +import requests + +from ads.aqua.app import AquaApp, logger +from ads.aqua.modeldeployment.entities import ModelParams +from ads.common.auth import default_signer +from ads.telemetry import telemetry + + +@dataclass +class MDInferenceResponse(AquaApp): + """Contains APIs for Aqua Model deployments Inference. + + Attributes + ---------- + + model_params: Dict + prompt: string + + Methods + ------- + get_model_deployment_response(self, **kwargs) -> "String" + Creates an instance of model deployment via Aqua + """ + + prompt: str = None + model_params: field(default_factory=ModelParams) = None + + @telemetry(entry_point="plugin=inference&action=get_response", name="aqua") + def get_model_deployment_response(self, endpoint): + """ + Returns MD inference response + + Parameters + ---------- + endpoint: str + MD predict url + prompt: str + User prompt. + + model_params: (Dict, optional) + Model parameters to be associated with the message. + Currently supported VLLM+OpenAI parameters. + + --model-params '{ + "max_tokens":500, + "temperature": 0.5, + "top_k": 10, + "top_p": 0.5, + "model": "/opt/ds/model/deployed_model", + ...}' + + Returns + ------- + model_response_content + """ + + params_dict = asdict(self.model_params) + params_dict = { + key: value for key, value in params_dict.items() if value is not None + } + body = {"prompt": self.prompt, **params_dict} + request_kwargs = {"json": body, "headers": {"Content-Type": "application/json"}} + response = requests.post( + endpoint, auth=default_signer()["signer"], **request_kwargs + ) + return json.loads(response.content) diff --git a/ads/aqua/ui.py b/ads/aqua/ui.py index 458f5b111..9daf26344 100644 --- a/ads/aqua/ui.py +++ b/ads/aqua/ui.py @@ -3,21 +3,24 @@ # Copyright (c) 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import concurrent.futures +from dataclasses import dataclass, field from datetime import datetime, timedelta from threading import Lock +from typing import Dict, List from cachetools import TTLCache from oci.exceptions import ServiceError from oci.identity.models import Compartment from ads.aqua import logger -from ads.aqua.base import AquaApp -from ads.aqua.data import Tags -from ads.aqua.exception import AquaValueError, AquaResourceAccessError -from ads.aqua.utils import load_config, sanitize_response +from ads.aqua.app import AquaApp +from ads.aqua.common.enums import Tags +from ads.aqua.common.errors import AquaResourceAccessError, AquaValueError +from ads.aqua.common.utils import get_container_config, load_config, sanitize_response from ads.common import oci_client as oc from ads.common.auth import default_signer from ads.common.object_storage_details import ObjectStorageDetails +from ads.common.serializer import DataClassSerializable from ads.config import ( AQUA_CONFIG_FOLDER, AQUA_RESOURCE_LIMIT_NAMES_CONFIG, @@ -28,6 +31,70 @@ from ads.telemetry import telemetry +@dataclass(repr=False) +class AquaContainerConfigItem(DataClassSerializable): + """Represents an item of the AQUA container configuration.""" + + name: str = None + version: str = None + display_name: str = None + family: str = None + + +@dataclass(repr=False) +class AquaContainerConfig(DataClassSerializable): + """ + Represents a configuration with AQUA containers to be returned to the client. + """ + + inference: List[AquaContainerConfigItem] = field(default_factory=list) + finetune: List[AquaContainerConfigItem] = field(default_factory=list) + evaluate: List[AquaContainerConfigItem] = field(default_factory=list) + + @classmethod + def from_container_index_json(cls, config: Dict) -> "AquaContainerConfig": + """ + Create an AquaContainerConfig instance from a container index JSON. + + Parameters + ---------- + config : Dict + The container index JSON. + + Returns + ------- + AquaContainerConfig + The container configuration instance. + """ + config = config or {} + inference_items = [] + finetune_items = [] + evaluate_items = [] + + # extract inference containers + for container_type, containers in config.items(): + if isinstance(containers, list): + for container in containers: + container_item = AquaContainerConfigItem( + name=container.get("name", ""), + version=container.get("version", ""), + display_name=container.get( + "displayName", container.get("version", "") + ), + family=container_type, + ) + if container.get("type") == "inference": + inference_items.append(container_item) + elif container_type == "odsc-llm-fine-tuning": + finetune_items.append(container_item) + elif container_type == "odsc-llm-evaluate": + evaluate_items.append(container_item) + + return AquaContainerConfig( + inference=inference_items, finetune=finetune_items, evaluate=evaluate_items + ) + + class AquaUIApp(AquaApp): """Contains APIs for supporting Aqua UI. @@ -42,7 +109,8 @@ class AquaUIApp(AquaApp): Lists the specified log group's log objects. list_compartments(self, **kwargs) -> List[Dict] Lists the compartments in a specified compartment. - + list_containers(self, **kwargs) -> AquaContainerConfig + Containers config to be returned to the client. """ _compartments_cache = TTLCache( @@ -219,9 +287,7 @@ def list_model_version_sets(self, target_tag: str = None, **kwargs) -> str: """ compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID) target_resource = ( - "experiments" - if target_tag == Tags.AQUA_EVALUATION.value - else "modelversionsets" + "experiments" if target_tag == Tags.AQUA_EVALUATION else "modelversionsets" ) logger.info(f"Loading {target_resource} from compartment: {compartment_id}") @@ -451,3 +517,17 @@ def is_bucket_versioned(self, bucket_uri: str): message = f"Model artifact bucket {bucket_uri} is not versioned. Check if the path exists and enable versioning on the bucket to proceed with model creation." return dict(is_versioned=is_versioned, message=message) + + @telemetry(entry_point="plugin=ui&action=list_containers", name="aqua") + def list_containers(self) -> AquaContainerConfig: + """ + Lists the AQUA containers. + + Returns + ------- + AquaContainerConfig + The AQUA containers configuration. + """ + return AquaContainerConfig.from_container_index_json( + config=get_container_config() + ) diff --git a/ads/cli.py b/ads/cli.py index 819a8228a..872e7d177 100644 --- a/ads/cli.py +++ b/ads/cli.py @@ -4,19 +4,21 @@ # Copyright (c) 2021, 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -import traceback import sys +import traceback +from dataclasses import is_dataclass import fire -from dataclasses import is_dataclass + from ads.common import logger try: import click - import ads.opctl.cli + import ads.jobs.cli - import ads.pipeline.cli + import ads.opctl.cli import ads.opctl.operator.cli + import ads.pipeline.cli except Exception as ex: print( "Please run `pip install oracle-ads[opctl]` to install " @@ -33,6 +35,7 @@ else: import importlib_metadata as metadata + ADS_VERSION = metadata.version("oracle_ads") @@ -86,13 +89,58 @@ def serialize(data): print(str(data)) +def exit_program(ex: Exception, logger: "logging.Logger") -> None: + """ + Logs the exception and exits the program with a specific exit code. + + This function logs the full traceback and the exception message, then terminates + the program with an exit code. If the exception object has an 'exit_code' attribute, + it uses that as the exit code; otherwise, it defaults to 1. + + Parameters + ---------- + ex (Exception): + The exception that triggered the program exit. This exception + should ideally contain an 'exit_code' attribute, but it is not mandatory. + logger (Logger): + A logging.Logger instance used to log the traceback and the error message. + + Returns + ------- + None: + This function does not return anything because it calls sys.exit, + terminating the process. + + Examples + -------- + + >>> import logging + >>> logger = logging.getLogger('ExampleLogger') + >>> try: + ... raise ValueError("An error occurred") + ... except Exception as e: + ... exit_program(e, logger) + """ + + logger.debug(traceback.format_exc()) + logger.error(str(ex)) + + exit_code = getattr(ex, "exit_code", 1) + logger.error(f"Exit code: {exit_code}") + sys.exit(exit_code) + + def cli(): if len(sys.argv) > 1 and sys.argv[1] == "aqua": + from ads.aqua import logger as aqua_logger from ads.aqua.cli import AquaCommand - fire.Fire( - AquaCommand, command=sys.argv[2:], name="ads aqua", serialize=serialize - ) + try: + fire.Fire( + AquaCommand, command=sys.argv[2:], name="ads aqua", serialize=serialize + ) + except Exception as err: + exit_program(err, aqua_logger) else: click_cli() diff --git a/ads/common/serializer.py b/ads/common/serializer.py index a98be75e9..87a46b37d 100644 --- a/ads/common/serializer.py +++ b/ads/common/serializer.py @@ -79,7 +79,7 @@ def to_dict(self, **kwargs: Dict) -> Dict: @classmethod @abstractmethod - def from_dict(cls, obj_dict: dict) -> "Serializable": + def from_dict(cls, obj_dict: dict, **kwargs) -> "Serializable": """Returns an instance of the class instantiated by the dictionary provided. Parameters @@ -239,7 +239,7 @@ def from_json( Returns instance of the class """ if json_string: - return cls.from_dict(json.loads(json_string, cls=decoder)) + return cls.from_dict(json.loads(json_string, cls=decoder), **kwargs) if uri: json_dict = json.loads(cls._read_from_file(uri, **kwargs), cls=decoder) return cls.from_dict(json_dict) diff --git a/ads/config.py b/ads/config.py index e4d76e703..f55fa196f 100644 --- a/ads/config.py +++ b/ads/config.py @@ -41,7 +41,6 @@ ) MD_OCID = os.environ.get("MD_OCID") DATAFLOW_RUN_OCID = os.environ.get("DATAFLOW_RUN_ID") - RESOURCE_OCID = ( NB_SESSION_OCID or JOB_RUN_OCID or MD_OCID or PIPELINE_RUN_OCID or DATAFLOW_RUN_OCID ) @@ -66,6 +65,8 @@ AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME = "deployment-container" AQUA_FINETUNING_CONTAINER_METADATA_NAME = "finetune-container" AQUA_EVALUATION_CONTAINER_METADATA_NAME = "evaluation-container" +AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME = "deployment-container-custom" +AQUA_FINETUNING_CONTAINER_OVERRIDE_FLAG_METADATA_NAME = "finetune-container-custom" AQUA_MODEL_DEPLOYMENT_FOLDER = "/opt/ds/model/deployed_model/" AQUA_SERVED_MODEL_NAME = "odsc-llm" AQUA_CONFIG_FOLDER = os.path.join( diff --git a/ads/model/datascience_model.py b/ads/model/datascience_model.py index fc1df3dd9..19bc6d3e6 100644 --- a/ads/model/datascience_model.py +++ b/ads/model/datascience_model.py @@ -238,7 +238,7 @@ class DataScienceModel(Builder): CONST_MODEL_VERSION_ID: "version_id", CONST_TIME_CREATED: "time_created", CONST_LIFECYCLE_STATE: "lifecycle_state", - CONST_MODEL_FILE_DESCRIPTION: "model_file_description", + CONST_MODEL_FILE_DESCRIPTION: "model_description", } def __init__(self, spec: Dict = None, **kwargs) -> None: diff --git a/ads/model/deployment/model_deployment.py b/ads/model/deployment/model_deployment.py index 7a1f8afd1..df6b681a2 100644 --- a/ads/model/deployment/model_deployment.py +++ b/ads/model/deployment/model_deployment.py @@ -370,6 +370,17 @@ def model_deployment_id(self) -> str: """ return self.get_spec(self.CONST_ID, None) + @property + def id(self) -> str: + """The model deployment ocid. + + Returns + ------- + str + The model deployment ocid. + """ + return self.get_spec(self.CONST_ID, None) + @property def created_by(self) -> str: """The user that creates the model deployment. diff --git a/ads/model/model_metadata.py b/ads/model/model_metadata.py index 228fec9af..2667b82ad 100644 --- a/ads/model/model_metadata.py +++ b/ads/model/model_metadata.py @@ -8,10 +8,10 @@ import logging import os import sys -from abc import ABC, abstractclassmethod, abstractmethod +from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union, Optional, Any import ads.dataset.factory as factory import fsspec @@ -41,6 +41,8 @@ _METADATA_EMPTY_VALUE = "NA" CURRENT_WORKING_DIR = "." +_sentinel = object() + class MetadataSizeTooLarge(ValueError): """Maximum allowed size for model metadata has been exceeded. @@ -727,13 +729,18 @@ def __init__(self): """Initializes Model Metadata.""" self._items = set() - def get(self, key: str) -> ModelMetadataItem: + def get( + self, key: str, value: Optional[Any] = _sentinel + ) -> Union[ModelMetadataItem, Any]: """Returns the model metadata item by provided key. Parameters ---------- key: str The key of model metadata item. + value: (str, optional) + A value to return if the specified key does not exist. Defaults to `object()`. + If default value not specified, the ValueError will be returned. Returns ------- @@ -750,7 +757,11 @@ def get(self, key: str) -> ModelMetadataItem: for item in self._items: if item.key.lower() == key.lower(): return item - raise ValueError(f"The metadata with {key} not found.") + + if value is _sentinel: + raise ValueError(f"The metadata with {key} not found.") + + return value def reset(self) -> None: """Resets all model metadata items to empty values. @@ -952,7 +963,7 @@ def __repr__(self): def __len__(self): return len(self._items) - @abstractclassmethod + @abstractmethod def _from_oci_metadata(cls, metadata_list): pass @@ -967,7 +978,7 @@ def to_dataframe(self) -> pd.DataFrame: """ pass - @abstractclassmethod + @abstractmethod def from_dict(cls, data: Dict) -> "ModelMetadata": """Constructs an instance of `ModelMetadata` from a dictionary. diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index f0183978e..0bb91d4d3 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -2,6 +2,15 @@ Release Notes ============= +2.11.10 +------ +Release date: June 5, 2024 + +* Support for Bring Your Own Model (BYOM) via AI Quick Actions. +* Introduced enhancements following our recent release. + + + 2.11.9 ------ Release date: April 24, 2024 diff --git a/pyproject.toml b/pyproject.toml index 87441c44c..f999c9fac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ build-backend = "flit_core.buildapi" # Required name = "oracle_ads" # the install (PyPI) name; name for local build in [tool.flit.module] section below -version = "2.11.9" +version = "2.11.10" # Optional description = "Oracle Accelerated Data Science SDK" @@ -61,7 +61,7 @@ dependencies = [ "fsspec>=0.8.7", "gitpython>=3.1.2", "jinja2>=2.11.2", - "matplotlib>=3.1.3", + "matplotlib>=3.1.3, <=3.8.4", "numpy>=1.19.2", "oci>=2.125.3", "ocifs>=1.1.3", @@ -115,7 +115,7 @@ opctl = [ "py-cpuinfo", "rich", "fire", - "cachetools", + "cachetools" ] optuna = ["optuna==2.9.0", "oracle_ads[viz]"] spark = ["pyspark>=3.0.0"] @@ -175,7 +175,7 @@ pii = [ "spacy==3.6.1", "report-creator==1.0.9", ] -llm = ["langchain-community<0.0.32", "langchain>=0.1.10,<0.1.14", "evaluate>=0.4.0"] +llm = ["langchain-community<0.0.32", "langchain>=0.1.10,<0.1.14", "evaluate>=0.4.0", "langchain-core<0.1.51"] aqua = ["jupyter_server"] # To reduce backtracking (decrese deps install time) during test/dev env setup reducing number of versions pip is diff --git a/tests/unitary/default_setup/model/test_model_metadata.py b/tests/unitary/default_setup/model/test_model_metadata.py index 9e5ffefc9..f38af703a 100644 --- a/tests/unitary/default_setup/model/test_model_metadata.py +++ b/tests/unitary/default_setup/model/test_model_metadata.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright (c) 2021, 2023 Oracle and/or its affiliates. +# Copyright (c) 2021, 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ """Unit tests for model metadata module. Includes tests for: @@ -839,6 +839,20 @@ def test_set_training_and_validation_dataset(self): ) assert metadata_custom["ValidationDatasetSize"].value == "(100,100)" + def test_get(self): + """Tests getting the model metadata item by provided key.""" + metadata_custom = ModelCustomMetadata() + metadata_custom._add_many([self.user_defined_item, self.dict_item]) + assert metadata_custom.get("My Own Meta") == self.user_defined_item + assert metadata_custom.get("My Meta With Dictionary") == self.dict_item + assert metadata_custom.get("NotExistingKey", None) == None + assert ( + metadata_custom.get("NotExistingKey", "some_default_value") + == "some_default_value" + ) + with pytest.raises(ValueError): + assert metadata_custom.get("NotExistingKey") + class TestModelTaxonomyMetadata: """Unit tests for ModelTaxonomyMetadata class.""" diff --git a/tests/unitary/with_extras/aqua/test_cli.py b/tests/unitary/with_extras/aqua/test_cli.py index d7dcc9512..6c3c97cc8 100644 --- a/tests/unitary/with_extras/aqua/test_cli.py +++ b/tests/unitary/with_extras/aqua/test_cli.py @@ -4,17 +4,19 @@ # Copyright (c) 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -import os import logging +import os import subprocess -from unittest import TestCase -from unittest.mock import patch from importlib import reload +from unittest import TestCase +from unittest.mock import call, patch + from parameterized import parameterized import ads.aqua import ads.config from ads.aqua.cli import AquaCommand +from ads.aqua.common.errors import AquaCLIError, AquaConfigError class TestAquaCLI(TestCase): @@ -38,56 +40,137 @@ def test_entrypoint(self): @parameterized.expand( [ ("default", None, DEFAULT_AQUA_CLI_LOGGING_LEVEL), - ("set logging level", "info", "info"), + ("set logging level", dict(log_level="info"), "INFO"), + ("debug", dict(debug=True), "DEBUG"), + ("verbose", dict(verbose=True), "INFO"), + ("flag_priority", dict(debug=True, log_level="info"), "DEBUG"), ] ) + @patch.dict( + os.environ, {"ODSC_MODEL_COMPARTMENT_OCID": SERVICE_COMPARTMENT_ID}, clear=True + ) def test_aquacommand(self, name, arg, expected): """Tests aqua command initialization.""" - with patch.dict( - os.environ, - {"ODSC_MODEL_COMPARTMENT_OCID": TestAquaCLI.SERVICE_COMPARTMENT_ID}, - ): + + reload(ads.config) + reload(ads.aqua) + reload(ads.aqua.cli) + with patch("ads.aqua.cli.set_log_level") as mock_setting_log: + if arg: + AquaCommand(**arg) + else: + AquaCommand() + mock_setting_log.assert_called_with(expected) + + @parameterized.expand( + [ + ("conflict", dict(debug=True, verbose=True)), + ("invalid_value", dict(debug="abc")), + ("invalid_value", dict(verbose="abc")), + ] + ) + @patch.dict( + os.environ, {"ODSC_MODEL_COMPARTMENT_OCID": SERVICE_COMPARTMENT_ID}, clear=True + ) + def test_aquacommand_flag(self, name, arg): + """Tests aqua command initialization with wrong flag.""" + + reload(ads.config) + reload(ads.aqua) + reload(ads.aqua.cli) + with self.assertRaises(AquaCLIError): + AquaCommand(**arg) + + @parameterized.expand( + [ + ( + "default", + {"ODSC_MODEL_COMPARTMENT_OCID": ""}, + "ODSC_MODEL_COMPARTMENT_OCID environment variable is not set for Aqua.", + ), + ( + "using jupyter instance", + { + "ODSC_MODEL_COMPARTMENT_OCID": "", + "NB_SESSION_OCID": "nb-session-ocid", + }, + "Aqua is not available for the notebook session nb-session-ocid. For more information, please refer to the documentation.", + ), + ] + ) + def test_aqua_command_without_compartment_env_var( + self, name, mock_env_dict, expected_msg + ): + """Test whether exit is called when ODSC_MODEL_COMPARTMENT_OCID is not set. + Also check if NB_SESSION_OCID is set then log the appropriate message.""" + + with patch.dict(os.environ, mock_env_dict): reload(ads.config) reload(ads.aqua) reload(ads.aqua.cli) - with patch("ads.aqua.cli.set_log_level") as mock_setting_log: - if arg: - AquaCommand(arg) - else: - AquaCommand() - mock_setting_log.assert_called_with(expected) + with self.assertRaises(AquaConfigError) as cm: + AquaCommand() + + self.assertEqual(str(cm.exception), expected_msg) + + @patch("sys.argv", ["ads", "aqua", "--some-option"]) + @patch("ads.cli.serialize") + @patch("fire.Fire") + @patch("ads.aqua.cli.AquaCommand") + @patch("ads.aqua.logger") + def test_aqua_cli(self, mock_logger, mock_aqua_command, mock_fire, mock_serialize): + """Tests when Aqua Cli being invoked.""" + from ads.cli import cli + + cli() + mock_fire.assert_called_once() + mock_fire.assert_called_with( + mock_aqua_command, + command=["--some-option"], + name="ads aqua", + serialize=mock_serialize, + ) @parameterized.expand( [ - ("default", None), - ("using jupyter instance", "nb-session-ocid"), + ( + "with_defined_exit_code", + AquaConfigError("test error"), + AquaConfigError.exit_code, + "test error", + ), + ( + "without_defined_exit_code", + ValueError("general error"), + 1, + "general error", + ), ] ) - def test_aqua_command_without_compartment_env_var(self, name, session_ocid): - """Test whether exit is called when ODSC_MODEL_COMPARTMENT_OCID is not set. Also check if NB_SESSION_OCID is - set then log the appropriate message.""" - - with patch("sys.exit") as mock_exit: - env_dict = {"ODSC_MODEL_COMPARTMENT_OCID": ""} - if session_ocid: - env_dict.update({"NB_SESSION_OCID": session_ocid}) - with patch.dict(os.environ, env_dict): - reload(ads.config) - reload(ads.aqua) - reload(ads.aqua.cli) - with patch("ads.aqua.cli.set_log_level") as mock_setting_log: - with patch("ads.aqua.logger.error") as mock_logger_error: - with patch("ads.aqua.logger.debug") as mock_logger_debug: - AquaCommand() - mock_setting_log.assert_called_with( - TestAquaCLI.DEFAULT_AQUA_CLI_LOGGING_LEVEL - ) - mock_logger_debug.assert_any_call( - "ODSC_MODEL_COMPARTMENT_OCID environment variable is not set for Aqua." - ) - if session_ocid: - mock_logger_error.assert_any_call( - f"Aqua is not available for the notebook session {session_ocid}. For more information, " - f"please refer to the documentation." - ) - mock_exit.assert_called_with(1) + @patch("sys.argv", ["ads", "aqua", "--error-option"]) + @patch("fire.Fire") + @patch("ads.aqua.cli.AquaCommand") + @patch("ads.aqua.logger.error") + @patch("sys.exit") + def test_aqua_cli_with_error( + self, + name, + mock_side_effect, + expected_code, + expected_logging_message, + mock_exit, + mock_logger_error, + mock_aqua_command, + mock_fire, + ): + """Tests when Aqua Cli gracefully exit when error raised.""" + mock_fire.side_effect = mock_side_effect + from ads.cli import cli + + cli() + calls = [ + call(expected_logging_message), + call(f"Exit code: {expected_code}"), + ] + mock_logger_error.assert_has_calls(calls) + mock_exit.assert_called_with(expected_code) diff --git a/tests/unitary/with_extras/aqua/test_common_handler.py b/tests/unitary/with_extras/aqua/test_common_handler.py index 2e85e4743..88e3e6e06 100644 --- a/tests/unitary/with_extras/aqua/test_common_handler.py +++ b/tests/unitary/with_extras/aqua/test_common_handler.py @@ -6,13 +6,14 @@ import os import unittest -from unittest.mock import MagicMock, patch from importlib import reload +from unittest.mock import MagicMock, patch + from notebook.base.handlers import IPythonHandler -import ads.config import ads.aqua -from ads.aqua.utils import AQUA_GA_LIST +import ads.config +from ads.aqua.constants import AQUA_GA_LIST from ads.aqua.extension.common_handler import CompatibilityCheckHandler diff --git a/tests/unitary/with_extras/aqua/test_data/finetuning/ft_config.json b/tests/unitary/with_extras/aqua/test_data/finetuning/ft_config.json new file mode 100644 index 000000000..4b80a47d3 --- /dev/null +++ b/tests/unitary/with_extras/aqua/test_data/finetuning/ft_config.json @@ -0,0 +1,34 @@ +{ + "configuration": { + "adapter": "lora", + "bf16": true, + "flash_attention": true, + "fp16": false, + "gradient_accumulation_steps": 1, + "gradient_checkpointing": true, + "learning_rate": 0.0002, + "logging_steps": 1, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_r": 32, + "lora_target_linear": true, + "lora_target_modules": [ + "q_proj", + "k_proj" + ], + "lr_scheduler": "cosine", + "micro_batch_size": 1, + "optimizer": "adamw_torch", + "pad_to_sequence_len": true, + "sample_packing": true, + "sequence_len": 2048, + "tf32": false, + "val_set_size": 0.1 + }, + "shape": { + "VM.GPU.A10.2": { + "batch_size": 2, + "replica": 1 + } + } +} diff --git a/tests/unitary/with_extras/aqua/test_data/ui/container_index.json b/tests/unitary/with_extras/aqua/test_data/ui/container_index.json new file mode 100644 index 000000000..04a88704e --- /dev/null +++ b/tests/unitary/with_extras/aqua/test_data/ui/container_index.json @@ -0,0 +1,76 @@ +{ + "containerSpec": { + "odsc-tgi-serving": { + "cliParam": "--sharded true --trust-remote-code", + "envVars": [ + { + "MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions" + }, + { + "MODEL_DEPLOY_ENABLE_STREAMING": "true" + }, + { + "PORT": "8080" + }, + { + "HEALTH_CHECK_PORT": "8080" + } + ], + "healthCheckPort": "8080", + "serverPort": "8080" + }, + "odsc-vllm-serving": { + "cliParam": "--served-model-name $(python -c 'import os; print(os.environ.get(\"ODSC_SERVED_MODEL_NAME\",\"odsc-llm\"))') --seed 42 ", + "envVars": [ + { + "MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions" + }, + { + "MODEL_DEPLOY_ENABLE_STREAMING": "true" + }, + { + "PORT": "8080" + }, + { + "HEALTH_CHECK_PORT": "8080" + } + ], + "healthCheckPort": "8080", + "serverPort": "8080" + } + }, + "odsc-llm-evaluate": [ + { + "name": "dsmc://odsc-llm-evaluate", + "version": "0.1.2.0" + } + ], + "odsc-llm-fine-tuning": [ + { + "name": "dsmc://odsc-llm-fine-tuning", + "version": "1.1.33.34" + } + ], + "odsc-tgi-serving": [ + { + "displayName": "TGI:1.4.5", + "name": "dsmc://odsc-tgi-serving", + "type": "inference", + "version": "1.4.5" + }, + { + "displayName": "TGI:2.0.2", + "name": "dsmc://odsc-tgi-serving", + "type": "inference", + "version": "2.0.2" + } + ], + "odsc-vllm-serving": [ + { + "displayName": "VLLM:0.3.0", + "name": "dsmc://odsc-vllm-serving", + "type": "inference", + "version": "0.3.0.7" + } + ] +} diff --git a/tests/unitary/with_extras/aqua/test_decorator.py b/tests/unitary/with_extras/aqua/test_decorator.py index 8617e898c..7e371e339 100644 --- a/tests/unitary/with_extras/aqua/test_decorator.py +++ b/tests/unitary/with_extras/aqua/test_decorator.py @@ -22,7 +22,7 @@ from parameterized import parameterized from tornado.web import HTTPError -from ads.aqua.exception import AquaError +from ads.aqua.common.errors import AquaError from ads.aqua.extension.base_handler import AquaAPIhandler @@ -179,7 +179,7 @@ def setUp(self, ipython_init_mock) -> None: @patch("uuid.uuid4") def test_handle_exceptions(self, name, error, expected_reply, mock_uuid): """Tests handling error decorator.""" - from ads.aqua.decorator import handle_exceptions + from ads.aqua.common.decorator import handle_exceptions mock_uuid.return_value = TestDataset.mock_request_id expected_call = json.dumps(expected_reply) diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index 795eff319..a4e99a28d 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -4,29 +4,31 @@ # Copyright (c) 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -import os +import copy import json +import os import unittest from dataclasses import asdict from importlib import reload from unittest.mock import MagicMock, patch + +import oci import pytest -import copy import yaml +from parameterized import parameterized -import oci -import ads.aqua.deployment +import ads.aqua.modeldeployment.deployment import ads.config -from ads.aqua.deployment import ( +from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse +from ads.aqua.modeldeployment.entities import ( AquaDeployment, AquaDeploymentDetail, - AquaDeploymentApp, - MDInferenceResponse, ModelParams, ) -from ads.aqua.exception import AquaRuntimeError +from ads.aqua.common.errors import AquaRuntimeError, AquaValueError from ads.model.datascience_model import DataScienceModel from ads.model.deployment.model_deployment import ModelDeployment +from ads.model.model_metadata import ModelCustomMetadata null = None @@ -167,7 +169,7 @@ def setUpClass(cls): os.environ["PROJECT_COMPARTMENT_OCID"] = TestDataset.USER_COMPARTMENT_ID reload(ads.config) reload(ads.aqua) - reload(ads.aqua.deployment) + reload(ads.aqua.modeldeployment.deployment) @classmethod def tearDownClass(cls): @@ -177,7 +179,7 @@ def tearDownClass(cls): os.environ.pop("PROJECT_COMPARTMENT_OCID", None) reload(ads.config) reload(ads.aqua) - reload(ads.aqua.deployment) + reload(ads.aqua.modeldeployment.deployment) def test_list_deployments(self): """Tests the list method in the AquaDeploymentApp class.""" @@ -200,7 +202,7 @@ def test_list_deployments(self): expected_attributes ), "Attributes mismatch" - @patch("ads.aqua.deployment.get_resource_name") + @patch("ads.aqua.modeldeployment.deployment.get_resource_name") def test_get_deployment(self, mock_get_resource_name): """Tests the get method in the AquaDeploymentApp class.""" @@ -213,8 +215,8 @@ def test_get_deployment(self, mock_get_resource_name): data=oci.data_science.models.ModelDeploymentSummary(**model_deployment), ) ) - mock_get_resource_name.side_effect = ( - lambda param: "log-group-name" + mock_get_resource_name.side_effect = lambda param: ( + "log-group-name" if param.startswith("ocid1.loggroup") else "log-name" if param.startswith("ocid1.log") @@ -253,7 +255,7 @@ def test_get_deployment_missing_tags(self): self.app.get(model_deployment_id=TestDataset.MODEL_DEPLOYMENT_ID) - @patch("ads.aqua.deployment.load_config") + @patch("ads.aqua.modeldeployment.deployment.load_config") def test_get_deployment_config(self, mock_load_config): """Test for fetching config details for a given deployment.""" @@ -272,11 +274,16 @@ def test_get_deployment_config(self, mock_load_config): result = self.app.get_deployment_config(TestDataset.MODEL_ID) assert result == config + @patch("ads.aqua.modeldeployment.deployment.get_container_config") @patch("ads.aqua.model.AquaModelApp.create") - @patch("ads.aqua.deployment.get_container_image") + @patch("ads.aqua.modeldeployment.deployment.get_container_image") @patch("ads.model.deployment.model_deployment.ModelDeployment.deploy") def test_create_deployment_for_foundation_model( - self, mock_deploy, mock_get_container_image, mock_create + self, + mock_deploy, + mock_get_container_image, + mock_create, + mock_get_container_config, ): """Test to create a deployment for foundational model""" aqua_model = os.path.join( @@ -290,6 +297,14 @@ def test_create_deployment_for_foundation_model( config = json.load(_file) self.app.get_deployment_config = MagicMock(return_value=config) + + container_index_json = os.path.join( + self.curr_dir, "test_data/ui/container_index.json" + ) + with open(container_index_json, "r") as _file: + container_index_config = json.load(_file) + mock_get_container_config.return_value = container_index_config + mock_get_container_image.return_value = TestDataset.DEPLOYMENT_IMAGE_NAME aqua_deployment = os.path.join( self.curr_dir, "test_data/deployment/aqua_create_deployment.yaml" @@ -324,11 +339,16 @@ def test_create_deployment_for_foundation_model( expected_result["state"] = "CREATING" assert actual_attributes == expected_result + @patch("ads.aqua.modeldeployment.deployment.get_container_config") @patch("ads.aqua.model.AquaModelApp.create") - @patch("ads.aqua.deployment.get_container_image") + @patch("ads.aqua.modeldeployment.deployment.get_container_image") @patch("ads.model.deployment.model_deployment.ModelDeployment.deploy") def test_create_deployment_for_fine_tuned_model( - self, mock_deploy, mock_get_container_image, mock_create + self, + mock_deploy, + mock_get_container_image, + mock_create, + mock_get_container_config, ): """Test to create a deployment for fine-tuned model""" @@ -357,6 +377,14 @@ def yaml_to_json(input_file): config = json.load(_file) self.app.get_deployment_config = MagicMock(return_value=config) + + container_index_json = os.path.join( + self.curr_dir, "test_data/ui/container_index.json" + ) + with open(container_index_json, "r") as _file: + container_index_config = json.load(_file) + mock_get_container_config.return_value = container_index_config + mock_get_container_image.return_value = TestDataset.DEPLOYMENT_IMAGE_NAME aqua_deployment = os.path.join( self.curr_dir, "test_data/deployment/aqua_create_deployment.yaml" @@ -391,6 +419,174 @@ def yaml_to_json(input_file): expected_result["state"] = "CREATING" assert actual_attributes == expected_result + @parameterized.expand( + [ + ( + "VLLM_PARAMS", + "odsc-vllm-serving", + ["--max-model-len 4096", "--seed 42", "--trust-remote-code"], + ), + ( + "VLLM_PARAMS", + "odsc-vllm-serving", + [], + ), + ( + "TGI_PARAMS", + "odsc-tgi-serving", + ["--sharded true", "--trust-remote-code"], + ), + ( + "CUSTOM_PARAMS", + "custom-container-key", + ["--max-model-len 4096", "--seed 42", "--trust-remote-code"], + ), + ] + ) + @patch("ads.model.datascience_model.DataScienceModel.from_id") + def test_get_deployment_default_params( + self, container_params_field, container_type_key, params, mock_from_id + ): + """Test for fetching config details for a given deployment.""" + + config_json = os.path.join( + self.curr_dir, "test_data/deployment/deployment_config.json" + ) + with open(config_json, "r") as _file: + config = json.load(_file) + # update config params for testing + config["configuration"][TestDataset.DEPLOYMENT_SHAPE_NAME]["parameters"][ + container_params_field + ] = " ".join(params) + + mock_model = MagicMock() + custom_metadata_list = ModelCustomMetadata() + custom_metadata_list.add( + **{"key": "deployment-container", "value": container_type_key} + ) + mock_model.custom_metadata_list = custom_metadata_list + mock_from_id.return_value = mock_model + + self.app.get_deployment_config = MagicMock(return_value=config) + result = self.app.get_deployment_default_params( + TestDataset.MODEL_ID, TestDataset.DEPLOYMENT_SHAPE_NAME + ) + if container_params_field == "CUSTOM_PARAMS": + assert result == [] + else: + assert result == params + + @parameterized.expand( + [ + ( + "odsc-vllm-serving", + ["--max-model-len 4096", "--seed 42", "--trust-remote-code"], + ), + ( + "odsc-vllm-serving", + [], + ), + ( + "odsc-tgi-serving", + ["--sharded true", "--trust-remote-code"], + ), + ( + "custom-container-key", + ["--max-model-len 4096", "--seed 42", "--trust-remote-code"], + ), + ( + "odsc-vllm-serving", + ["--tensor-parallel-size 2"], + ), + ( + "odsc-tgi-serving", + ["--port 8080"], + ), + ] + ) + @patch("ads.model.datascience_model.DataScienceModel.from_id") + @patch("ads.aqua.modeldeployment.deployment.get_container_config") + def test_validate_deployment_params( + self, container_type_key, params, mock_get_container_config, mock_from_id + ): + """Test for checking if overridden deployment params are valid.""" + mock_model = MagicMock() + custom_metadata_list = ModelCustomMetadata() + custom_metadata_list.add( + **{"key": "deployment-container", "value": container_type_key} + ) + mock_model.custom_metadata_list = custom_metadata_list + mock_from_id.return_value = mock_model + + container_index_json = os.path.join( + self.curr_dir, "test_data/ui/container_index.json" + ) + with open(container_index_json, "r") as _file: + container_index_config = json.load(_file) + mock_get_container_config.return_value = container_index_config + + if container_type_key in {"odsc-vllm-serving", "odsc-tgi-serving"} and params: + with pytest.raises(AquaValueError): + self.app.validate_deployment_params( + model_id="mock-model-id", + params=params, + ) + else: + result = self.app.validate_deployment_params( + model_id="mock-model-id", + params=params, + ) + assert result["valid"] is True + + @parameterized.expand( + [ + ( + "odsc-vllm-serving", + ["--max-model-len 4096"], + ), + ( + "odsc-tgi-serving", + ["--max_stop_sequences 5"], + ), + ( + "", + ["--some_random_key some_random_value"], + ), + ] + ) + @patch("ads.model.datascience_model.DataScienceModel.from_id") + @patch("ads.aqua.modeldeployment.deployment.get_container_config") + def test_validate_deployment_params_for_unverified_models( + self, container_type_key, params, mock_get_container_config, mock_from_id + ): + """Test to check if container family is used when metadata does not have image information + for unverified models.""" + mock_model = MagicMock() + mock_model.custom_metadata_list = ModelCustomMetadata() + mock_from_id.return_value = mock_model + + container_index_json = os.path.join( + self.curr_dir, "test_data/ui/container_index.json" + ) + with open(container_index_json, "r") as _file: + container_index_config = json.load(_file) + mock_get_container_config.return_value = container_index_config + + if container_type_key in {"odsc-vllm-serving", "odsc-tgi-serving"} and params: + result = self.app.validate_deployment_params( + model_id="mock-model-id", + params=params, + container_family=container_type_key, + ) + assert result["valid"] is True + else: + with pytest.raises(AquaValueError): + self.app.validate_deployment_params( + model_id="mock-model-id", + params=params, + container_family=container_type_key, + ) + class TestMDInferenceResponse(unittest.TestCase): def setUp(self): diff --git a/tests/unitary/with_extras/aqua/test_deployment_handler.py b/tests/unitary/with_extras/aqua/test_deployment_handler.py index 91a5c8a0f..2da40cbcd 100644 --- a/tests/unitary/with_extras/aqua/test_deployment_handler.py +++ b/tests/unitary/with_extras/aqua/test_deployment_handler.py @@ -6,23 +6,25 @@ import os import unittest -from unittest.mock import MagicMock, patch from importlib import reload +from unittest.mock import MagicMock, patch +from parameterized import parameterized + from notebook.base.handlers import IPythonHandler -import pytest -import ads.config import ads.aqua +import ads.config from ads.aqua.extension.deployment_handler import ( AquaDeploymentHandler, AquaDeploymentInferenceHandler, + AquaDeploymentParamsHandler, ) -from ads.aqua.deployment import AquaDeploymentApp, MDInferenceResponse class TestDataset: USER_COMPARTMENT_ID = "ocid1.compartment.oc1.." USER_PROJECT_ID = "ocid1.datascienceproject.oc1.iad." + INSTANCE_SHAPE = "VM.GPU.A10.1" deployment_request = { "model_id": "ocid1.datasciencemodel.oc1.iad.", "instance_shape": "VM.GPU.A10.1", @@ -65,7 +67,7 @@ def tearDownClass(cls): reload(ads.aqua) reload(ads.aqua.extension.deployment_handler) - @patch("ads.aqua.deployment.AquaDeploymentApp.get_deployment_config") + @patch("ads.aqua.modeldeployment.AquaDeploymentApp.get_deployment_config") def test_get_deployment_config(self, mock_get_deployment_config): """Test get method to return deployment config""" self.deployment_handler.request.path = "aqua/deployments/config" @@ -83,14 +85,14 @@ def test_get_deployment_config_without_id(self, mock_error): mock_error.assert_called_once() assert result["status"] == 400 - @patch("ads.aqua.deployment.AquaDeploymentApp.get") + @patch("ads.aqua.modeldeployment.AquaDeploymentApp.get") def test_get_deployment(self, mock_get): """Test get method to return deployment information.""" self.deployment_handler.request.path = "aqua/deployments" self.deployment_handler.get(id="mock-model-id") mock_get.assert_called() - @patch("ads.aqua.deployment.AquaDeploymentApp.list") + @patch("ads.aqua.modeldeployment.AquaDeploymentApp.list") def test_list_deployment(self, mock_list): """Test get method to return a list of model deployments.""" self.deployment_handler.request.path = "aqua/deployments" @@ -99,7 +101,7 @@ def test_list_deployment(self, mock_list): compartment_id=TestDataset.USER_COMPARTMENT_ID, project_id=None ) - @patch("ads.aqua.deployment.AquaDeploymentApp.create") + @patch("ads.aqua.modeldeployment.AquaDeploymentApp.create") def test_post(self, mock_create): """Test post method to create a model deployment.""" self.deployment_handler.get_json_body = MagicMock( @@ -119,6 +121,70 @@ def test_post(self, mock_create): access_log_id=None, predict_log_id=None, bandwidth_mbps=None, + web_concurrency=None, + server_port=None, + health_check_port=None, + env_var=None, + container_family=None, + ) + + +class AquaDeploymentParamsHandlerTestCase(unittest.TestCase): + default_params = ["--seed 42", "--trust-remote-code"] + + @patch.object(IPythonHandler, "__init__") + def setUp(self, ipython_init_mock) -> None: + ipython_init_mock.return_value = None + self.test_instance = AquaDeploymentParamsHandler(MagicMock(), MagicMock()) + + @patch("notebook.base.handlers.APIHandler.finish") + @patch("ads.aqua.modeldeployment.AquaDeploymentApp.get_deployment_default_params") + def test_get_deployment_default_params( + self, mock_get_deployment_default_params, mock_finish + ): + """Test to check the handler get method to return default params for model deployment.""" + + mock_get_deployment_default_params.return_value = self.default_params + mock_finish.side_effect = lambda x: x + + args = {"instance_shape": TestDataset.INSTANCE_SHAPE} + self.test_instance.get_argument = MagicMock( + side_effect=lambda arg, default=None: args.get(arg, default) + ) + result = self.test_instance.get(model_id="test_model_id") + self.assertCountEqual(result["data"], self.default_params) + + mock_get_deployment_default_params.assert_called_with( + model_id="test_model_id", instance_shape=TestDataset.INSTANCE_SHAPE + ) + + @parameterized.expand( + [ + None, + "container-family-name", + ] + ) + @patch("notebook.base.handlers.APIHandler.finish") + @patch("ads.aqua.modeldeployment.AquaDeploymentApp.validate_deployment_params") + def test_validate_deployment_params( + self, container_family_value, mock_validate_deployment_params, mock_finish + ): + mock_validate_deployment_params.return_value = dict(valid=True) + mock_finish.side_effect = lambda x: x + + self.test_instance.get_json_body = MagicMock( + return_value=dict( + model_id="test-model-id", + params=self.default_params, + container_family=container_family_value, + ) + ) + result = self.test_instance.post() + assert result["valid"] is True + mock_validate_deployment_params.assert_called_with( + model_id="test-model-id", + params=self.default_params, + container_family=container_family_value, ) @@ -132,7 +198,7 @@ def setUp(self, ipython_init_mock) -> None: self.inference_handler.request = MagicMock() self.inference_handler.finish = MagicMock() - @patch("ads.aqua.deployment.MDInferenceResponse.get_model_deployment_response") + @patch("ads.aqua.modeldeployment.MDInferenceResponse.get_model_deployment_response") def test_post(self, mock_get_model_deployment_response): """Test post method to return model deployment response.""" self.inference_handler.get_json_body = MagicMock( diff --git a/tests/unitary/with_extras/aqua/test_evaluation.py b/tests/unitary/with_extras/aqua/test_evaluation.py index 5ea5b6e41..b75986d3d 100644 --- a/tests/unitary/with_extras/aqua/test_evaluation.py +++ b/tests/unitary/with_extras/aqua/test_evaluation.py @@ -15,21 +15,21 @@ import oci from parameterized import parameterized -from ads.aqua import utils -from ads.aqua.data import Tags -from ads.aqua.evaluation import ( - AquaEvalMetrics, - AquaEvalReport, - AquaEvaluationApp, - AquaEvaluationSummary, -) -from ads.aqua.exception import ( +from ads.aqua.common import utils +from ads.aqua.common.enums import Tags +from ads.aqua.common.errors import ( AquaFileNotFoundError, AquaMissingKeyError, AquaRuntimeError, ) +from ads.aqua.constants import EVALUATION_REPORT_JSON, EVALUATION_REPORT_MD, UNKNOWN +from ads.aqua.evaluation import AquaEvaluationApp +from ads.aqua.evaluation.entities import ( + AquaEvalMetrics, + AquaEvalReport, + AquaEvaluationSummary, +) from ads.aqua.extension.base_handler import AquaAPIhandler -from ads.aqua.utils import EVALUATION_REPORT_JSON, EVALUATION_REPORT_MD, UNKNOWN from ads.jobs.ads_job import DataScienceJob, DataScienceJobRun, Job from ads.model import DataScienceModel from ads.model.deployment.model_deployment import ModelDeployment @@ -423,7 +423,7 @@ def assert_payload(self, response, response_type): @patch("ads.jobs.ads_job.Job.name", new_callable=PropertyMock) @patch("ads.jobs.ads_job.Job.id", new_callable=PropertyMock) @patch.object(Job, "create") - @patch("ads.aqua.evaluation.get_container_image") + @patch("ads.aqua.evaluation.evaluation.get_container_image") @patch.object(DataScienceModel, "create") @patch.object(ModelVersionSet, "create") @patch.object(DataScienceModel, "from_id") @@ -527,36 +527,28 @@ def test_create_evaluation( def test_get_service_model_name(self): # get service model name from fine tuned model deployment - source = ( - ModelDeployment() - .with_freeform_tags( - **{ - Tags.AQUA_TAG.value: UNKNOWN, - Tags.AQUA_FINE_TUNED_MODEL_TAG.value: "test_service_model_id#test_service_model_name", - Tags.AQUA_MODEL_NAME_TAG.value: "test_fine_tuned_model_name" - } - ) + source = ModelDeployment().with_freeform_tags( + **{ + Tags.AQUA_TAG: UNKNOWN, + Tags.AQUA_FINE_TUNED_MODEL_TAG: "test_service_model_id#test_service_model_name", + Tags.AQUA_MODEL_NAME_TAG: "test_fine_tuned_model_name", + } ) service_model_name = self.app._get_service_model_name(source) assert service_model_name == "test_service_model_name" # get service model name from model deployment - source = ( - ModelDeployment() - .with_freeform_tags( - **{ - Tags.AQUA_TAG.value: "active", - Tags.AQUA_MODEL_NAME_TAG.value: "test_service_model_name" - } - ) + source = ModelDeployment().with_freeform_tags( + **{ + Tags.AQUA_TAG: "active", + Tags.AQUA_MODEL_NAME_TAG: "test_service_model_name", + } ) service_model_name = self.app._get_service_model_name(source) assert service_model_name == "test_service_model_name" # get service model name from service model - source = DataScienceModel( - display_name="test_service_model_name" - ) + source = DataScienceModel(display_name="test_service_model_name") service_model_name = self.app._get_service_model_name(source) assert service_model_name == "test_service_model_name" @@ -807,6 +799,7 @@ def test_get_status(self): @parameterized.expand( [ ( + "artifact_exist", dict( return_value=oci.response.Response( status=200, request=MagicMock(), headers=MagicMock(), data=None @@ -815,6 +808,7 @@ def test_get_status(self): "SUCCEEDED", ), ( + "artifact_missing", dict( side_effect=oci.exceptions.ServiceError( status=404, code=None, message="error test msg", headers={} @@ -825,7 +819,7 @@ def test_get_status(self): ] ) def test_get_status_when_missing_jobrun( - self, mock_head_model_artifact_response, expected_output + self, name, mock_head_model_artifact_response, expected_output ): """Tests getting evaluation status correctly when missing jobrun association.""" self.app.ds_client.get_model_provenance = MagicMock( @@ -839,13 +833,14 @@ def test_get_status_when_missing_jobrun( ) ) self.app._fetch_jobrun = MagicMock(return_value=None) - + self.app._deletion_cache.clear() self.app.ds_client.head_model_artifact = MagicMock( side_effect=mock_head_model_artifact_response.get("side_effect", None), return_value=mock_head_model_artifact_response.get("return_value", None), ) response = self.app.get_status(TestDataset.EVAL_ID) + self.app.ds_client.head_model_artifact.assert_called_with( model_id=TestDataset.EVAL_ID ) @@ -913,8 +908,8 @@ def setUp(self): def tearDown(self) -> None: self.app._eval_cache.clear() - @patch("ads.aqua.utils.query_resource") - @patch("ads.aqua.utils.query_resources") + @patch("ads.aqua.common.utils.query_resource") + @patch("ads.aqua.common.utils.query_resources") def test_skipping_fetch_jobrun(self, mock_query_resources, mock_query_resource): """Tests listing evalution.""" mock_query_resources.return_value = [ @@ -936,8 +931,8 @@ def test_skipping_fetch_jobrun(self, mock_query_resources, mock_query_resource): mock_query_resources.assert_called_once() mock_query_resource.assert_not_called() - @patch("ads.aqua.utils.query_resource") - @patch("ads.aqua.utils.query_resources") + @patch("ads.aqua.common.utils.query_resource") + @patch("ads.aqua.common.utils.query_resources") def test_error_in_fetch_job(self, mock_query_resources, mock_query_resource): """Tests when fetching job encounters error.""" mock_query_resources.return_value = [ @@ -956,7 +951,7 @@ def test_error_in_fetch_job(self, mock_query_resources, mock_query_resource): jobrun=None, ) - @patch("ads.aqua.utils.query_resources") + @patch("ads.aqua.common.utils.query_resources") def test_missing_info_in_custometadata(self, mock_query_resources): """Tests missing info in evaluation custom metadata.""" eval_without_meta = copy.deepcopy(TestDataset.resource_summary_object_eval[0]) @@ -985,14 +980,14 @@ def setUp(self): self.mock_model = DataScienceModel(id="model456") @patch.object(DataScienceJobRun, "cancel") - @patch("ads.aqua.evaluation.logger") + @patch("ads.aqua.evaluation.evaluation.logger") async def test_cancel(self, mock_logger, mock_cancel): await self.app._cancel_job_run(DataScienceJobRun(), self.mock_model) mock_cancel.assert_called_once() mock_logger.info.assert_called_once() - @patch("ads.aqua.evaluation.logger") + @patch("ads.aqua.evaluation.evaluation.logger") async def test_cancel_exception(self, mock_logger): mock_cancel = MagicMock( side_effect=oci.exceptions.ServiceError( @@ -1007,7 +1002,7 @@ async def test_cancel_exception(self, mock_logger): mock_cancel.assert_called_once() mock_logger.error.assert_called_once() - @patch("ads.aqua.evaluation.logger") + @patch("ads.aqua.evaluation.evaluation.logger") async def test_delete(self, mock_logger): mock_job = DataScienceJob() mock_job.dsc_job.delete = MagicMock() @@ -1019,7 +1014,7 @@ async def test_delete(self, mock_logger): self.mock_model.delete.assert_called_once() mock_logger.info.assert_called() - @patch("ads.aqua.evaluation.logger") + @patch("ads.aqua.evaluation.evaluation.logger") async def test_delete_exception(self, mock_logger): mock_job = DataScienceJob() mock_job.dsc_job.delete = MagicMock( diff --git a/tests/unitary/with_extras/aqua/test_evaluation_handler.py b/tests/unitary/with_extras/aqua/test_evaluation_handler.py index 4a9cb3135..6382c8d39 100644 --- a/tests/unitary/with_extras/aqua/test_evaluation_handler.py +++ b/tests/unitary/with_extras/aqua/test_evaluation_handler.py @@ -9,8 +9,9 @@ from notebook.base.handlers import IPythonHandler from parameterized import parameterized -from ads.aqua.evaluation import AquaEvaluationApp, CreateAquaEvaluationDetails -from ads.aqua.extension.base_handler import Errors +from ads.aqua.evaluation import AquaEvaluationApp +from ads.aqua.evaluation.entities import CreateAquaEvaluationDetails +from ads.aqua.extension.errors import Errors from ads.aqua.extension.evaluation_handler import AquaEvaluationHandler from tests.unitary.with_extras.aqua.utils import HandlerTestDataset as TestDataset diff --git a/tests/unitary/with_extras/aqua/test_finetuning.py b/tests/unitary/with_extras/aqua/test_finetuning.py index 61358c66a..93c32f114 100644 --- a/tests/unitary/with_extras/aqua/test_finetuning.py +++ b/tests/unitary/with_extras/aqua/test_finetuning.py @@ -4,22 +4,29 @@ # Copyright (c) 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -from dataclasses import asdict -from importlib import reload import os +import json +import pytest +from parameterized import parameterized from unittest import TestCase from unittest.mock import MagicMock, PropertyMock - from mock import patch -import ads.config +from dataclasses import asdict +from importlib import reload + import ads.aqua -import ads.aqua.finetune -from ads.aqua.base import AquaApp -from ads.aqua.finetune import AquaFineTuningApp, FineTuneCustomMetadata -from ads.aqua.model import AquaFineTuneModel +import ads.aqua.finetuning.finetuning +from ads.aqua.model.entities import AquaFineTuneModel +import ads.config +from ads.aqua.app import AquaApp +from ads.aqua.finetuning import AquaFineTuningApp +from ads.aqua.finetuning.constants import FineTuneCustomMetadata +from ads.aqua.finetuning.entities import AquaFineTuningParams from ads.jobs.ads_job import Job from ads.model.datascience_model import DataScienceModel from ads.model.model_metadata import ModelCustomMetadata +from ads.aqua.common.errors import AquaValueError +from ads.aqua.config.config import get_finetuning_config_defaults class FineTuningTestCase(TestCase): @@ -30,23 +37,25 @@ def setUp(self): @classmethod def setUpClass(cls): + cls.curr_dir = os.path.dirname(os.path.abspath(__file__)) os.environ["ODSC_MODEL_COMPARTMENT_OCID"] = cls.SERVICE_COMPARTMENT_ID reload(ads.config) reload(ads.aqua) - reload(ads.aqua.finetune) + reload(ads.aqua.finetuning.finetuning) @classmethod def tearDownClass(cls): + cls.curr_dir = None os.environ.pop("ODSC_MODEL_COMPARTMENT_OCID", None) reload(ads.config) reload(ads.aqua) - reload(ads.aqua.finetune) + reload(ads.aqua.finetuning.finetuning) @patch.object(Job, "run") @patch("ads.jobs.ads_job.Job.name", new_callable=PropertyMock) @patch("ads.jobs.ads_job.Job.id", new_callable=PropertyMock) @patch.object(Job, "create") - @patch("ads.aqua.finetune.get_container_image") + @patch("ads.aqua.finetuning.finetuning.get_container_image") @patch.object(AquaFineTuningApp, "get_finetuning_config") @patch.object(AquaApp, "create_model_catalog") @patch.object(AquaApp, "create_model_version_set") @@ -65,15 +74,15 @@ def test_create_fine_tuning( ): custom_metadata_list = ModelCustomMetadata() custom_metadata_list.add( - key=FineTuneCustomMetadata.SERVICE_MODEL_ARTIFACT_LOCATION.value, + key=FineTuneCustomMetadata.SERVICE_MODEL_ARTIFACT_LOCATION, value="test_service_model_artifact_location", ) custom_metadata_list.add( - key=FineTuneCustomMetadata.SERVICE_MODEL_DEPLOYMENT_CONTAINER.value, + key=FineTuneCustomMetadata.SERVICE_MODEL_DEPLOYMENT_CONTAINER, value="test_service_model_deployment_container", ) custom_metadata_list.add( - key=FineTuneCustomMetadata.SERVICE_MODEL_FINE_TUNE_CONTAINER.value, + key=FineTuneCustomMetadata.SERVICE_MODEL_FINE_TUNE_CONTAINER, value="test_service_model_fine_tune_container", ) @@ -116,7 +125,11 @@ def test_create_fine_tuning( ft_name="test_ft_name", dataset_path="oci://ds_bucket@namespace/prefix/dataset.jsonl", report_path="oci://report_bucket@namespace/prefix/", - ft_parameters={"epochs": 1, "learning_rate": 0.02}, + ft_parameters={ + "epochs": 1, + "learning_rate": 0.02, + "lora_target_linear": False, + }, shape_name="VM.GPU.A10.1", replica=1, validation_set_size=0.2, @@ -145,7 +158,9 @@ def test_create_fine_tuning( "parameters": { "epochs": 1, "learning_rate": 0.02, - "sample_packing": "True", + "sample_packing": "auto", + "batch_size": 1, + "lora_target_linear": False, }, "source": { "id": f"{ft_source.id}", @@ -193,3 +208,112 @@ def test_exit_code_message(self): message, "Job run could not be started due to service issues. Please try again later.", ) + + def test_build_oci_launch_cmd(self): + dataset_path = "oci://ds_bucket@namespace/prefix/dataset.jsonl" + report_path = "oci://report_bucket@namespace/prefix/" + val_set_size = 0.1 + parameters = AquaFineTuningParams( + batch_size=1, + epochs=1, + sample_packing="True", + learning_rate=0.01, + sequence_len=2, + lora_target_modules=["q_proj", "k_proj"], + ) + finetuning_params = "--trust_remote_code True" + oci_launch_cmd = self.app._build_oci_launch_cmd( + dataset_path=dataset_path, + report_path=report_path, + val_set_size=val_set_size, + parameters=parameters, + finetuning_params=finetuning_params, + ) + + assert ( + oci_launch_cmd + == f"--training_data {dataset_path} --output_dir {report_path} --val_set_size {val_set_size} --num_epochs {parameters.epochs} --learning_rate {parameters.learning_rate} --sample_packing {parameters.sample_packing} --micro_batch_size {parameters.batch_size} --sequence_len {parameters.sequence_len} --lora_target_modules q_proj,k_proj {finetuning_params}" + ) + + def test_get_finetuning_config(self): + """Test for fetching config details for a given model to be finetuned.""" + + config_json = os.path.join(self.curr_dir, "test_data/finetuning/ft_config.json") + with open(config_json, "r") as _file: + config = json.load(_file) + + self.app.get_config = MagicMock(return_value=config) + result = self.app.get_finetuning_config(model_id="test-model-id") + assert result == config + + self.app.get_config = MagicMock(return_value=None) + result = self.app.get_finetuning_config(model_id="test-model-id") + assert result == get_finetuning_config_defaults() + + def test_get_finetuning_default_params(self): + """Test for fetching finetuning config params for a given model.""" + + params_dict = { + "params": { + "batch_size": 1, + "sequence_len": 2048, + "sample_packing": True, + "pad_to_sequence_len": True, + "learning_rate": 0.0002, + "lora_r": 32, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "lora_target_modules": ["q_proj", "k_proj"], + } + } + config_json = os.path.join(self.curr_dir, "test_data/finetuning/ft_config.json") + with open(config_json, "r") as _file: + config = json.load(_file) + + self.app.get_finetuning_config = MagicMock(return_value=config) + result = self.app.get_finetuning_default_params(model_id="test_model_id") + assert result == params_dict + + # check when config json is not available + self.app.get_finetuning_config = MagicMock(return_value={}) + result = self.app.get_finetuning_default_params(model_id="test_model_id") + assert result == {"params": {}} + + @parameterized.expand( + [ + ( + { + "epochs": 1, + "learning_rate": 0.0002, + "batch_size": 1, + "sequence_len": 2048, + "sample_packing": True, + "pad_to_sequence_len": True, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "lora_target_modules": ["q_proj", " k_proj"], + }, + True, + ), + ( + { + "micro_batch_size": 1, + "max_sequence_len": 2048, + "flash_attention": True, + "pad_to_sequence_len": True, + "lr_scheduler": "cosine", + }, + False, + ), + ] + ) + def test_validate_finetuning_params(self, params, is_valid): + """Test for checking if overridden fine-tuning params are valid.""" + if is_valid: + result = self.app.validate_finetuning_params(params) + assert result["valid"] is True + else: + with pytest.raises(AquaValueError): + self.app.validate_finetuning_params(params) diff --git a/tests/unitary/with_extras/aqua/test_finetuning_handler.py b/tests/unitary/with_extras/aqua/test_finetuning_handler.py index 66c059bc5..03b437771 100644 --- a/tests/unitary/with_extras/aqua/test_finetuning_handler.py +++ b/tests/unitary/with_extras/aqua/test_finetuning_handler.py @@ -6,11 +6,16 @@ from unittest import TestCase from unittest.mock import MagicMock + from mock import patch +from notebook.base.handlers import APIHandler, IPythonHandler -from notebook.base.handlers import IPythonHandler -from ads.aqua.extension.finetune_handler import AquaFineTuneHandler -from ads.aqua.finetune import AquaFineTuningApp, CreateFineTuningDetails +from ads.aqua.extension.finetune_handler import ( + AquaFineTuneHandler, + AquaFineTuneParamsHandler, +) +from ads.aqua.finetuning import AquaFineTuningApp +from ads.aqua.finetuning.entities import CreateFineTuningDetails class TestDataset: @@ -19,10 +24,7 @@ class TestDataset: ft_name="test_ft_name", dataset_path="oci://ds_bucket@namespace/prefix/dataset.jsonl", report_path="oci://report_bucket@namespace/prefix/", - ft_parameters={ - "epochs":1, - "learning_rate":0.02 - }, + ft_parameters={"epochs": 1, "learning_rate": 0.02}, shape_name="VM.GPU.A10.1", replica=1, validation_set_size=0.2, @@ -32,16 +34,12 @@ class TestDataset: mock_finetuning_config = { "shape": { - "VM.GPU.A10.1": { - "batch_size": 1, - "replica": 1 - }, + "VM.GPU.A10.1": {"batch_size": 1, "replica": 1}, } } class FineTuningHandlerTestCase(TestCase): - @patch.object(IPythonHandler, "__init__") def setUp(self, ipython_init_mock) -> None: ipython_init_mock.return_value = None @@ -65,13 +63,11 @@ def test_get(self, mock_urlparse, mock_get_finetuning_config): @patch.object(AquaFineTuningApp, "create") def test_post(self, mock_create): self.test_instance.get_json_body = MagicMock( - return_value = TestDataset.mock_valid_input + return_value=TestDataset.mock_valid_input ) self.test_instance.post() - self.test_instance.finish.assert_called_with( - mock_create.return_value - ) + self.test_instance.finish.assert_called_with(mock_create.return_value) mock_create.assert_called_with( CreateFineTuningDetails(**TestDataset.mock_valid_input) ) @@ -82,9 +78,54 @@ def test_get_finetuning_config(self, mock_get_finetuning_config): self.test_instance.get_finetuning_config(model_id="test_model_id") - self.test_instance.finish.assert_called_with( - TestDataset.mock_finetuning_config - ) - mock_get_finetuning_config.assert_called_with( - model_id="test_model_id" + self.test_instance.finish.assert_called_with(TestDataset.mock_finetuning_config) + mock_get_finetuning_config.assert_called_with(model_id="test_model_id") + + +class AquaFineTuneParamsHandlerTestCase(TestCase): + default_params = [ + "--batch_size 1", + "--sequence_len 2048", + "--sample_packing true", + "--pad_to_sequence_len true", + "--learning_rate 0.0002", + "--lora_r 32", + "--lora_alpha 16", + "--lora_dropout 0.05", + "--lora_target_linear true", + "--lora_target_modules q_proj,k_proj", + ] + + @patch.object(IPythonHandler, "__init__") + def setUp(self, ipython_init_mock) -> None: + ipython_init_mock.return_value = None + self.test_instance = AquaFineTuneParamsHandler(MagicMock(), MagicMock()) + + @patch("notebook.base.handlers.APIHandler.finish") + @patch("ads.aqua.finetuning.AquaFineTuningApp.get_finetuning_default_params") + def test_get_finetuning_default_params( + self, mock_get_finetuning_default_params, mock_finish + ): + """Test to check the handler get method to return default params for fine-tuning job.""" + mock_get_finetuning_default_params.return_value = self.default_params + mock_finish.side_effect = lambda x: x + + result = self.test_instance.get(model_id="test_model_id") + self.assertCountEqual(result["data"], self.default_params) + + mock_get_finetuning_default_params.assert_called_with(model_id="test_model_id") + + @patch("notebook.base.handlers.APIHandler.finish") + @patch("ads.aqua.finetuning.AquaFineTuningApp.validate_finetuning_params") + def test_validate_finetuning_params( + self, mock_validate_finetuning_params, mock_finish + ): + mock_validate_finetuning_params.return_value = dict(valid=True) + mock_finish.side_effect = lambda x: x + + self.test_instance.get_json_body = MagicMock( + return_value=dict(params=self.default_params) ) + result = self.test_instance.post() + assert result["valid"] is True + mock_validate_finetuning_params.assert_called_with(params=self.default_params) diff --git a/tests/unitary/with_extras/aqua/test_handlers.py b/tests/unitary/with_extras/aqua/test_handlers.py index e74b99d4f..daaaa7c83 100644 --- a/tests/unitary/with_extras/aqua/test_handlers.py +++ b/tests/unitary/with_extras/aqua/test_handlers.py @@ -18,13 +18,13 @@ from tornado.web import Application, HTTPError import ads.aqua -import ads.aqua.exception +import ads.aqua.common.errors import ads.aqua.extension import ads.aqua.extension.common_handler import ads.config +from ads.aqua.common.errors import AquaError from ads.aqua.data import AquaResourceIdentifier from ads.aqua.evaluation import AquaEvaluationApp -from ads.aqua.exception import AquaError from ads.aqua.extension.base_handler import AquaAPIhandler from ads.aqua.extension.common_handler import ( ADSVersionHandler, diff --git a/tests/unitary/with_extras/aqua/test_model.py b/tests/unitary/with_extras/aqua/test_model.py index 42328e59a..833adb0e7 100644 --- a/tests/unitary/with_extras/aqua/test_model.py +++ b/tests/unitary/with_extras/aqua/test_model.py @@ -5,20 +5,43 @@ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import os -import unittest from dataclasses import asdict from importlib import reload -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch -from mock import patch import oci +import pytest from parameterized import parameterized import ads.aqua.model +from ads.aqua.model.entities import AquaModelSummary, ImportModelDetails, AquaModel +import ads.common +import ads.common.oci_client import ads.config -from ads.aqua.model import AquaModelApp, AquaModelSummary +from ads.aqua.model import AquaModelApp +from ads.common.object_storage_details import ObjectStorageDetails from ads.model.datascience_model import DataScienceModel -from ads.model.model_metadata import ModelCustomMetadata, ModelProvenanceMetadata, ModelTaxonomyMetadata +from ads.model.model_metadata import ( + ModelCustomMetadata, + ModelProvenanceMetadata, + ModelTaxonomyMetadata, +) +from ads.aqua.common.errors import AquaRuntimeError, AquaFileNotFoundError +from ads.model.service.oci_datascience_model import OCIDataScienceModel + + +@pytest.fixture(autouse=True, scope="class") +def mock_auth(): + with patch("ads.common.auth.default_signer") as mock_default_signer: + yield mock_default_signer + + +@pytest.fixture(autouse=True, scope="class") +def mock_init_client(): + with patch( + "ads.common.oci_datascience.OCIDataScienceMixin.init_client" + ) as mock_client: + yield mock_client class TestDataset: @@ -29,7 +52,7 @@ class TestDataset: "defined_tags": {}, "display_name": "Model1", "freeform_tags": { - "OCI_AQUA": "", + "OCI_AQUA": "active", "aqua_service_model": "ocid1.datasciencemodel.oc1.iad.#Model1", "license": "UPL", "organization": "Oracle AI", @@ -40,6 +63,23 @@ class TestDataset: "project_id": "ocid1.datascienceproject.oc1.iad.", "time_created": "2024-01-19T17:57:39.158000+00:00", }, + { + "compartment_id": "ocid1.compartment.oc1..", + "created_by": "ocid1.datasciencenotebooksession.oc1.iad.", + "defined_tags": {}, + "display_name": "VerifiedModel", + "freeform_tags": { + "OCI_AQUA": "", + "license": "UPL", + "organization": "Oracle AI", + "task": "text_generation", + "ready_to_import": "true", + }, + "id": "ocid1.datasciencemodel.oc1.iad.", + "lifecycle_state": "ACTIVE", + "project_id": "ocid1.datascienceproject.oc1.iad.", + "time_created": "2024-01-19T17:57:39.158000+00:00", + }, ] resource_summary_objects = [ @@ -70,47 +110,83 @@ class TestDataset: COMPARTMENT_ID = "ocid1.compartment.oc1.." -class TestAquaModel(unittest.TestCase): +@patch("ads.config.COMPARTMENT_OCID", "ocid1.compartment.oc1.") +@patch("ads.config.PROJECT_OCID", "ocid1.datascienceproject.oc1.iad.") +class TestAquaModel: """Contains unittests for AquaModelApp.""" - def setUp(self): + @pytest.fixture(autouse=True, scope="class") + def mock_auth(cls): + with patch("ads.common.auth.default_signer") as mock_default_signer: + yield mock_default_signer + + @pytest.fixture(autouse=True, scope="class") + def mock_init_client(cls): + with patch( + "ads.common.oci_datascience.OCIDataScienceMixin.init_client" + ) as mock_client: + yield mock_client + + def setup_method(self): + self.default_signer_patch = patch( + "ads.common.auth.default_signer", new_callable=MagicMock + ) + self.create_signer_patch = patch( + "ads.common.auth.APIKey.create_signer", new_callable=MagicMock + ) + self.validate_config_patch = patch( + "oci.config.validate_config", new_callable=MagicMock + ) + self.create_client_patch = patch( + "ads.common.oci_client.OCIClientFactory.create_client", + new_callable=MagicMock, + ) + self.mock_default_signer = self.default_signer_patch.start() + self.mock_create_signer = self.create_signer_patch.start() + self.mock_validate_config = self.validate_config_patch.start() + self.mock_create_client = self.create_client_patch.start() self.app = AquaModelApp() + def teardown_method(self): + self.default_signer_patch.stop() + self.create_signer_patch.stop() + self.validate_config_patch.stop() + self.create_client_patch.stop() + @classmethod - def setUpClass(cls): + def setup_class(cls): os.environ["CONDA_BUCKET_NS"] = "test-namespace" os.environ["ODSC_MODEL_COMPARTMENT_OCID"] = TestDataset.SERVICE_COMPARTMENT_ID reload(ads.config) reload(ads.aqua) - reload(ads.aqua.model) + reload(ads.aqua.model.model) @classmethod - def tearDownClass(cls): + def teardown_class(cls): os.environ.pop("CONDA_BUCKET_NS", None) os.environ.pop("ODSC_MODEL_COMPARTMENT_OCID", None) reload(ads.config) reload(ads.aqua) - reload(ads.aqua.model) + reload(ads.aqua.model.model) @patch.object(DataScienceModel, "create") @patch("ads.model.datascience_model.validate") @patch.object(DataScienceModel, "from_id") def test_create_model(self, mock_from_id, mock_validate, mock_create): mock_model = MagicMock() - mock_model.model_file_description = {"test_key":"test_value"} + mock_model.model_file_description = {"test_key": "test_value"} mock_model.display_name = "test_display_name" mock_model.description = "test_description" mock_model.freeform_tags = { - "OCI_AQUA":"ACTIVE", - "license":"test_license", - "organization":"test_organization", - "task":"test_task", - "ready_to_fine_tune":"true" + "OCI_AQUA": "ACTIVE", + "license": "test_license", + "organization": "test_organization", + "task": "test_task", + "ready_to_fine_tune": "true", } custom_metadata_list = ModelCustomMetadata() custom_metadata_list.add( - key="test_metadata_item_key", - value="test_metadata_item_value" + **{"key": "test_metadata_item_key", "value": "test_metadata_item_value"} ) mock_model.custom_metadata_list = custom_metadata_list mock_model.provenance_metadata = ModelProvenanceMetadata( @@ -118,7 +194,7 @@ def test_create_model(self, mock_from_id, mock_validate, mock_create): ) mock_from_id.return_value = mock_model - # will not copy service model + # will not copy service model self.app.create( model_id="test_model_id", project_id="test_project_id", @@ -137,33 +213,50 @@ def test_create_model(self, mock_from_id, mock_validate, mock_create): model = self.app.create( model_id="test_model_id", project_id="test_project_id", - compartment_id="test_compartment_id" + compartment_id="test_compartment_id", ) mock_from_id.assert_called_with("test_model_id") mock_validate.assert_called() - mock_create.assert_called_with( - model_by_reference=True - ) + mock_create.assert_called_with(model_by_reference=True) assert model.display_name == "test_display_name" assert model.description == "test_description" assert model.description == "test_description" assert model.freeform_tags == { - "OCI_AQUA":"ACTIVE", - "license":"test_license", - "organization":"test_organization", - "task":"test_task", - "ready_to_fine_tune":"true" + "OCI_AQUA": "ACTIVE", + "license": "test_license", + "organization": "test_organization", + "task": "test_task", + "ready_to_fine_tune": "true", } - assert model.custom_metadata_list.get( - "test_metadata_item_key" - ).value == "test_metadata_item_value" + assert ( + model.custom_metadata_list.get("test_metadata_item_key").value + == "test_metadata_item_value" + ) assert model.provenance_metadata.training_id == "test_training_id" - @patch("ads.aqua.model.read_file") + @pytest.mark.parametrize( + "foundation_model_type", + [ + "service", + "verified", + ], + ) + @patch("ads.aqua.model.model.read_file") @patch.object(DataScienceModel, "from_id") - def test_get_model_not_fine_tuned(self, mock_from_id, mock_read_file): + @patch( + "ads.aqua.model.model.get_artifact_path", + return_value="oci://bucket@namespace/prefix", + ) + def test_get_foundation_models( + self, + mock_get_artifact_path, + mock_from_id, + mock_read_file, + foundation_model_type, + mock_auth, + ): ds_model = MagicMock() ds_model.id = "test_id" ds_model.compartment_id = "test_compartment_id" @@ -171,55 +264,108 @@ def test_get_model_not_fine_tuned(self, mock_from_id, mock_read_file): ds_model.display_name = "test_display_name" ds_model.description = "test_description" ds_model.freeform_tags = { - "OCI_AQUA":"ACTIVE", - "license":"test_license", - "organization":"test_organization", - "task":"test_task" + "OCI_AQUA": "" if foundation_model_type == "verified" else "ACTIVE", + "license": "test_license", + "organization": "test_organization", + "task": "test_task", } + if foundation_model_type == "verified": + ds_model.freeform_tags["ready_to_import"] = "true" ds_model.time_created = "2024-01-19T17:57:39.158000+00:00" custom_metadata_list = ModelCustomMetadata() custom_metadata_list.add( - key="artifact_location", - value="oci://bucket@namespace/prefix/" + **{ + "key": "artifact_location", + "value": "oci://bucket@namespace/prefix/", + } + ) + custom_metadata_list.add( + **{ + "key": "deployment-container", + "value": "odsc-vllm-serving", + } + ) + custom_metadata_list.add( + **{ + "key": "evaluation-container", + "value": "odsc-llm-evaluate", + } + ) + custom_metadata_list.add( + **{ + "key": "finetune-container", + "value": "odsc-llm-fine-tuning", + } ) ds_model.custom_metadata_list = custom_metadata_list mock_from_id.return_value = ds_model mock_read_file.return_value = "test_model_card" - aqua_model = self.app.get(model_id="test_model_id") - - mock_from_id.assert_called_with("test_model_id") - mock_read_file.assert_called_with( - file_path="oci://bucket@namespace/prefix/README.md", - auth=self.app._auth, + model_id = ( + "verified_model_id" + if foundation_model_type == "verified" + else "service_model_id" ) + aqua_model = self.app.get(model_id=model_id) + + mock_from_id.assert_called_with(model_id) + + if foundation_model_type == "verified": + mock_read_file.assert_called_with( + file_path="oci://bucket@namespace/prefix/config/README.md", + auth=mock_auth(), + ) + else: + mock_read_file.assert_called_with( + file_path="oci://bucket@namespace/prefix/README.md", + auth=mock_auth(), + ) assert asdict(aqua_model) == { - 'compartment_id': f'{ds_model.compartment_id}', - 'console_link': ( - f'https://cloud.oracle.com/data-science/models/{ds_model.id}?region={self.app.region}', + "compartment_id": f"{ds_model.compartment_id}", + "console_link": ( + f"https://cloud.oracle.com/data-science/models/{ds_model.id}?region={self.app.region}", + ), + "icon": "", + "id": f"{ds_model.id}", + "is_fine_tuned_model": False, + "license": f'{ds_model.freeform_tags["license"]}', + "model_card": f"{mock_read_file.return_value}", + "name": f"{ds_model.display_name}", + "organization": f'{ds_model.freeform_tags["organization"]}', + "project_id": f"{ds_model.project_id}", + "ready_to_deploy": False if foundation_model_type == "verified" else True, + "ready_to_finetune": False, + "ready_to_import": True if foundation_model_type == "verified" else False, + "search_text": ( + ",test_license,test_organization,test_task,true" + if foundation_model_type == "verified" + else "ACTIVE,test_license,test_organization,test_task" ), - 'icon': '', - 'id': f'{ds_model.id}', - 'is_fine_tuned_model': False, - 'license': f'{ds_model.freeform_tags["license"]}', - 'model_card': f'{mock_read_file.return_value}', - 'name': f'{ds_model.display_name}', - 'organization': f'{ds_model.freeform_tags["organization"]}', - 'project_id': f'{ds_model.project_id}', - 'ready_to_deploy': True, - 'ready_to_finetune': False, - 'search_text': 'ACTIVE,test_license,test_organization,test_task', - 'tags': ds_model.freeform_tags, - 'task': f'{ds_model.freeform_tags["task"]}', - 'time_created': f'{ds_model.time_created}' + "tags": ds_model.freeform_tags, + "task": f'{ds_model.freeform_tags["task"]}', + "time_created": f"{ds_model.time_created}", + "inference_container": "odsc-vllm-serving", + "finetuning_container": "odsc-llm-fine-tuning", + "evaluation_container": "odsc-llm-evaluate", } - @patch("ads.aqua.utils.query_resource") - @patch("ads.aqua.model.read_file") + @patch("ads.aqua.common.utils.query_resource") + @patch("ads.aqua.model.model.read_file") @patch.object(DataScienceModel, "from_id") - def test_get_model_fine_tuned(self, mock_from_id, mock_read_file, mock_query_resource): + @patch( + "ads.aqua.model.model.get_artifact_path", + return_value="oci://bucket@namespace/prefix", + ) + def test_get_model_fine_tuned( + self, + mock_get_artifact_path, + mock_from_id, + mock_read_file, + mock_query_resource, + mock_auth, + ): ds_model = MagicMock() ds_model.id = "test_id" ds_model.compartment_id = "test_model_compartment_id" @@ -229,32 +375,48 @@ def test_get_model_fine_tuned(self, mock_from_id, mock_read_file, mock_query_res ds_model.model_version_set_id = "test_model_version_set_id" ds_model.model_version_set_name = "test_model_version_set_name" ds_model.freeform_tags = { - "OCI_AQUA":"ACTIVE", - "license":"test_license", - "organization":"test_organization", - "task":"test_task", - "aqua_fine_tuned_model":"test_finetuned_model" + "OCI_AQUA": "ACTIVE", + "license": "test_license", + "organization": "test_organization", + "task": "test_task", + "aqua_fine_tuned_model": "test_finetuned_model", } + self.app._service_model_details_cache.get = MagicMock(return_value=None) ds_model.time_created = "2024-01-19T17:57:39.158000+00:00" ds_model.lifecycle_state = "ACTIVE" custom_metadata_list = ModelCustomMetadata() custom_metadata_list.add( - key="artifact_location", - value="oci://bucket@namespace/prefix/" + **{"key": "artifact_location", "value": "oci://bucket@namespace/prefix/"} + ) + custom_metadata_list.add( + **{"key": "fine_tune_source", "value": "test_fine_tuned_source_id"} + ) + custom_metadata_list.add( + **{"key": "fine_tune_source_name", "value": "test_fine_tuned_source_name"} + ) + custom_metadata_list.add( + **{ + "key": "deployment-container", + "value": "odsc-vllm-serving", + } ) custom_metadata_list.add( - key="fine_tune_source", - value="test_fine_tuned_source_id" + **{ + "key": "evaluation-container", + "value": "odsc-llm-evaluate", + } ) custom_metadata_list.add( - key="fine_tune_source_name", - value="test_fine_tuned_source_name" + **{ + "key": "finetune-container", + "value": "odsc-llm-fine-tuning", + } ) ds_model.custom_metadata_list = custom_metadata_list defined_metadata_list = ModelTaxonomyMetadata() defined_metadata_list["Hyperparameters"].value = { - "training_data" : "test_training_data", - "val_set_size" : "test_val_set_size" + "training_data": "test_training_data", + "val_set_size": "test_val_set_size", } ds_model.defined_metadata_list = defined_metadata_list ds_model.provenance_metadata = ModelProvenanceMetadata( @@ -269,26 +431,24 @@ def test_get_model_fine_tuned(self, mock_from_id, mock_read_file, mock_query_res job_run.id = "test_job_run_id" job_run.lifecycle_state = "SUCCEEDED" job_run.lifecycle_details = "test lifecycle details" - job_run.identifier = "test_job_id", + job_run.identifier = ("test_job_id",) job_run.display_name = "test_job_name" job_run.compartment_id = "test_job_run_compartment_id" job_infrastructure_configuration_details = MagicMock() job_infrastructure_configuration_details.shape_name = "test_shape_name" job_configuration_override_details = MagicMock() - job_configuration_override_details.environment_variables = { - "NODE_COUNT" : 1 - } - job_run.job_infrastructure_configuration_details = job_infrastructure_configuration_details + job_configuration_override_details.environment_variables = {"NODE_COUNT": 1} + job_run.job_infrastructure_configuration_details = ( + job_infrastructure_configuration_details + ) job_run.job_configuration_override_details = job_configuration_override_details log_details = MagicMock() log_details.log_id = "test_log_id" log_details.log_group_id = "test_log_group_id" job_run.log_details = log_details response.data = job_run - self.app.ds_client.get_job_run = MagicMock( - return_value = response - ) + self.app.ds_client.get_job_run = MagicMock(return_value=response) query_resource = MagicMock() query_resource.display_name = "test_display_name" @@ -299,84 +459,349 @@ def test_get_model_fine_tuned(self, mock_from_id, mock_read_file, mock_query_res mock_from_id.assert_called_with("test_model_id") mock_read_file.assert_called_with( file_path="oci://bucket@namespace/prefix/README.md", - auth=self.app._auth, + auth=mock_auth(), ) mock_query_resource.assert_called() assert asdict(model) == { - 'compartment_id': f'{ds_model.compartment_id}', - 'console_link': ( - f'https://cloud.oracle.com/data-science/models/{ds_model.id}?region={self.app.region}', + "compartment_id": f"{ds_model.compartment_id}", + "console_link": ( + f"https://cloud.oracle.com/data-science/models/{ds_model.id}?region={self.app.region}", ), - 'dataset': 'test_training_data', - 'experiment': {'id': '', 'name': '', 'url': ''}, - 'icon': '', - 'id': f'{ds_model.id}', - 'is_fine_tuned_model': True, - 'job': {'id': '', 'name': '', 'url': ''}, - 'license': 'test_license', - 'lifecycle_details': f'{job_run.lifecycle_details}', - 'lifecycle_state': f'{ds_model.lifecycle_state}', - 'log': { - 'id': f'{log_details.log_id}', - 'name': f'{query_resource.display_name}', - 'url': 'https://cloud.oracle.com/logging/search?searchQuery=search ' - f'"{job_run.compartment_id}/{log_details.log_group_id}/{log_details.log_id}" | ' - f"source='{job_run.id}' | sort by datetime desc®ions={self.app.region}" - }, - 'log_group': { - 'id': f'{log_details.log_group_id}', - 'name': f'{query_resource.display_name}', - 'url': f'https://cloud.oracle.com/logging/log-groups/{log_details.log_group_id}?region={self.app.region}' - }, - 'metrics': [ - { - 'category': 'validation', - 'name': 'validation_metrics', - 'scores': [] - }, + "dataset": "test_training_data", + "experiment": {"id": "", "name": "", "url": ""}, + "icon": "", + "id": f"{ds_model.id}", + "is_fine_tuned_model": True, + "job": {"id": "", "name": "", "url": ""}, + "license": "test_license", + "lifecycle_details": f"{job_run.lifecycle_details}", + "lifecycle_state": f"{ds_model.lifecycle_state}", + "log": { + "id": f"{log_details.log_id}", + "name": f"{query_resource.display_name}", + "url": "https://cloud.oracle.com/logging/search?searchQuery=search " + f'"{job_run.compartment_id}/{log_details.log_group_id}/{log_details.log_id}" | ' + f"source='{job_run.id}' | sort by datetime desc®ions={self.app.region}", + }, + "log_group": { + "id": f"{log_details.log_group_id}", + "name": f"{query_resource.display_name}", + "url": f"https://cloud.oracle.com/logging/log-groups/{log_details.log_group_id}?region={self.app.region}", + }, + "metrics": [ + {"category": "validation", "name": "validation_metrics", "scores": []}, + {"category": "training", "name": "training_metrics", "scores": []}, { - 'category': 'training', - 'name': 'training_metrics', - 'scores': [] + "category": "validation", + "name": "validation_metrics_final", + "scores": [], }, { - 'category': 'validation', - 'name': 'validation_metrics_final', - 'scores': [] + "category": "training", + "name": "training_metrics_final", + "scores": [], }, - { - 'category': 'training', - 'name': 'training_metrics_final', - 'scores': [] - } ], - 'model_card': f'{mock_read_file.return_value}', - 'name': f'{ds_model.display_name}', - 'organization': 'test_organization', - 'project_id': f'{ds_model.project_id}', - 'ready_to_deploy': True, - 'ready_to_finetune': False, - 'search_text': 'ACTIVE,test_license,test_organization,test_task,test_finetuned_model', - 'shape_info': { - 'instance_shape': f'{job_infrastructure_configuration_details.shape_name}', - 'replica': 1, + "model_card": f"{mock_read_file.return_value}", + "name": f"{ds_model.display_name}", + "organization": "test_organization", + "project_id": f"{ds_model.project_id}", + "ready_to_deploy": True, + "ready_to_finetune": False, + "ready_to_import": False, + "search_text": "ACTIVE,test_license,test_organization,test_task,test_finetuned_model", + "shape_info": { + "instance_shape": f"{job_infrastructure_configuration_details.shape_name}", + "replica": 1, }, - 'source': {'id': '', 'name': '', 'url': ''}, - 'tags': ds_model.freeform_tags, - 'task': 'test_task', - 'time_created': f'{ds_model.time_created}', - 'validation': { - 'type': 'Automatic split', - 'value': 'test_val_set_size' - } + "source": {"id": "", "name": "", "url": ""}, + "tags": ds_model.freeform_tags, + "task": "test_task", + "time_created": f"{ds_model.time_created}", + "validation": {"type": "Automatic split", "value": "test_val_set_size"}, + "inference_container": "odsc-vllm-serving", + "finetuning_container": "odsc-llm-fine-tuning", + "evaluation_container": "odsc-llm-evaluate", } - @patch("ads.aqua.model.read_file") - @patch("ads.aqua.model.get_artifact_path") + @pytest.mark.parametrize( + "artifact_location_set", + [ + True, + False, + ], + ) + @patch("ads.aqua.common.utils.copy_file") + @patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects") + @patch("ads.aqua.common.utils.load_config", return_value={}) + def test_import_verified_model( + self, + mock_load_config, + mock_list_objects, + mock_copy_file, + artifact_location_set, + ): + ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) + ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock() + DataScienceModel.upload_artifact = MagicMock() + DataScienceModel.sync = MagicMock() + OCIDataScienceModel.create = MagicMock() + + # The name attribute cannot be mocked during creation of the mock object, + # hence attach it separately to the mocked objects. + artifact_path = "service_models/model-name/commit-id/artifact" + obj1 = MagicMock(etag="12345-1234-1234-1234-123456789", size=150) + obj1.name = f"{artifact_path}/config/deployment_config.json" + obj2 = MagicMock(etag="12345-1234-1234-1234-123456789", size=150) + obj2.name = f"{artifact_path}/config/ft_config.json" + objects = [obj1, obj2] + mock_list_objects.return_value = MagicMock(objects=objects) + + ds_model = DataScienceModel() + os_path = "oci://aqua-bkt@aqua-ns/prefix/path" + model_name = "oracle/aqua-1t-mega-model" + ds_freeform_tags = { + "OCI_AQUA": "ACTIVE", + "license": "aqua-license", + "organization": "oracle", + "task": "text-generation", + "ready_to_import": "true", + } + ds_model = ( + ds_model.with_compartment_id("test_model_compartment_id") + .with_project_id("test_project_id") + .with_display_name(model_name) + .with_description("test_description") + .with_model_version_set_id("test_model_version_set_id") + .with_freeform_tags(**ds_freeform_tags) + .with_version_id("ocid1.blah.blah") + ) + custom_metadata_list = ModelCustomMetadata() + custom_metadata_list.add( + **{"key": "deployment-container", "value": "odsc-tgi-serving"} + ) + custom_metadata_list.add( + **{"key": "evaluation-container", "value": "odsc-llm-evaluate"} + ) + if not artifact_location_set: + custom_metadata_list.add( + **{ + "key": "artifact_location", + "value": artifact_path, + "description": "artifact location", + } + ) + ds_model.with_custom_metadata_list(custom_metadata_list) + ds_model.set_spec(ds_model.CONST_MODEL_FILE_DESCRIPTION, {}) + ds_model.dsc_model = MagicMock(id="test_model_id") + DataScienceModel.from_id = MagicMock(return_value=ds_model) + reload(ads.aqua.model.model) + app = AquaModelApp() + model: AquaModel = app.register( + model="ocid1.datasciencemodel.xxx.xxxx.", + os_path=os_path, + ) + if not artifact_location_set: + mock_copy_file.assert_called() + ds_freeform_tags.pop( + "ready_to_import" + ) # The imported model should not have this tag + assert model.tags == { + "aqua_custom_base_model": "true", + "aqua_service_model": "test_model_id", + **ds_freeform_tags, + } + mock_load_config.assert_called() + + assert model.inference_container == "odsc-tgi-serving" + assert model.finetuning_container is None + assert model.evaluation_container == "odsc-llm-evaluate" + assert model.ready_to_import is False + assert model.ready_to_deploy is True + assert model.ready_to_finetune is False + + @patch("ads.aqua.common.utils.load_config", return_value={}) + def test_import_any_model_no_containers_specified(self, mock_load_config): + ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) + ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock() + DataScienceModel.upload_artifact = MagicMock() + DataScienceModel.sync = MagicMock() + OCIDataScienceModel.create = MagicMock() + + ds_model = DataScienceModel() + os_path = "oci://aqua-bkt@aqua-ns/prefix/path" + model_name = "oracle/aqua-1t-mega-model" + ds_freeform_tags = { + "OCI_AQUA": "ACTIVE", + "license": "aqua-license", + "organization": "oracle", + "task": "text-generation", + } + + reload(ads.aqua.model.model) + app = AquaModelApp() + with pytest.raises(AquaRuntimeError): + with patch.object(AquaModelApp, "list") as aqua_model_mock_list: + aqua_model_mock_list.return_value = [ + AquaModelSummary( + id="test_id1", + name="organization1/name1", + organization="organization1", + ), + ] + model: DataScienceModel = app.register( + model=model_name, + os_path=os_path, + ) + + @patch("ads.aqua.common.utils.load_config", return_value={}) + def test_import_model_with_project_compartment_override(self, mock_load_config): + ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) + ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock() + DataScienceModel.upload_artifact = MagicMock() + DataScienceModel.sync = MagicMock() + OCIDataScienceModel.create = MagicMock() + + ds_model = DataScienceModel() + os_path = "oci://aqua-bkt@aqua-ns/prefix/path" + model_name = "oracle/aqua-1t-mega-model" + compartment_override = "my.blah.compartment" + project_override = "my.blah.project" + ds_freeform_tags = { + "OCI_AQUA": "ACTIVE", + "license": "aqua-license", + "organization": "oracle", + "task": "text-generation", + } + ds_model = ( + ds_model.with_compartment_id("test_model_compartment_id") + .with_project_id("test_project_id") + .with_display_name(model_name) + .with_description("test_description") + .with_model_version_set_id("test_model_version_set_id") + .with_freeform_tags(**ds_freeform_tags) + .with_version_id("ocid1.blah.blah") + ) + custom_metadata_list = ModelCustomMetadata() + custom_metadata_list.add( + **{"key": "deployment-container", "value": "odsc-tgi-serving"} + ) + custom_metadata_list.add( + **{"key": "evaluation-container", "value": "odsc-llm-evaluate"} + ) + ds_model.with_custom_metadata_list(custom_metadata_list) + ds_model.set_spec(ds_model.CONST_MODEL_FILE_DESCRIPTION, {}) + DataScienceModel.from_id = MagicMock(return_value=ds_model) + reload(ads.aqua.model.model) + app = AquaModelApp() + model: AquaModel = app.register( + compartment_id=compartment_override, + project_id=project_override, + model="ocid1.datasciencemodel.xxx.xxxx.", + os_path=os_path, + ) + assert model.compartment_id == compartment_override + assert model.project_id == project_override + + @patch("ads.aqua.common.utils.load_config", side_effect=AquaFileNotFoundError) + def test_import_model_with_missing_config(self, mock_load_config): + """Test for validating if error is returned when model artifacts are incomplete or not available.""" + os_path = "oci://aqua-bkt@aqua-ns/prefix/path" + model_name = "oracle/aqua-1t-mega-model" + reload(ads.aqua.model.model) + app = AquaModelApp() + with pytest.raises(AquaRuntimeError): + model: AquaModel = app.register( + model=model_name, + os_path=os_path, + ) + + @patch("ads.aqua.common.utils.load_config", return_value={}) + def test_import_any_model_smc_container( + self, + mock_load_config, + ): + my_model = "oracle/aqua-1t-mega-model" + ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) + ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock() + DataScienceModel.upload_artifact = MagicMock() + DataScienceModel.sync = MagicMock() + OCIDataScienceModel.create = MagicMock() + + os_path = "oci://aqua-bkt@aqua-ns/prefix/path" + ds_freeform_tags = { + "OCI_AQUA": "active", + } + + reload(ads.aqua.model.model) + app = AquaModelApp() + with patch.object(AquaModelApp, "list") as aqua_model_mock_list: + aqua_model_mock_list.return_value = [ + AquaModelSummary( + id="test_id1", + name="organization1/name1", + organization="organization1", + ), + AquaModelSummary( + id="test_id2", + name="organization1/name2", + organization="organization1", + ), + AquaModelSummary( + id="test_id3", + name="organization2/name3", + organization="organization2", + ), + ] + model: AquaModel = app.register( + model=my_model, + os_path=os_path, + inference_container="odsc-vllm-or-tgi-container", + finetuning_container="odsc-llm-fine-tuning", + ) + assert model.tags == { + "aqua_custom_base_model": "true", + "ready_to_fine_tune": "true", + **ds_freeform_tags, + } + assert model.inference_container == "odsc-vllm-or-tgi-container" + assert model.finetuning_container == "odsc-llm-fine-tuning" + assert model.evaluation_container == "odsc-llm-evaluate" + assert model.ready_to_import is False + assert model.ready_to_deploy is True + assert model.ready_to_finetune is True + + @parameterized.expand( + [ + ( + { + "os_path": "oci://aqua-bkt@aqua-ns/path", + "model": "oracle/oracle-1it", + "inference_container": "odsc-vllm-serving", + }, + f"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --inference_container odsc-vllm-serving", + ), + ( + { + "os_path": "oci://aqua-bkt@aqua-ns/path", + "model": "ocid1.datasciencemodel.oc1.iad.", + }, + f"ads aqua model register --model ocid1.datasciencemodel.oc1.iad. --os_path oci://aqua-bkt@aqua-ns/path", + ), + ] + ) + def test_import_cli(self, data, expected_output): + import_details = ImportModelDetails(**data) + assert import_details.build_cli() == expected_output + + @patch("ads.aqua.model.model.read_file") + @patch("ads.aqua.model.model.get_artifact_path") def test_load_license(self, mock_get_artifact_path, mock_read_file): self.app.ds_client.get_model = MagicMock() - mock_get_artifact_path.return_value = "oci://bucket@namespace/prefix/config/LICENSE.txt" + mock_get_artifact_path.return_value = ( + "oci://bucket@namespace/prefix/config/LICENSE.txt" + ) mock_read_file.return_value = "test_license" license = self.app.load_license(model_id="test_model_id") @@ -384,9 +809,7 @@ def test_load_license(self, mock_get_artifact_path, mock_read_file): mock_get_artifact_path.assert_called() mock_read_file.assert_called() - assert asdict(license) == { - 'id': 'test_model_id', 'license': 'test_license' - } + assert asdict(license) == {"id": "test_model_id", "license": "test_license"} def test_list_service_models(self): """Tests listing service models succesfully.""" @@ -403,7 +826,7 @@ def test_list_service_models(self): received_args = self.app.list_resource.call_args.kwargs assert received_args.get("compartment_id") == TestDataset.SERVICE_COMPARTMENT_ID - assert len(results) == 1 + assert len(results) == 2 attributes = AquaModelSummary.__annotations__.keys() for r in results: @@ -426,7 +849,7 @@ def test_list_custom_models(self): results = self.app.list(TestDataset.COMPARTMENT_ID) - self.app._rqs.assert_called_with(TestDataset.COMPARTMENT_ID) + self.app._rqs.assert_called_with(TestDataset.COMPARTMENT_ID, model_type="FT") assert len(results) == 1 diff --git a/tests/unitary/with_extras/aqua/test_model_handler.py b/tests/unitary/with_extras/aqua/test_model_handler.py index 5bee1d954..78a6d5e5c 100644 --- a/tests/unitary/with_extras/aqua/test_model_handler.py +++ b/tests/unitary/with_extras/aqua/test_model_handler.py @@ -5,23 +5,23 @@ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ from unittest import TestCase -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch -from mock import patch from notebook.base.handlers import IPythonHandler - -from ads.aqua.extension.model_handler import AquaModelHandler, AquaModelLicenseHandler +from ads.aqua.extension.model_handler import ( + AquaModelHandler, + AquaModelLicenseHandler, +) from ads.aqua.model import AquaModelApp +from ads.aqua.model.entities import AquaModel class ModelHandlerTestCase(TestCase): - @patch.object(IPythonHandler, "__init__") def setUp(self, ipython_init_mock) -> None: ipython_init_mock.return_value = None self.model_handler = AquaModelHandler(MagicMock(), MagicMock()) self.model_handler.request = MagicMock() - self.model_handler.finish = MagicMock() @patch.object(AquaModelHandler, "list") def test_get_no_id(self, mock_list): @@ -35,37 +35,78 @@ def test_get_with_id(self, mock_read): @patch.object(AquaModelApp, "get") def test_read(self, mock_get): - self.model_handler.read(model_id="test_model_id") - self.model_handler.finish.assert_called_with( - mock_get.return_value - ) - mock_get.assert_called_with("test_model_id") + with patch( + "ads.aqua.extension.base_handler.AquaAPIhandler.finish" + ) as mock_finish: + mock_finish.side_effect = lambda x: x + self.model_handler.read(model_id="test_model_id") + mock_get.assert_called_with("test_model_id") @patch.object(AquaModelApp, "clear_model_list_cache") @patch("ads.aqua.extension.model_handler.urlparse") def test_delete(self, mock_urlparse, mock_clear_model_list_cache): request_path = MagicMock(path="aqua/model/cache") mock_urlparse.return_value = request_path - - self.model_handler.delete() - self.model_handler.finish.assert_called_with( - mock_clear_model_list_cache.return_value - ) - - mock_urlparse.assert_called() - mock_clear_model_list_cache.assert_called() + mock_clear_model_list_cache.return_value = { + "key": { + "compartment_id": "test-compartment-ocid", + }, + "cache_deleted": True, + } + + with patch( + "ads.aqua.extension.base_handler.AquaAPIhandler.finish" + ) as mock_finish: + mock_finish.side_effect = lambda x: x + result = self.model_handler.delete() + assert result["cache_deleted"] is True + mock_urlparse.assert_called() + mock_clear_model_list_cache.assert_called() @patch.object(AquaModelApp, "list") def test_list(self, mock_list): - self.model_handler.list() - - self.model_handler.finish.assert_called_with( - mock_list.return_value + with patch( + "ads.aqua.extension.base_handler.AquaAPIhandler.finish" + ) as mock_finish: + mock_finish.side_effect = lambda x: x + self.model_handler.list() + mock_list.assert_called_with( + compartment_id=None, project_id=None, model_type=None + ) + + @patch("notebook.base.handlers.APIHandler.finish") + @patch("ads.aqua.model.AquaModelApp.register") + def test_register(self, mock_register, mock_finish): + mock_register.return_value = AquaModel( + id="test_id", + inference_container="odsc-tgi-serving", + evaluation_container="odsc-llm-evaluate", + ) + mock_finish.side_effect = lambda x: x + + self.model_handler.get_json_body = MagicMock( + return_value=dict( + model="test_model_name", + os_path="test_os_path", + inference_container="odsc-tgi-serving", + ) ) - mock_list.assert_called_with(None, None) + result = self.model_handler.post() + mock_register.assert_called_with( + model="test_model_name", + os_path="test_os_path", + inference_container="odsc-tgi-serving", + finetuning_container=None, + compartment_id=None, + project_id=None, + ) + assert result["id"] == "test_id" + assert result["inference_container"] == "odsc-tgi-serving" + assert result["evaluation_container"] == "odsc-llm-evaluate" + assert result["finetuning_container"] is None + class ModelLicenseHandlerTestCase(TestCase): - @patch.object(IPythonHandler, "__init__") def setUp(self, ipython_init_mock) -> None: ipython_init_mock.return_value = None diff --git a/tests/unitary/with_extras/aqua/test_ui.py b/tests/unitary/with_extras/aqua/test_ui.py index dc4966559..fdcc9ac92 100644 --- a/tests/unitary/with_extras/aqua/test_ui.py +++ b/tests/unitary/with_extras/aqua/test_ui.py @@ -4,19 +4,21 @@ # Copyright (c) 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -import os +import hashlib import json +import os import unittest from importlib import reload from unittest.mock import MagicMock, patch + +import oci import pytest from parameterized import parameterized -import oci import ads.config +from ads.aqua.common.errors import AquaValueError +from ads.aqua.common.utils import load_config from ads.aqua.ui import AquaUIApp -from ads.aqua.exception import AquaValueError -from ads.aqua.utils import load_config from ads.config import AQUA_CONFIG_FOLDER, AQUA_RESOURCE_LIMIT_NAMES_CONFIG @@ -453,3 +455,57 @@ def test_is_bucket_versioned(self, versioned, mock_from_path): mock_from_path.return_value.is_bucket_versioned.return_value = versioned result = self.app.is_bucket_versioned("oci://bucket-name-@namespace/prefix") assert result["is_versioned"] == versioned + + @patch("ads.aqua.ui.get_container_config") + def test_list_containers(self, mock_get_container_config): + """Test to lists AQUA containers.""" + + with open( + os.path.join(self.curr_dir, "test_data/ui/container_index.json"), "r" + ) as _file: + container_index_json = json.load(_file) + + mock_get_container_config.return_value = container_index_json + + test_result = self.app.list_containers().to_dict() + + expected_result = { + "inference": [ + { + "name": "dsmc://odsc-tgi-serving", + "version": "1.4.5", + "display_name": "TGI:1.4.5", + "family": "odsc-tgi-serving", + }, + { + "name": "dsmc://odsc-tgi-serving", + "version": "2.0.2", + "display_name": "TGI:2.0.2", + "family": "odsc-tgi-serving", + }, + { + "name": "dsmc://odsc-vllm-serving", + "version": "0.3.0.7", + "display_name": "VLLM:0.3.0", + "family": "odsc-vllm-serving", + }, + ], + "finetune": [ + { + "name": "dsmc://odsc-llm-fine-tuning", + "version": "1.1.33.34", + "display_name": "1.1.33.34", + "family": "odsc-llm-fine-tuning", + } + ], + "evaluate": [ + { + "name": "dsmc://odsc-llm-evaluate", + "version": "0.1.2.0", + "display_name": "0.1.2.0", + "family": "odsc-llm-evaluate", + } + ], + } + + assert test_result == expected_result diff --git a/tests/unitary/with_extras/aqua/test_ui_handler.py b/tests/unitary/with_extras/aqua/test_ui_handler.py index 7fefe8950..687536936 100644 --- a/tests/unitary/with_extras/aqua/test_ui_handler.py +++ b/tests/unitary/with_extras/aqua/test_ui_handler.py @@ -6,16 +6,16 @@ import os import unittest -from unittest.mock import MagicMock, patch from importlib import reload +from unittest.mock import MagicMock, patch + +from notebook.base.handlers import IPythonHandler from parameterized import parameterized -import ads.config import ads.aqua -from notebook.base.handlers import IPythonHandler +import ads.config +from ads.aqua.common.enums import Tags from ads.aqua.extension.ui_handler import AquaUIHandler -from ads.aqua.ui import AquaUIApp -from ads.aqua.data import Tags class TestDataset: @@ -48,7 +48,7 @@ def tearDownClass(cls): reload(ads.aqua) reload(ads.aqua.extension.ui_handler) - @patch.object(AquaUIApp, "list_log_groups") + @patch("ads.aqua.ui.AquaUIApp.list_log_groups") def test_list_log_groups(self, mock_list_log_groups): """Test get method to fetch log groups""" self.ui_handler.request.path = "aqua/logging" @@ -57,49 +57,56 @@ def test_list_log_groups(self, mock_list_log_groups): compartment_id=TestDataset.USER_COMPARTMENT_ID ) - @patch.object(AquaUIApp, "list_logs") + @patch("ads.aqua.ui.AquaUIApp.list_logs") def test_list_logs(self, mock_list_logs): """Test get method to fetch logs for a given log group.""" self.ui_handler.request.path = "aqua/logging" self.ui_handler.get(id="mock-log-id") mock_list_logs.assert_called_with(log_group_id="mock-log-id") - @patch.object(AquaUIApp, "list_compartments") + @patch("ads.aqua.ui.AquaUIApp.list_compartments") def test_list_compartments(self, mock_list_compartments): """Test get method to fetch list of compartments.""" self.ui_handler.request.path = "aqua/compartments" self.ui_handler.get() mock_list_compartments.assert_called() - @patch.object(AquaUIApp, "get_default_compartment") + @patch("ads.aqua.ui.AquaUIApp.list_containers") + def test_list_containers(self, mock_list_containers): + """Test get method to fetch list of containers.""" + self.ui_handler.request.path = "aqua/containers" + self.ui_handler.get() + mock_list_containers.assert_called() + + @patch("ads.aqua.ui.AquaUIApp.get_default_compartment") def test_get_default_compartment(self, mock_get_default_compartment): """Test get method to fetch default compartment.""" self.ui_handler.request.path = "aqua/compartments/default" self.ui_handler.get() mock_get_default_compartment.assert_called() - @patch.object(AquaUIApp, "list_model_version_sets") + @patch("ads.aqua.ui.AquaUIApp.list_model_version_sets") def test_list_experiments(self, mock_list_experiments): """Test get method to fetch list of experiments.""" self.ui_handler.request.path = "aqua/experiment" self.ui_handler.get() mock_list_experiments.assert_called_with( compartment_id=TestDataset.USER_COMPARTMENT_ID, - target_tag=Tags.AQUA_EVALUATION.value, + target_tag=Tags.AQUA_EVALUATION, ) - @patch.object(AquaUIApp, "list_model_version_sets") + @patch("ads.aqua.ui.AquaUIApp.list_model_version_sets") def test_list_model_version_sets(self, mock_list_model_version_sets): """Test get method to fetch version sets.""" self.ui_handler.request.path = "aqua/versionsets" self.ui_handler.get() mock_list_model_version_sets.assert_called_with( compartment_id=TestDataset.USER_COMPARTMENT_ID, - target_tag=Tags.AQUA_FINE_TUNING.value, + target_tag=Tags.AQUA_FINE_TUNING, ) @parameterized.expand(["true", ""]) - @patch.object(AquaUIApp, "list_buckets") + @patch("ads.aqua.ui.AquaUIApp.list_buckets") def test_list_buckets(self, versioned, mock_list_buckets): """Test get method to fetch list of buckets.""" self.ui_handler.request.path = "aqua/buckets" @@ -113,7 +120,7 @@ def test_list_buckets(self, versioned, mock_list_buckets): versioned=True if versioned == "true" else False, ) - @patch.object(AquaUIApp, "list_job_shapes") + @patch("ads.aqua.ui.AquaUIApp.list_job_shapes") def test_list_job_shapes(self, mock_list_job_shapes): """Test get method to fetch jobs shapes list.""" self.ui_handler.request.path = "aqua/job/shapes" @@ -122,14 +129,14 @@ def test_list_job_shapes(self, mock_list_job_shapes): compartment_id=TestDataset.USER_COMPARTMENT_ID ) - @patch.object(AquaUIApp, "list_vcn") + @patch("ads.aqua.ui.AquaUIApp.list_vcn") def test_list_vcn(self, mock_list_vcn): """Test get method to fetch list of vcns.""" self.ui_handler.request.path = "aqua/vcn" self.ui_handler.get() mock_list_vcn.assert_called_with(compartment_id=TestDataset.USER_COMPARTMENT_ID) - @patch.object(AquaUIApp, "list_subnets") + @patch("ads.aqua.ui.AquaUIApp.list_subnets") def test_list_subnets(self, mock_list_subnets): """Test the get method to fetch list of subnets.""" self.ui_handler.request.path = "aqua/subnets" @@ -142,7 +149,7 @@ def test_list_subnets(self, mock_list_subnets): compartment_id=TestDataset.USER_COMPARTMENT_ID, vcn_id="mock-vcn-id" ) - @patch.object(AquaUIApp, "get_shape_availability") + @patch("ads.aqua.ui.AquaUIApp.get_shape_availability") def test_get_shape_availability(self, mock_get_shape_availability): """Test get shape availability.""" self.ui_handler.request.path = "aqua/shapes/limit" @@ -156,7 +163,7 @@ def test_get_shape_availability(self, mock_get_shape_availability): instance_shape=TestDataset.DEPLOYMENT_SHAPE_NAME, ) - @patch.object(AquaUIApp, "is_bucket_versioned") + @patch("ads.aqua.ui.AquaUIApp.is_bucket_versioned") def test_is_bucket_versioned(self, mock_is_bucket_versioned): """Test get method to check if a bucket is versioned.""" self.ui_handler.request.path = "aqua/bucket/versioning" diff --git a/tests/unitary/with_extras/aqua/test_ui_websocket_handler.py b/tests/unitary/with_extras/aqua/test_ui_websocket_handler.py new file mode 100644 index 000000000..898c76058 --- /dev/null +++ b/tests/unitary/with_extras/aqua/test_ui_websocket_handler.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import os +import unittest +from concurrent.futures import Future +from unittest.mock import MagicMock, patch + +from tornado.websocket import WebSocketHandler + +from ads.aqua.extension.evaluation_ws_msg_handler import AquaEvaluationWSMsgHandler +from ads.aqua.extension.ui_websocket_handler import AquaUIWebSocketHandler + + +class TestAquaUIWebSocketHandler(unittest.TestCase): + @patch.object(WebSocketHandler, "__init__") + def setUp(self, webSocketInitMock) -> None: + webSocketInitMock.return_value = None + self.web_socket_handler = AquaUIWebSocketHandler(MagicMock(), MagicMock()) + + def test_throws_error_on_duplicate_msg_handlers(self): + """Test that an error is thrown when duplicate message handlers are added.""" + with self.assertRaises(ValueError): + AquaUIWebSocketHandler._handlers_.append(AquaEvaluationWSMsgHandler) + AquaUIWebSocketHandler(MagicMock(), MagicMock()) + AquaUIWebSocketHandler._handlers_.pop() + + def test_throws_error_on_bad_request(self): + """Test that an error is thrown when a bad request is made.""" + with self.assertRaises(ValueError): + self.web_socket_handler.on_message("test") + + @patch.object(AquaUIWebSocketHandler, "write_message") + def test_throws_error_on_unexpected_kind(self, write_message_mock: MagicMock): + """Test that an error is thrown when an unexpected kind is received.""" + write_message_mock.return_value = None + with self.assertRaises(ValueError): + self.web_socket_handler.on_message( + '{"message_id": "test", "kind": "test", "data": {}}' + ) + assert write_message_mock.called + + @patch.object(AquaUIWebSocketHandler, "write_message") + @patch("ads.aqua.extension.ui_websocket_handler.IOLoop") + def test_throws_internal_error_on_future_error( + self, ioloop_mock: MagicMock, write_message_mock: MagicMock + ): + future = Future() + future.set_exception(ValueError()) + self.web_socket_handler.future_message_map[future] = MagicMock() + with self.assertRaises(ValueError): + self.web_socket_handler.on_message_processed(future) + assert ioloop_mock.current().run_sync.called diff --git a/tests/unitary/with_extras/aqua/test_utils.py b/tests/unitary/with_extras/aqua/test_utils.py index 926b1f468..b68d0789f 100644 --- a/tests/unitary/with_extras/aqua/test_utils.py +++ b/tests/unitary/with_extras/aqua/test_utils.py @@ -10,8 +10,8 @@ from oci.resource_search.models.resource_summary import ResourceSummary from parameterized import parameterized -from ads.aqua import utils -from ads.aqua.exception import AquaRuntimeError +from ads.aqua.common import utils +from ads.aqua.common.errors import AquaRuntimeError from ads.common.oci_resource import SEARCH_TYPE, OCIResource from ads.config import TENANCY_OCID