Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ODSC-65517] Support freeform and defined tags for resource creation in Aqua #1021

Merged
merged 10 commits into from
Dec 12, 2024
14 changes: 12 additions & 2 deletions ads/aqua/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import json
import os
from dataclasses import fields
from typing import Dict, Union
Expand Down Expand Up @@ -135,6 +136,8 @@ def create_model_version_set(
description: str = None,
compartment_id: str = None,
project_id: str = None,
freeform_tags: dict = None,
defined_tags: dict = None,
**kwargs,
) -> tuple:
"""Creates ModelVersionSet from given ID or Name.
Expand All @@ -153,7 +156,10 @@ def create_model_version_set(
Project OCID.
tag: (str, optional)
calling tag, can be Tags.AQUA_FINE_TUNING or Tags.AQUA_EVALUATION

freeform_tags: (dict, optional)
Freeform tags for the model version set
defined_tags: (dict, optional)
Defined tags for the model version set
Returns
-------
tuple: (model_version_set_id, model_version_set_name)
Expand Down Expand Up @@ -182,13 +188,15 @@ def create_model_version_set(
mvs_freeform_tags = {
tag: tag,
}
mvs_freeform_tags = {**mvs_freeform_tags, **(freeform_tags or {})}
model_version_set = (
ModelVersionSet()
.with_compartment_id(compartment_id)
.with_project_id(project_id)
.with_name(model_version_set_name)
.with_description(description)
.with_freeform_tags(**mvs_freeform_tags)
.with_defined_tags(**(defined_tags or {}))
# TODO: decide what parameters will be needed
# when refactor eval to use this method, we need to pass tag here.
.create(**kwargs)
Expand Down Expand Up @@ -340,7 +348,9 @@ def build_cli(self) -> str:
"""
cmd = f"ads aqua {self._command}"
params = [
f"--{field.name} {getattr(self,field.name)}"
f"--{field.name} {json.dumps(getattr(self, field.name))}"
if isinstance(getattr(self, field.name), dict)
else f"--{field.name} {getattr(self, field.name)}"
for field in fields(self.__class__)
if getattr(self, field.name) is not None
]
Expand Down
6 changes: 6 additions & 0 deletions ads/aqua/evaluation/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class CreateAquaEvaluationDetails(Serializable):
The metrics for the evaluation.
force_overwrite: (bool, optional). Defaults to `False`.
Whether to force overwrite the existing file in object storage.
freeform_tags: (dict, optional)
Freeform tags for the evaluation model
defined_tags: (dict, optional)
Defined tags for the evaluation model
"""

evaluation_source_id: str
Expand All @@ -85,6 +89,8 @@ class CreateAquaEvaluationDetails(Serializable):
log_id: Optional[str] = None
metrics: Optional[List[Dict[str, Any]]] = None
force_overwrite: Optional[bool] = False
freeform_tags: Optional[dict] = None
defined_tags: Optional[dict] = None

class Config:
extra = "ignore"
Expand Down
28 changes: 25 additions & 3 deletions ads/aqua/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ def create(
evaluation_mvs_freeform_tags = {
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
}
evaluation_mvs_freeform_tags = {
**evaluation_mvs_freeform_tags,
**(create_aqua_evaluation_details.freeform_tags or {}),
}

model_version_set = (
ModelVersionSet()
Expand All @@ -307,6 +311,9 @@ def create(
create_aqua_evaluation_details.experiment_description
)
.with_freeform_tags(**evaluation_mvs_freeform_tags)
.with_defined_tags(
**(create_aqua_evaluation_details.defined_tags or {})
)
# TODO: decide what parameters will be needed
.create(**kwargs)
)
Expand Down Expand Up @@ -369,6 +376,10 @@ def create(
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
Tags.AQUA_EVALUATION_MODEL_ID: evaluation_model.id,
}
evaluation_job_freeform_tags = {
**evaluation_job_freeform_tags,
**(create_aqua_evaluation_details.freeform_tags or {}),
}

evaluation_job = Job(name=evaluation_model.display_name).with_infrastructure(
DataScienceJob()
Expand All @@ -379,6 +390,7 @@ def create(
.with_shape_name(create_aqua_evaluation_details.shape_name)
.with_block_storage_size(create_aqua_evaluation_details.block_storage_size)
.with_freeform_tag(**evaluation_job_freeform_tags)
.with_defined_tag(**(create_aqua_evaluation_details.defined_tags or {}))
)
if (
create_aqua_evaluation_details.memory_in_gbs
Expand Down Expand Up @@ -425,6 +437,7 @@ def create(
evaluation_job_run = evaluation_job.run(
name=evaluation_model.display_name,
freeform_tags=evaluation_job_freeform_tags,
defined_tags=(create_aqua_evaluation_details.defined_tags or {}),
wait=False,
)
logger.debug(
Expand All @@ -444,13 +457,20 @@ def create(
for metadata in evaluation_model_custom_metadata.to_dict()["data"]
]

evaluation_model_freeform_tags = {
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
**(create_aqua_evaluation_details.freeform_tags or {}),
}
evaluation_model_defined_tags = (
create_aqua_evaluation_details.defined_tags or {}
)

self.ds_client.update_model(
model_id=evaluation_model.id,
update_model_details=UpdateModelDetails(
custom_metadata_list=updated_custom_metadata_list,
freeform_tags={
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
},
freeform_tags=evaluation_model_freeform_tags,
defined_tags=evaluation_model_defined_tags,
),
)

Expand Down Expand Up @@ -524,6 +544,8 @@ def create(
"evaluation_job_id": evaluation_job.id,
"evaluation_source": create_aqua_evaluation_details.evaluation_source_id,
"evaluation_experiment_id": experiment_model_version_set_id,
**evaluation_model_freeform_tags,
**evaluation_model_defined_tags,
},
parameters=AquaEvalParams(),
)
Expand Down
12 changes: 8 additions & 4 deletions ads/aqua/extension/deployment_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def delete(self, model_deployment_id):
return self.finish(AquaDeploymentApp().delete(model_deployment_id))

@handle_exceptions
def put(self, *args, **kwargs):
def put(self, *args, **kwargs): # noqa: ARG002
"""
Handles put request for the activating and deactivating OCI datascience model deployments
Raises
Expand All @@ -82,7 +82,7 @@ def put(self, *args, **kwargs):
raise HTTPError(400, f"The request {self.request.path} is invalid.")

