Skip to content

Commit

Permalink
Merge branch 'main' into feature/hf-list-model-api-changes
Browse files Browse the repository at this point in the history
  • Loading branch information
kumar-shivam-ranjan authored Oct 23, 2024
2 parents 82b2613 + 52dfede commit 5000146
Showing 1 changed file with 49 additions and 48 deletions.
97 changes: 49 additions & 48 deletions ads/aqua/modeldeployment/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from ads.config import (
AQUA_CONFIG_FOLDER,
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME,
AQUA_MODEL_DEPLOYMENT_CONFIG,
AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS,
COMPARTMENT_OCID,
Expand Down Expand Up @@ -87,26 +86,26 @@ class AquaDeploymentApp(AquaApp):

@telemetry(entry_point="plugin=deployment&action=create", name="aqua")
def create(
self,
model_id: str,
instance_shape: str,
display_name: str,
instance_count: int = None,
log_group_id: str = None,
access_log_id: str = None,
predict_log_id: str = None,
compartment_id: str = None,
project_id: str = None,
description: str = None,
bandwidth_mbps: int = None,
web_concurrency: int = None,
server_port: int = None,
health_check_port: int = None,
env_var: Dict = None,
container_family: str = None,
memory_in_gbs: Optional[float] = None,
ocpus: Optional[float] = None,
model_file: Optional[str] = None,
self,
model_id: str,
instance_shape: str,
display_name: str,
instance_count: int = None,
log_group_id: str = None,
access_log_id: str = None,
predict_log_id: str = None,
compartment_id: str = None,
project_id: str = None,
description: str = None,
bandwidth_mbps: int = None,
web_concurrency: int = None,
server_port: int = None,
health_check_port: int = None,
env_var: Dict = None,
container_family: str = None,
memory_in_gbs: Optional[float] = None,
ocpus: Optional[float] = None,
model_file: Optional[str] = None,
) -> "AquaDeployment":
"""
Creates a new Aqua deployment
Expand Down Expand Up @@ -175,6 +174,7 @@ def create(
tags[tag] = aqua_model.freeform_tags[tag]

tags.update({Tags.AQUA_MODEL_NAME_TAG: aqua_model.display_name})
tags.update({Tags.TASK: aqua_model.freeform_tags.get(Tags.TASK, None)})

# Set up info to get deployment config
config_source_id = model_id
Expand Down Expand Up @@ -231,8 +231,7 @@ def create(
env_var.update({"FT_MODEL": f"{fine_tune_output_path}"})

container_type_key = self._get_container_type_key(
model=aqua_model,
container_family=container_family
model=aqua_model, container_family=container_family
)

# fetch image name from config
Expand All @@ -248,7 +247,11 @@ def create(
model_format = model_formats_str.split(",")

# Figure out a better way to handle this in future release
if ModelFormat.GGUF.value in model_format and container_type_key.lower() == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY:
if (
ModelFormat.GGUF.value in model_format
and container_type_key.lower()
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
):
if model_file is not None:
logger.info(
f"Overriding {model_file} as model_file for model {aqua_model.id}."
Expand Down Expand Up @@ -299,8 +302,8 @@ def create(
if user_params:
# todo: remove this check in the future version, logic to be moved to container_index
if (
container_type_key.lower()
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
container_type_key.lower()
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
):
# AQUA_LLAMA_CPP_CONTAINER_FAMILY container uses uvicorn that required model/server params
# to be set as env vars
Expand Down Expand Up @@ -422,9 +425,8 @@ def _get_container_type_key(model: DataScienceModel, container_family: str) -> s
f"for model {model.id}. For unverified Aqua models, {AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} should be"
f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
) from err

return container_type_key


@telemetry(entry_point="plugin=deployment&action=list", name="aqua")
def list(self, **kwargs) -> List["AquaDeployment"]:
Expand Down Expand Up @@ -453,8 +455,8 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
for model_deployment in model_deployments:
oci_aqua = (
(
Tags.AQUA_TAG in model_deployment.freeform_tags
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
Tags.AQUA_TAG in model_deployment.freeform_tags
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
)
if model_deployment.freeform_tags
else False
Expand Down Expand Up @@ -508,8 +510,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":

oci_aqua = (
(
Tags.AQUA_TAG in model_deployment.freeform_tags
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
Tags.AQUA_TAG in model_deployment.freeform_tags
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
)
if model_deployment.freeform_tags
else False
Expand All @@ -526,8 +528,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
log_group_name = ""

logs = (
model_deployment.category_log_details.access
or model_deployment.category_log_details.predict
model_deployment.category_log_details.access
or model_deployment.category_log_details.predict
)
if logs:
log_id = logs.log_id
Expand Down Expand Up @@ -582,9 +584,9 @@ def get_deployment_config(self, model_id: str) -> Dict:
return config

def get_deployment_default_params(
self,
model_id: str,
instance_shape: str,
self,
model_id: str,
instance_shape: str,
) -> List[str]:
"""Gets the default params set in the deployment configs for the given model and instance shape.
Expand Down Expand Up @@ -616,8 +618,8 @@ def get_deployment_default_params(
)

if (
container_type_key
and container_type_key in InferenceContainerTypeFamily.values()
container_type_key
and container_type_key in InferenceContainerTypeFamily.values()
):
deployment_config = self.get_deployment_config(model_id)
config_params = (
Expand All @@ -640,10 +642,10 @@ def get_deployment_default_params(
return default_params

def validate_deployment_params(
self,
model_id: str,
params: List[str] = None,
container_family: str = None,
self,
model_id: str,
params: List[str] = None,
container_family: str = None,
) -> Dict:
"""Validate if the deployment parameters passed by the user can be overridden. Parameter values are not
validated, only param keys are validated.
Expand All @@ -666,8 +668,7 @@ def validate_deployment_params(
if params:
model = DataScienceModel.from_id(model_id)
container_type_key = self._get_container_type_key(
model=model,
container_family=container_family
model=model, container_family=container_family
)

container_config = get_container_config()
Expand All @@ -689,9 +690,9 @@ def validate_deployment_params(

@staticmethod
def _find_restricted_params(
default_params: Union[str, List[str]],
user_params: Union[str, List[str]],
container_family: str,
default_params: Union[str, List[str]],
user_params: Union[str, List[str]],
container_family: str,
) -> List[str]:
"""Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
The default parameters coming from the container index json file cannot be overridden.
Expand Down

0 comments on commit 5000146

Please sign in to comment.