From b983a7fd01e825aaf217339c586c2f5a750bac55 Mon Sep 17 00:00:00 2001 From: Lu Peng Date: Wed, 30 Oct 2024 16:22:57 -0400 Subject: [PATCH] Improved evaluation model parameters. --- ads/aqua/evaluation/entities.py | 89 ++++++++++--------- ads/aqua/evaluation/evaluation.py | 82 +++++------------ ads/aqua/extension/evaluation_handler.py | 5 -- .../with_extras/aqua/test_evaluation.py | 9 +- .../aqua/test_evaluation_handler.py | 33 ------- 5 files changed, 72 insertions(+), 146 deletions(-) diff --git a/ads/aqua/evaluation/entities.py b/ads/aqua/evaluation/entities.py index d626995a6..27a6d8b8f 100644 --- a/ads/aqua/evaluation/entities.py +++ b/ads/aqua/evaluation/entities.py @@ -9,19 +9,18 @@ This module contains dataclasses for aqua evaluation. """ -from dataclasses import dataclass, field -from typing import List, Optional, Union +from pydantic import Field +from typing import Any, Dict, List, Optional, Union from ads.aqua.data import AquaResourceIdentifier -from ads.common.serializer import DataClassSerializable +from ads.aqua.config.utils.serializer import Serializable -@dataclass(repr=False) -class CreateAquaEvaluationDetails(DataClassSerializable): - """Dataclass to create aqua model evaluation. +class CreateAquaEvaluationDetails(Serializable): + """Class for creating aqua model evaluation. - Fields - ------ + Properties + ---------- evaluation_source_id: str The evaluation source id. Must be either model or model deployment ocid. evaluation_name: str @@ -83,69 +82,74 @@ class CreateAquaEvaluationDetails(DataClassSerializable): ocpus: Optional[float] = None log_group_id: Optional[str] = None log_id: Optional[str] = None - metrics: Optional[List] = None + metrics: Optional[List[str]] = None force_overwrite: Optional[bool] = False + class Config: + extra = "ignore" -@dataclass(repr=False) -class AquaEvalReport(DataClassSerializable): +class AquaEvalReport(Serializable): evaluation_id: str = "" content: str = "" + class Config: + extra = "ignore" -@dataclass(repr=False) -class ModelParams(DataClassSerializable): +class ModelParams(Serializable): max_tokens: str = "" top_p: str = "" top_k: str = "" temperature: str = "" presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 - stop: Optional[Union[str, List[str]]] = field(default_factory=list) + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) model: Optional[str] = "odsc-llm" + class Config: + extra = "allow" -@dataclass(repr=False) -class AquaEvalParams(ModelParams, DataClassSerializable): +class AquaEvalParams(ModelParams): shape: str = "" dataset_path: str = "" report_path: str = "" - -@dataclass(repr=False) -class AquaEvalMetric(DataClassSerializable): +class AquaEvalMetric(Serializable): key: str name: str description: str = "" + class Config: + extra = "ignore" -@dataclass(repr=False) -class AquaEvalMetricSummary(DataClassSerializable): +class AquaEvalMetricSummary(Serializable): metric: str = "" score: str = "" grade: str = "" + class Config: + extra = "ignore" -@dataclass(repr=False) -class AquaEvalMetrics(DataClassSerializable): +class AquaEvalMetrics(Serializable): id: str report: str - metric_results: List[AquaEvalMetric] = field(default_factory=list) - metric_summary_result: List[AquaEvalMetricSummary] = field(default_factory=list) + metric_results: List[AquaEvalMetric] = Field(default_factory=list) + metric_summary_result: List[AquaEvalMetricSummary] = Field(default_factory=list) + class Config: + extra = "ignore" -@dataclass(repr=False) -class AquaEvaluationCommands(DataClassSerializable): +class AquaEvaluationCommands(Serializable): evaluation_id: str evaluation_target_id: str - input_data: dict - metrics: list + input_data: Dict[str, Any] + metrics: List[str] output_dir: str - params: dict + params: Dict[str, Any] + class Config: + extra = "ignore" -@dataclass(repr=False) -class AquaEvaluationSummary(DataClassSerializable): +class AquaEvaluationSummary(Serializable): """Represents a summary of Aqua evalution.""" id: str @@ -154,17 +158,18 @@ class AquaEvaluationSummary(DataClassSerializable): lifecycle_state: str lifecycle_details: str time_created: str - tags: dict - experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - parameters: AquaEvalParams = field(default_factory=AquaEvalParams) + tags: Dict[str, Any] + experiment: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier) + source: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier) + job: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier) + parameters: AquaEvalParams = Field(default_factory=AquaEvalParams) + class Config: + extra = "ignore" -@dataclass(repr=False) -class AquaEvaluationDetail(AquaEvaluationSummary, DataClassSerializable): +class AquaEvaluationDetail(AquaEvaluationSummary): """Represents a details of Aqua evalution.""" - log_group: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - log: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier) - introspection: dict = field(default_factory=dict) + log_group: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier) + log: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier) + introspection: dict = Field(default_factory=dict) diff --git a/ads/aqua/evaluation/evaluation.py b/ads/aqua/evaluation/evaluation.py index 7f7349beb..0a3cb974a 100644 --- a/ads/aqua/evaluation/evaluation.py +++ b/ads/aqua/evaluation/evaluation.py @@ -7,7 +7,6 @@ import re import tempfile from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import asdict, fields from datetime import datetime, timedelta from pathlib import Path from threading import Lock @@ -46,7 +45,6 @@ upload_local_to_os, ) from ads.aqua.config.config import get_evaluation_service_config -from ads.aqua.config.evaluation.evaluation_service_config import EvaluationServiceConfig from ads.aqua.constants import ( CONSOLE_LINK_RESOURCE_TYPE_MAPPING, EVALUATION_REPORT, @@ -75,7 +73,6 @@ AquaEvaluationSummary, AquaResourceIdentifier, CreateAquaEvaluationDetails, - ModelParams, ) from ads.aqua.evaluation.errors import EVALUATION_JOB_EXIT_CODE_MESSAGE from ads.aqua.ui import AquaContainerConfig @@ -164,7 +161,7 @@ def create( raise AquaValueError( "Invalid create evaluation parameters. " "Allowable parameters are: " - f"{', '.join([field.name for field in fields(CreateAquaEvaluationDetails)])}." + f"{', '.join([field for field in CreateAquaEvaluationDetails.model_fields])}." ) from ex if not is_valid_ocid(create_aqua_evaluation_details.evaluation_source_id): @@ -175,15 +172,7 @@ def create( # The model to evaluate evaluation_source = None - # The evaluation service config - evaluation_config: EvaluationServiceConfig = get_evaluation_service_config() - # The evaluation inference configuration. The inference configuration will be extracted - # based on the inferencing container family. eval_inference_configuration: Dict = {} - # The evaluation inference model sampling params. The system parameters that will not be - # visible for user, but will be applied implicitly for evaluation. The service model params - # will be extracted based on the container family and version. - eval_inference_service_model_params: Dict = {} if ( DataScienceResource.MODEL_DEPLOYMENT @@ -200,29 +189,14 @@ def create( runtime = ModelDeploymentContainerRuntime.from_dict( evaluation_source.runtime.to_dict() ) - container_config = AquaContainerConfig.from_container_index_json( + inference_config = AquaContainerConfig.from_container_index_json( enable_spec=True - ) - for ( - inference_container_family, - inference_container_info, - ) in container_config.inference.items(): - if ( - inference_container_info.name - == runtime.image[: runtime.image.rfind(":")] - ): + ).inference + for container in inference_config.values(): + if container.name == runtime.image[: runtime.image.rfind(":")]: eval_inference_configuration = ( - evaluation_config.get_merged_inference_params( - inference_container_family - ).to_dict() - ) - eval_inference_service_model_params = ( - evaluation_config.get_merged_inference_model_params( - inference_container_family, - inference_container_info.version, - ) + container.spec.evaluation_configuration ) - except Exception: logger.debug( f"Could not load inference config details for the evaluation source id: " @@ -277,19 +251,12 @@ def create( ) evaluation_dataset_path = dst_uri - evaluation_model_parameters = None - try: - evaluation_model_parameters = AquaEvalParams( - shape=create_aqua_evaluation_details.shape_name, - dataset_path=evaluation_dataset_path, - report_path=create_aqua_evaluation_details.report_path, - **create_aqua_evaluation_details.model_parameters, - ) - except Exception as ex: - raise AquaValueError( - "Invalid model parameters. Model parameters should " - f"be a dictionary with keys: {', '.join(list(ModelParams.__annotations__.keys()))}." - ) from ex + evaluation_model_parameters = AquaEvalParams( + shape=create_aqua_evaluation_details.shape_name, + dataset_path=evaluation_dataset_path, + report_path=create_aqua_evaluation_details.report_path, + **create_aqua_evaluation_details.model_parameters, + ) target_compartment = ( create_aqua_evaluation_details.compartment_id or COMPARTMENT_OCID @@ -370,7 +337,7 @@ def create( evaluation_model_taxonomy_metadata = ModelTaxonomyMetadata() evaluation_model_taxonomy_metadata[ MetadataTaxonomyKeys.HYPERPARAMETERS - ].value = {"model_params": dict(asdict(evaluation_model_parameters))} + ].value = {"model_params": evaluation_model_parameters.to_dict()} evaluation_model = ( DataScienceModel() @@ -443,7 +410,6 @@ def create( dataset_path=evaluation_dataset_path, report_path=create_aqua_evaluation_details.report_path, model_parameters={ - **eval_inference_service_model_params, **create_aqua_evaluation_details.model_parameters, }, metrics=create_aqua_evaluation_details.metrics, @@ -580,16 +546,14 @@ def _build_evaluation_runtime( **{ "AIP_SMC_EVALUATION_ARGUMENTS": json.dumps( { - **asdict( - self._build_launch_cmd( - evaluation_id=evaluation_id, - evaluation_source_id=evaluation_source_id, - dataset_path=dataset_path, - report_path=report_path, - model_parameters=model_parameters, - metrics=metrics, - ), - ), + **self._build_launch_cmd( + evaluation_id=evaluation_id, + evaluation_source_id=evaluation_source_id, + dataset_path=dataset_path, + report_path=report_path, + model_parameters=model_parameters, + metrics=metrics, + ).to_dict(), **(inference_configuration or {}), }, ), @@ -662,9 +626,9 @@ def _build_launch_cmd( "format": Path(dataset_path).suffix, "url": dataset_path, }, - metrics=metrics, + metrics=metrics or [], output_dir=report_path, - params=model_parameters, + params=model_parameters or {}, ) @telemetry(entry_point="plugin=evaluation&action=get", name="aqua") diff --git a/ads/aqua/extension/evaluation_handler.py b/ads/aqua/extension/evaluation_handler.py index ed040f5c4..288440525 100644 --- a/ads/aqua/extension/evaluation_handler.py +++ b/ads/aqua/extension/evaluation_handler.py @@ -12,7 +12,6 @@ from ads.aqua.evaluation.entities import CreateAquaEvaluationDetails from ads.aqua.extension.base_handler import AquaAPIhandler from ads.aqua.extension.errors import Errors -from ads.aqua.extension.utils import validate_function_parameters from ads.config import COMPARTMENT_OCID @@ -47,10 +46,6 @@ def post(self, *args, **kwargs): # noqa if not input_data: raise HTTPError(400, Errors.NO_INPUT_DATA) - validate_function_parameters( - data_class=CreateAquaEvaluationDetails, input_data=input_data - ) - self.finish( # TODO: decide what other kwargs will be needed for create aqua evaluation. AquaEvaluationApp().create( diff --git a/tests/unitary/with_extras/aqua/test_evaluation.py b/tests/unitary/with_extras/aqua/test_evaluation.py index 0a64732f7..ab31bcfb9 100644 --- a/tests/unitary/with_extras/aqua/test_evaluation.py +++ b/tests/unitary/with_extras/aqua/test_evaluation.py @@ -9,7 +9,6 @@ import json import os import unittest -from dataclasses import asdict from unittest.mock import MagicMock, PropertyMock, patch import oci @@ -419,14 +418,13 @@ def assert_payload(self, response, response_type): """Checks each field is not empty.""" attributes = response_type.__annotations__.keys() - rdict = asdict(response) + rdict = response.to_dict() for attr in attributes: if attr == "lifecycle_details": # can be empty when jobrun is succeed continue assert rdict.get(attr), f"{attr} is empty" - @patch("ads.aqua.evaluation.evaluation.get_evaluation_service_config") @patch.object(Job, "run") @patch("ads.jobs.ads_job.Job.name", new_callable=PropertyMock) @patch("ads.jobs.ads_job.Job.id", new_callable=PropertyMock) @@ -445,7 +443,6 @@ def test_create_evaluation( mock_job_id, mock_job_name, mock_job_run, - mock_get_evaluation_service_config, ): foundation_model = MagicMock() foundation_model.display_name = "test_foundation_model" @@ -475,8 +472,6 @@ def test_create_evaluation( evaluation_job_run.lifecycle_state = "IN_PROGRESS" mock_job_run.return_value = evaluation_job_run - mock_get_evaluation_service_config.return_value = EvaluationServiceConfig() - self.app.ds_client.update_model = MagicMock() self.app.ds_client.update_model_provenance = MagicMock() @@ -494,7 +489,7 @@ def test_create_evaluation( ) aqua_evaluation_summary = self.app.create(**create_aqua_evaluation_details) - assert asdict(aqua_evaluation_summary) == { + assert aqua_evaluation_summary.to_dict() == { "console_url": f"https://cloud.oracle.com/data-science/models/{evaluation_model.id}?region={self.app.region}", "experiment": { "id": f"{experiment.id}", diff --git a/tests/unitary/with_extras/aqua/test_evaluation_handler.py b/tests/unitary/with_extras/aqua/test_evaluation_handler.py index 6382c8d39..28c22c309 100644 --- a/tests/unitary/with_extras/aqua/test_evaluation_handler.py +++ b/tests/unitary/with_extras/aqua/test_evaluation_handler.py @@ -11,7 +11,6 @@ from ads.aqua.evaluation import AquaEvaluationApp from ads.aqua.evaluation.entities import CreateAquaEvaluationDetails -from ads.aqua.extension.errors import Errors from ads.aqua.extension.evaluation_handler import AquaEvaluationHandler from tests.unitary.with_extras.aqua.utils import HandlerTestDataset as TestDataset @@ -58,38 +57,6 @@ def test_post(self, mock_create): ) ) - @parameterized.expand( - [ - ( - dict(return_value=TestDataset.mock_invalid_input), - 400, - "Missing required parameter:", - ), - (dict(side_effect=Exception()), 400, Errors.INVALID_INPUT_DATA_FORMAT), - (dict(return_value=None), 400, Errors.NO_INPUT_DATA), - ] - ) - def test_post_fail( - self, mock_get_json_body_response, expected_status_code, expected_error_msg - ): - """Tests POST when encounter error.""" - self.test_instance.get_json_body = MagicMock( - side_effect=mock_get_json_body_response.get("side_effect", None), - return_value=mock_get_json_body_response.get("return_value", None), - ) - self.test_instance.write_error = MagicMock() - - self.test_instance.post() - - assert ( - self.test_instance.write_error.call_args[1].get("status_code") - == expected_status_code - ), "Raised wrong status code." - - assert expected_error_msg in self.test_instance.write_error.call_args[1].get( - "reason" - ), "Error message is incorrect." - @parameterized.expand( [ ("", TestDataset.MOCK_OCID, "get"),