Skip to content

Commit

Permalink
Improved evaluation model parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
lu-ohai committed Oct 30, 2024
1 parent 2799630 commit b983a7f
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 146 deletions.
89 changes: 47 additions & 42 deletions ads/aqua/evaluation/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,18 @@
This module contains dataclasses for aqua evaluation.
"""

from dataclasses import dataclass, field
from typing import List, Optional, Union
from pydantic import Field
from typing import Any, Dict, List, Optional, Union

from ads.aqua.data import AquaResourceIdentifier
from ads.common.serializer import DataClassSerializable
from ads.aqua.config.utils.serializer import Serializable


@dataclass(repr=False)
class CreateAquaEvaluationDetails(DataClassSerializable):
"""Dataclass to create aqua model evaluation.
class CreateAquaEvaluationDetails(Serializable):
"""Class for creating aqua model evaluation.
Fields
------
Properties
----------
evaluation_source_id: str
The evaluation source id. Must be either model or model deployment ocid.
evaluation_name: str
Expand Down Expand Up @@ -83,69 +82,74 @@ class CreateAquaEvaluationDetails(DataClassSerializable):
ocpus: Optional[float] = None
log_group_id: Optional[str] = None
log_id: Optional[str] = None
metrics: Optional[List] = None
metrics: Optional[List[str]] = None
force_overwrite: Optional[bool] = False

class Config:
extra = "ignore"

@dataclass(repr=False)
class AquaEvalReport(DataClassSerializable):
class AquaEvalReport(Serializable):
evaluation_id: str = ""
content: str = ""

class Config:
extra = "ignore"

@dataclass(repr=False)
class ModelParams(DataClassSerializable):
class ModelParams(Serializable):
max_tokens: str = ""
top_p: str = ""
top_k: str = ""
temperature: str = ""
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
stop: Optional[Union[str, List[str]]] = field(default_factory=list)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
model: Optional[str] = "odsc-llm"

class Config:
extra = "allow"

@dataclass(repr=False)
class AquaEvalParams(ModelParams, DataClassSerializable):
class AquaEvalParams(ModelParams):
shape: str = ""
dataset_path: str = ""
report_path: str = ""


@dataclass(repr=False)
class AquaEvalMetric(DataClassSerializable):
class AquaEvalMetric(Serializable):
key: str
name: str
description: str = ""

class Config:
extra = "ignore"

@dataclass(repr=False)
class AquaEvalMetricSummary(DataClassSerializable):
class AquaEvalMetricSummary(Serializable):
metric: str = ""
score: str = ""
grade: str = ""

class Config:
extra = "ignore"

@dataclass(repr=False)
class AquaEvalMetrics(DataClassSerializable):
class AquaEvalMetrics(Serializable):
id: str
report: str
metric_results: List[AquaEvalMetric] = field(default_factory=list)
metric_summary_result: List[AquaEvalMetricSummary] = field(default_factory=list)
metric_results: List[AquaEvalMetric] = Field(default_factory=list)
metric_summary_result: List[AquaEvalMetricSummary] = Field(default_factory=list)

class Config:
extra = "ignore"

@dataclass(repr=False)
class AquaEvaluationCommands(DataClassSerializable):
class AquaEvaluationCommands(Serializable):
evaluation_id: str
evaluation_target_id: str
input_data: dict
metrics: list
input_data: Dict[str, Any]
metrics: List[str]
output_dir: str
params: dict
params: Dict[str, Any]

class Config:
extra = "ignore"

@dataclass(repr=False)
class AquaEvaluationSummary(DataClassSerializable):
class AquaEvaluationSummary(Serializable):
"""Represents a summary of Aqua evalution."""

id: str
Expand All @@ -154,17 +158,18 @@ class AquaEvaluationSummary(DataClassSerializable):
lifecycle_state: str
lifecycle_details: str
time_created: str
tags: dict
experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
parameters: AquaEvalParams = field(default_factory=AquaEvalParams)
tags: Dict[str, Any]
experiment: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
source: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
job: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
parameters: AquaEvalParams = Field(default_factory=AquaEvalParams)

class Config:
extra = "ignore"

@dataclass(repr=False)
class AquaEvaluationDetail(AquaEvaluationSummary, DataClassSerializable):
class AquaEvaluationDetail(AquaEvaluationSummary):
"""Represents a details of Aqua evalution."""

log_group: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
log: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
introspection: dict = field(default_factory=dict)
log_group: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
log: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
introspection: dict = Field(default_factory=dict)
82 changes: 23 additions & 59 deletions ads/aqua/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import re
import tempfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import asdict, fields
from datetime import datetime, timedelta
from pathlib import Path
from threading import Lock
Expand Down Expand Up @@ -46,7 +45,6 @@
upload_local_to_os,
)
from ads.aqua.config.config import get_evaluation_service_config
from ads.aqua.config.evaluation.evaluation_service_config import EvaluationServiceConfig
from ads.aqua.constants import (
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
EVALUATION_REPORT,
Expand Down Expand Up @@ -75,7 +73,6 @@
AquaEvaluationSummary,
AquaResourceIdentifier,
CreateAquaEvaluationDetails,
ModelParams,
)
from ads.aqua.evaluation.errors import EVALUATION_JOB_EXIT_CODE_MESSAGE
from ads.aqua.ui import AquaContainerConfig
Expand Down Expand Up @@ -164,7 +161,7 @@ def create(
raise AquaValueError(
"Invalid create evaluation parameters. "
"Allowable parameters are: "
f"{', '.join([field.name for field in fields(CreateAquaEvaluationDetails)])}."
f"{', '.join([field for field in CreateAquaEvaluationDetails.model_fields])}."
) from ex

if not is_valid_ocid(create_aqua_evaluation_details.evaluation_source_id):
Expand All @@ -175,15 +172,7 @@ def create(

# The model to evaluate
evaluation_source = None
# The evaluation service config
evaluation_config: EvaluationServiceConfig = get_evaluation_service_config()
# The evaluation inference configuration. The inference configuration will be extracted
# based on the inferencing container family.
eval_inference_configuration: Dict = {}
# The evaluation inference model sampling params. The system parameters that will not be
# visible for user, but will be applied implicitly for evaluation. The service model params
# will be extracted based on the container family and version.
eval_inference_service_model_params: Dict = {}

if (
DataScienceResource.MODEL_DEPLOYMENT
Expand All @@ -200,29 +189,14 @@ def create(
runtime = ModelDeploymentContainerRuntime.from_dict(
evaluation_source.runtime.to_dict()
)
container_config = AquaContainerConfig.from_container_index_json(
inference_config = AquaContainerConfig.from_container_index_json(
enable_spec=True
)
for (
inference_container_family,
inference_container_info,
) in container_config.inference.items():
if (
inference_container_info.name
== runtime.image[: runtime.image.rfind(":")]
):
).inference
for container in inference_config.values():
if container.name == runtime.image[: runtime.image.rfind(":")]:
eval_inference_configuration = (
evaluation_config.get_merged_inference_params(
inference_container_family
).to_dict()
)
eval_inference_service_model_params = (
evaluation_config.get_merged_inference_model_params(
inference_container_family,
inference_container_info.version,
)
container.spec.evaluation_configuration
)

except Exception:
logger.debug(
f"Could not load inference config details for the evaluation source id: "
Expand Down Expand Up @@ -277,19 +251,12 @@ def create(
)
evaluation_dataset_path = dst_uri

evaluation_model_parameters = None
try:
evaluation_model_parameters = AquaEvalParams(
shape=create_aqua_evaluation_details.shape_name,
dataset_path=evaluation_dataset_path,
report_path=create_aqua_evaluation_details.report_path,
**create_aqua_evaluation_details.model_parameters,
)
except Exception as ex:
raise AquaValueError(
"Invalid model parameters. Model parameters should "
f"be a dictionary with keys: {', '.join(list(ModelParams.__annotations__.keys()))}."
) from ex
evaluation_model_parameters = AquaEvalParams(
shape=create_aqua_evaluation_details.shape_name,
dataset_path=evaluation_dataset_path,
report_path=create_aqua_evaluation_details.report_path,
**create_aqua_evaluation_details.model_parameters,
)

target_compartment = (
create_aqua_evaluation_details.compartment_id or COMPARTMENT_OCID
Expand Down Expand Up @@ -370,7 +337,7 @@ def create(
evaluation_model_taxonomy_metadata = ModelTaxonomyMetadata()
evaluation_model_taxonomy_metadata[
MetadataTaxonomyKeys.HYPERPARAMETERS
].value = {"model_params": dict(asdict(evaluation_model_parameters))}
].value = {"model_params": evaluation_model_parameters.to_dict()}

evaluation_model = (
DataScienceModel()
Expand Down Expand Up @@ -443,7 +410,6 @@ def create(
dataset_path=evaluation_dataset_path,
report_path=create_aqua_evaluation_details.report_path,
model_parameters={
**eval_inference_service_model_params,
**create_aqua_evaluation_details.model_parameters,
},
metrics=create_aqua_evaluation_details.metrics,
Expand Down Expand Up @@ -580,16 +546,14 @@ def _build_evaluation_runtime(
**{
"AIP_SMC_EVALUATION_ARGUMENTS": json.dumps(
{
**asdict(
self._build_launch_cmd(
evaluation_id=evaluation_id,
evaluation_source_id=evaluation_source_id,
dataset_path=dataset_path,
report_path=report_path,
model_parameters=model_parameters,
metrics=metrics,
),
),
**self._build_launch_cmd(
evaluation_id=evaluation_id,
evaluation_source_id=evaluation_source_id,
dataset_path=dataset_path,
report_path=report_path,
model_parameters=model_parameters,
metrics=metrics,
).to_dict(),
**(inference_configuration or {}),
},
),
Expand Down Expand Up @@ -662,9 +626,9 @@ def _build_launch_cmd(
"format": Path(dataset_path).suffix,
"url": dataset_path,
},
metrics=metrics,
metrics=metrics or [],
output_dir=report_path,
params=model_parameters,
params=model_parameters or {},
)

@telemetry(entry_point="plugin=evaluation&action=get", name="aqua")
Expand Down
5 changes: 0 additions & 5 deletions ads/aqua/extension/evaluation_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ads.aqua.evaluation.entities import CreateAquaEvaluationDetails
from ads.aqua.extension.base_handler import AquaAPIhandler
from ads.aqua.extension.errors import Errors
from ads.aqua.extension.utils import validate_function_parameters
from ads.config import COMPARTMENT_OCID


Expand Down Expand Up @@ -47,10 +46,6 @@ def post(self, *args, **kwargs): # noqa
if not input_data:
raise HTTPError(400, Errors.NO_INPUT_DATA)

validate_function_parameters(
data_class=CreateAquaEvaluationDetails, input_data=input_data
)

self.finish(
# TODO: decide what other kwargs will be needed for create aqua evaluation.
AquaEvaluationApp().create(
Expand Down
9 changes: 2 additions & 7 deletions tests/unitary/with_extras/aqua/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import json
import os
import unittest
from dataclasses import asdict
from unittest.mock import MagicMock, PropertyMock, patch

import oci
Expand Down Expand Up @@ -419,14 +418,13 @@ def assert_payload(self, response, response_type):
"""Checks each field is not empty."""

attributes = response_type.__annotations__.keys()
rdict = asdict(response)
rdict = response.to_dict()

for attr in attributes:
if attr == "lifecycle_details": # can be empty when jobrun is succeed
continue
assert rdict.get(attr), f"{attr} is empty"

@patch("ads.aqua.evaluation.evaluation.get_evaluation_service_config")
@patch.object(Job, "run")
@patch("ads.jobs.ads_job.Job.name", new_callable=PropertyMock)
@patch("ads.jobs.ads_job.Job.id", new_callable=PropertyMock)
Expand All @@ -445,7 +443,6 @@ def test_create_evaluation(
mock_job_id,
mock_job_name,
mock_job_run,
mock_get_evaluation_service_config,
):
foundation_model = MagicMock()
foundation_model.display_name = "test_foundation_model"
Expand Down Expand Up @@ -475,8 +472,6 @@ def test_create_evaluation(
evaluation_job_run.lifecycle_state = "IN_PROGRESS"
mock_job_run.return_value = evaluation_job_run

mock_get_evaluation_service_config.return_value = EvaluationServiceConfig()

self.app.ds_client.update_model = MagicMock()
self.app.ds_client.update_model_provenance = MagicMock()

Expand All @@ -494,7 +489,7 @@ def test_create_evaluation(
)
aqua_evaluation_summary = self.app.create(**create_aqua_evaluation_details)

assert asdict(aqua_evaluation_summary) == {
assert aqua_evaluation_summary.to_dict() == {
"console_url": f"https://cloud.oracle.com/data-science/models/{evaluation_model.id}?region={self.app.region}",
"experiment": {
"id": f"{experiment.id}",
Expand Down
Loading

0 comments on commit b983a7f

Please sign in to comment.