Skip to content

Commit

Permalink
Merge branch 'main' into dataflow_changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ahosler authored Dec 16, 2024
2 parents 320aeeb + eab7c47 commit 21a813d
Show file tree
Hide file tree
Showing 18 changed files with 422 additions and 159 deletions.
14 changes: 12 additions & 2 deletions ads/aqua/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import json
import os
from dataclasses import fields
from typing import Dict, Union
Expand Down Expand Up @@ -135,6 +136,8 @@ def create_model_version_set(
description: str = None,
compartment_id: str = None,
project_id: str = None,
freeform_tags: dict = None,
defined_tags: dict = None,
**kwargs,
) -> tuple:
"""Creates ModelVersionSet from given ID or Name.
Expand All @@ -153,7 +156,10 @@ def create_model_version_set(
Project OCID.
tag: (str, optional)
calling tag, can be Tags.AQUA_FINE_TUNING or Tags.AQUA_EVALUATION
freeform_tags: (dict, optional)
Freeform tags for the model version set
defined_tags: (dict, optional)
Defined tags for the model version set
Returns
-------
tuple: (model_version_set_id, model_version_set_name)
Expand Down Expand Up @@ -182,13 +188,15 @@ def create_model_version_set(
mvs_freeform_tags = {
tag: tag,
}
mvs_freeform_tags = {**mvs_freeform_tags, **(freeform_tags or {})}
model_version_set = (
ModelVersionSet()
.with_compartment_id(compartment_id)
.with_project_id(project_id)
.with_name(model_version_set_name)
.with_description(description)
.with_freeform_tags(**mvs_freeform_tags)
.with_defined_tags(**(defined_tags or {}))
# TODO: decide what parameters will be needed
# when refactor eval to use this method, we need to pass tag here.
.create(**kwargs)
Expand Down Expand Up @@ -340,7 +348,9 @@ def build_cli(self) -> str:
"""
cmd = f"ads aqua {self._command}"
params = [
f"--{field.name} {getattr(self,field.name)}"
f"--{field.name} {json.dumps(getattr(self, field.name))}"
if isinstance(getattr(self, field.name), dict)
else f"--{field.name} {getattr(self, field.name)}"
for field in fields(self.__class__)
if getattr(self, field.name) is not None
]
Expand Down
6 changes: 6 additions & 0 deletions ads/aqua/evaluation/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class CreateAquaEvaluationDetails(Serializable):
The metrics for the evaluation.
force_overwrite: (bool, optional). Defaults to `False`.
Whether to force overwrite the existing file in object storage.
freeform_tags: (dict, optional)
Freeform tags for the evaluation model
defined_tags: (dict, optional)
Defined tags for the evaluation model
"""

evaluation_source_id: str
Expand All @@ -85,6 +89,8 @@ class CreateAquaEvaluationDetails(Serializable):
log_id: Optional[str] = None
metrics: Optional[List[Dict[str, Any]]] = None
force_overwrite: Optional[bool] = False
freeform_tags: Optional[dict] = None
defined_tags: Optional[dict] = None

class Config:
extra = "ignore"
Expand Down
28 changes: 25 additions & 3 deletions ads/aqua/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ def create(
evaluation_mvs_freeform_tags = {
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
}
evaluation_mvs_freeform_tags = {
**evaluation_mvs_freeform_tags,
**(create_aqua_evaluation_details.freeform_tags or {}),
}

model_version_set = (
ModelVersionSet()
Expand All @@ -307,6 +311,9 @@ def create(
create_aqua_evaluation_details.experiment_description
)
.with_freeform_tags(**evaluation_mvs_freeform_tags)
.with_defined_tags(
**(create_aqua_evaluation_details.defined_tags or {})
)
# TODO: decide what parameters will be needed
.create(**kwargs)
)
Expand Down Expand Up @@ -369,6 +376,10 @@ def create(
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
Tags.AQUA_EVALUATION_MODEL_ID: evaluation_model.id,
}
evaluation_job_freeform_tags = {
**evaluation_job_freeform_tags,
**(create_aqua_evaluation_details.freeform_tags or {}),
}

evaluation_job = Job(name=evaluation_model.display_name).with_infrastructure(
DataScienceJob()
Expand All @@ -379,6 +390,7 @@ def create(
.with_shape_name(create_aqua_evaluation_details.shape_name)
.with_block_storage_size(create_aqua_evaluation_details.block_storage_size)
.with_freeform_tag(**evaluation_job_freeform_tags)
.with_defined_tag(**(create_aqua_evaluation_details.defined_tags or {}))
)
if (
create_aqua_evaluation_details.memory_in_gbs
Expand Down Expand Up @@ -425,6 +437,7 @@ def create(
evaluation_job_run = evaluation_job.run(
name=evaluation_model.display_name,
freeform_tags=evaluation_job_freeform_tags,
defined_tags=(create_aqua_evaluation_details.defined_tags or {}),
wait=False,
)
logger.debug(
Expand All @@ -444,13 +457,20 @@ def create(
for metadata in evaluation_model_custom_metadata.to_dict()["data"]
]

evaluation_model_freeform_tags = {
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
**(create_aqua_evaluation_details.freeform_tags or {}),
}
evaluation_model_defined_tags = (
create_aqua_evaluation_details.defined_tags or {}
)

self.ds_client.update_model(
model_id=evaluation_model.id,
update_model_details=UpdateModelDetails(
custom_metadata_list=updated_custom_metadata_list,
freeform_tags={
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
},
freeform_tags=evaluation_model_freeform_tags,
defined_tags=evaluation_model_defined_tags,
),
)

Expand Down Expand Up @@ -524,6 +544,8 @@ def create(
"evaluation_job_id": evaluation_job.id,
"evaluation_source": create_aqua_evaluation_details.evaluation_source_id,
"evaluation_experiment_id": experiment_model_version_set_id,
**evaluation_model_freeform_tags,
**evaluation_model_defined_tags,
},
parameters=AquaEvalParams(),
)
Expand Down
12 changes: 8 additions & 4 deletions ads/aqua/extension/deployment_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def delete(self, model_deployment_id):
return self.finish(AquaDeploymentApp().delete(model_deployment_id))

@handle_exceptions
def put(self, *args, **kwargs):
def put(self, *args, **kwargs): # noqa: ARG002
"""
Handles put request for the activating and deactivating OCI datascience model deployments
Raises
Expand All @@ -82,7 +82,7 @@ def put(self, *args, **kwargs):
raise HTTPError(400, f"The request {self.request.path} is invalid.")