@handle_exceptions
def post(self, *args, **kwargs):
def post(self, *args, **kwargs): # noqa: ARG002
"""
Handles post request for the deployment APIs
Raises
Expand Down Expand Up @@ -132,6 +132,8 @@ def post(self, *args, **kwargs):
private_endpoint_id = input_data.get("private_endpoint_id")
container_image_uri = input_data.get("container_image_uri")
cmd_var = input_data.get("cmd_var")
freeform_tags = input_data.get("freeform_tags")
defined_tags = input_data.get("defined_tags")

self.finish(
AquaDeploymentApp().create(
Expand All @@ -157,6 +159,8 @@ def post(self, *args, **kwargs):
private_endpoint_id=private_endpoint_id,
container_image_uri=container_image_uri,
cmd_var=cmd_var,
freeform_tags=freeform_tags,
defined_tags=defined_tags,
)
)

Expand Down Expand Up @@ -196,7 +200,7 @@ def validate_predict_url(endpoint):
return False

@handle_exceptions
def post(self, *args, **kwargs):
def post(self, *args, **kwargs): # noqa: ARG002
"""
Handles inference request for the Active Model Deployments
Raises
Expand Down Expand Up @@ -262,7 +266,7 @@ def get(self, model_id):
)

@handle_exceptions
def post(self, *args, **kwargs):
def post(self, *args, **kwargs): # noqa: ARG002
"""Handles post request for the deployment param handler API.

Raises
Expand Down
16 changes: 9 additions & 7 deletions ads/aqua/extension/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def list(self):
)

@handle_exceptions
def post(self, *args, **kwargs):
def post(self, *args, **kwargs): # noqa: ARG002
"""
Handles post request for the registering any Aqua model.
Raises
Expand Down Expand Up @@ -131,6 +131,8 @@ def post(self, *args, **kwargs):
inference_container_uri = input_data.get("inference_container_uri")
allow_patterns = input_data.get("allow_patterns")
ignore_patterns = input_data.get("ignore_patterns")
freeform_tags = input_data.get("freeform_tags")
defined_tags = input_data.get("defined_tags")

return self.finish(
AquaModelApp().register(
Expand All @@ -145,6 +147,8 @@ def post(self, *args, **kwargs):
inference_container_uri=inference_container_uri,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
freeform_tags=freeform_tags,
defined_tags=defined_tags,
)
)

Expand All @@ -170,11 +174,9 @@ def put(self, id):

enable_finetuning = input_data.get("enable_finetuning")
task = input_data.get("task")
app=AquaModelApp()
app = AquaModelApp()
self.finish(
app.edit_registered_model(
id, inference_container, enable_finetuning, task
)
app.edit_registered_model(id, inference_container, enable_finetuning, task)
)
app.clear_model_details_cache(model_id=id)

Expand Down Expand Up @@ -218,7 +220,7 @@ def _find_matching_aqua_model(model_id: str) -> Optional[AquaModelSummary]:
return None

@handle_exceptions
def get(self, *args, **kwargs):
def get(self, *args, **kwargs): # noqa: ARG002
"""
Finds a list of matching models from hugging face based on query string provided from users.

Expand All @@ -239,7 +241,7 @@ def get(self, *args, **kwargs):
return self.finish({"models": models})

@handle_exceptions
def post(self, *args, **kwargs):
def post(self, *args, **kwargs): # noqa: ARG002
"""Handles post request for the HF Models APIs

Raises
Expand Down
6 changes: 6 additions & 0 deletions ads/aqua/finetuning/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class CreateFineTuningDetails(DataClassSerializable):
The log id for fine tuning job infrastructure.
force_overwrite: (bool, optional). Defaults to `False`.
Whether to force overwrite the existing file in object storage.
freeform_tags: (dict, optional)
Freeform tags for the fine-tuning model
defined_tags: (dict, optional)
Defined tags for the fine-tuning model
"""

ft_source_id: str
Expand All @@ -101,3 +105,5 @@ class CreateFineTuningDetails(DataClassSerializable):
log_id: Optional[str] = None
log_group_id: Optional[str] = None
force_overwrite: Optional[bool] = False
freeform_tags: Optional[dict] = None
defined_tags: Optional[dict] = None
Loading
Loading