From 82d1ae26398d5847cb5d53c0003515c59bc15ced Mon Sep 17 00:00:00 2001 From: Vipul Date: Mon, 9 Dec 2024 22:16:29 -0800 Subject: [PATCH] update model and deployment handlers --- ads/aqua/extension/deployment_handler.py | 12 ++++++++---- ads/aqua/extension/model_handler.py | 16 +++++++++------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/ads/aqua/extension/deployment_handler.py b/ads/aqua/extension/deployment_handler.py index 88ad84272..2a3e827c3 100644 --- a/ads/aqua/extension/deployment_handler.py +++ b/ads/aqua/extension/deployment_handler.py @@ -59,7 +59,7 @@ def delete(self, model_deployment_id): return self.finish(AquaDeploymentApp().delete(model_deployment_id)) @handle_exceptions - def put(self, *args, **kwargs): + def put(self, *args, **kwargs): # noqa: ARG002 """ Handles put request for the activating and deactivating OCI datascience model deployments Raises @@ -82,7 +82,7 @@ def put(self, *args, **kwargs): raise HTTPError(400, f"The request {self.request.path} is invalid.") @handle_exceptions - def post(self, *args, **kwargs): + def post(self, *args, **kwargs): # noqa: ARG002 """ Handles post request for the deployment APIs Raises @@ -132,6 +132,8 @@ def post(self, *args, **kwargs): private_endpoint_id = input_data.get("private_endpoint_id") container_image_uri = input_data.get("container_image_uri") cmd_var = input_data.get("cmd_var") + freeform_tags = input_data.get("freeform_tags") + defined_tags = input_data.get("defined_tags") self.finish( AquaDeploymentApp().create( @@ -157,6 +159,8 @@ def post(self, *args, **kwargs): private_endpoint_id=private_endpoint_id, container_image_uri=container_image_uri, cmd_var=cmd_var, + freeform_tags=freeform_tags, + defined_tags=defined_tags, ) ) @@ -196,7 +200,7 @@ def validate_predict_url(endpoint): return False @handle_exceptions - def post(self, *args, **kwargs): + def post(self, *args, **kwargs): # noqa: ARG002 """ Handles inference request for the Active Model Deployments Raises @@ -262,7 +266,7 @@ def get(self, model_id): ) @handle_exceptions - def post(self, *args, **kwargs): + def post(self, *args, **kwargs): # noqa: ARG002 """Handles post request for the deployment param handler API. Raises diff --git a/ads/aqua/extension/model_handler.py b/ads/aqua/extension/model_handler.py index 1a322d801..42f90ffef 100644 --- a/ads/aqua/extension/model_handler.py +++ b/ads/aqua/extension/model_handler.py @@ -96,7 +96,7 @@ def list(self): ) @handle_exceptions - def post(self, *args, **kwargs): + def post(self, *args, **kwargs): # noqa: ARG002 """ Handles post request for the registering any Aqua model. Raises @@ -131,6 +131,8 @@ def post(self, *args, **kwargs): inference_container_uri = input_data.get("inference_container_uri") allow_patterns = input_data.get("allow_patterns") ignore_patterns = input_data.get("ignore_patterns") + freeform_tags = input_data.get("freeform_tags") + defined_tags = input_data.get("defined_tags") return self.finish( AquaModelApp().register( @@ -145,6 +147,8 @@ def post(self, *args, **kwargs): inference_container_uri=inference_container_uri, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, + freeform_tags=freeform_tags, + defined_tags=defined_tags, ) ) @@ -170,11 +174,9 @@ def put(self, id): enable_finetuning = input_data.get("enable_finetuning") task = input_data.get("task") - app=AquaModelApp() + app = AquaModelApp() self.finish( - app.edit_registered_model( - id, inference_container, enable_finetuning, task - ) + app.edit_registered_model(id, inference_container, enable_finetuning, task) ) app.clear_model_details_cache(model_id=id) @@ -218,7 +220,7 @@ def _find_matching_aqua_model(model_id: str) -> Optional[AquaModelSummary]: return None @handle_exceptions - def get(self, *args, **kwargs): + def get(self, *args, **kwargs): # noqa: ARG002 """ Finds a list of matching models from hugging face based on query string provided from users. @@ -239,7 +241,7 @@ def get(self, *args, **kwargs): return self.finish({"models": models}) @handle_exceptions - def post(self, *args, **kwargs): + def post(self, *args, **kwargs): # noqa: ARG002 """Handles post request for the HF Models APIs Raises