@handle_exceptions
def post(self, *args, **kwargs):
def post(self, *args, **kwargs): # noqa: ARG002
"""
Handles post request for the deployment APIs
Raises
Expand Down Expand Up @@ -132,6 +132,8 @@ def post(self, *args, **kwargs):
private_endpoint_id = input_data.get("private_endpoint_id")
container_image_uri = input_data.get("container_image_uri")
cmd_var = input_data.get("cmd_var")
freeform_tags = input_data.get("freeform_tags")
defined_tags = input_data.get("defined_tags")

self.finish(
AquaDeploymentApp().create(
Expand All @@ -157,6 +159,8 @@ def post(self, *args, **kwargs):
private_endpoint_id=private_endpoint_id,
container_image_uri=container_image_uri,
cmd_var=cmd_var,
freeform_tags=freeform_tags,
defined_tags=defined_tags,
)
)

Expand Down Expand Up @@ -196,7 +200,7 @@ def validate_predict_url(endpoint):
return False

@handle_exceptions
def post(self, *args, **kwargs):
def post(self, *args, **kwargs): # noqa: ARG002
"""
Handles inference request for the Active Model Deployments
Raises
Expand Down Expand Up @@ -262,7 +266,7 @@ def get(self, model_id):
)

@handle_exceptions
def post(self, *args, **kwargs):
def post(self, *args, **kwargs): # noqa: ARG002
"""Handles post request for the deployment param handler API.
Raises
Expand Down
16 changes: 9 additions & 7 deletions ads/aqua/extension/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def list(self):
)

@handle_exceptions
def post(self, *args, **kwargs):
def post(self, *args, **kwargs): # noqa: ARG002
"""
Handles post request for the registering any Aqua model.
Raises
Expand Down Expand Up @@ -131,6 +131,8 @@ def post(self, *args, **kwargs):
inference_container_uri = input_data.get("inference_container_uri")
allow_patterns = input_data.get("allow_patterns")
ignore_patterns = input_data.get("ignore_patterns")
freeform_tags = input_data.get("freeform_tags")
defined_tags = input_data.get("defined_tags")

return self.finish(
AquaModelApp().register(
Expand All @@ -145,6 +147,8 @@ def post(self, *args, **kwargs):
inference_container_uri=inference_container_uri,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
freeform_tags=freeform_tags,
defined_tags=defined_tags,
)
)

Expand All @@ -170,11 +174,9 @@ def put(self, id):

enable_finetuning = input_data.get("enable_finetuning")
task = input_data.get("task")
app=AquaModelApp()
app = AquaModelApp()
self.finish(
app.edit_registered_model(
id, inference_container, enable_finetuning, task
)
app.edit_registered_model(id, inference_container, enable_finetuning, task)
)
app.clear_model_details_cache(model_id=id)

Expand Down Expand Up @@ -218,7 +220,7 @@ def _find_matching_aqua_model(model_id: str) -> Optional[AquaModelSummary]:
return None

@handle_exceptions
def get(self, *args, **kwargs):
def get(self, *args, **kwargs): # noqa: ARG002
"""
Finds a list of matching models from hugging face based on query string provided from users.
Expand All @@ -239,7 +241,7 @@ def get(self, *args, **kwargs):
return self.finish({"models": models})

@handle_exceptions
def post(self, *args, **kwargs):
def post(self, *args, **kwargs): # noqa: ARG002
"""Handles post request for the HF Models APIs
Raises
Expand Down
6 changes: 6 additions & 0 deletions ads/aqua/finetuning/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class CreateFineTuningDetails(DataClassSerializable):
The log id for fine tuning job infrastructure.
force_overwrite: (bool, optional). Defaults to `False`.
Whether to force overwrite the existing file in object storage.
freeform_tags: (dict, optional)
Freeform tags for the fine-tuning model
defined_tags: (dict, optional)
Defined tags for the fine-tuning model
"""

ft_source_id: str
Expand All @@ -101,3 +105,5 @@ class CreateFineTuningDetails(DataClassSerializable):
log_id: Optional[str] = None
log_group_id: Optional[str] = None
force_overwrite: Optional[bool] = False
freeform_tags: Optional[dict] = None
defined_tags: Optional[dict] = None
Loading

0 comments on commit 21a813d

Please sign in to comment.