From 8dc7bbac883a53a71e5af3ebe7ca8250a4b5bfaf Mon Sep 17 00:00:00 2001 From: kumar shivam ranjan Date: Thu, 21 Nov 2024 02:59:55 +0530 Subject: [PATCH] Adding available shapes API for model deployment --- ads/aqua/extension/ui_handler.py | 14 +++++++++++++- ads/aqua/ui.py | 26 ++++++++++++++++++++++++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/ads/aqua/extension/ui_handler.py b/ads/aqua/extension/ui_handler.py index 732151bf0..685fa55d9 100644 --- a/ads/aqua/extension/ui_handler.py +++ b/ads/aqua/extension/ui_handler.py @@ -68,6 +68,8 @@ def get(self, id=""): return self.list_buckets() elif paths.startswith("aqua/job/shapes"): return self.list_job_shapes() + elif paths.startswith("aqua/modeldeployment/shapes"): + return self.list_model_deployment_shapes() elif paths.startswith("aqua/vcn"): return self.list_vcn() elif paths.startswith("aqua/subnets"): @@ -160,6 +162,15 @@ def list_job_shapes(self, **kwargs): AquaUIApp().list_job_shapes(compartment_id=compartment_id, **kwargs) ) + def list_model_deployment_shapes(self, **kwargs): + """Lists model deployment shapes available in the specified compartment.""" + compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID) + return self.finish( + AquaUIApp().list_model_deployment_shapes( + compartment_id=compartment_id, **kwargs + ) + ) + def list_vcn(self, **kwargs): """Lists the virtual cloud networks (VCNs) in the specified compartment.""" compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID) @@ -255,8 +266,9 @@ def post(self, *args, **kwargs): __handlers__ = [ ("logging/?([^/]*)", AquaUIHandler), ("compartments/?([^/]*)", AquaUIHandler), - # TODO: change url to evaluation/experiements/?([^/]*) + # TODO: change url to evaluation/experiments/?([^/]*) ("experiment/?([^/]*)", AquaUIHandler), + ("modeldeployment/?([^/]*)", AquaUIHandler), ("versionsets/?([^/]*)", AquaUIHandler), ("buckets/?([^/]*)", AquaUIHandler), ("job/shapes/?([^/]*)", AquaUIHandler), diff --git a/ads/aqua/ui.py b/ads/aqua/ui.py index 39fa63f09..10ada6524 100644 --- a/ads/aqua/ui.py +++ b/ads/aqua/ui.py @@ -481,12 +481,12 @@ def _is_bucket_versioned(response): @telemetry(entry_point="plugin=ui&action=list_job_shapes", name="aqua") def list_job_shapes(self, **kwargs) -> list: - """Lists all availiable job shapes for the specified compartment. + """Lists all available job shapes for the specified compartment. Parameters ---------- **kwargs - Addtional arguments, such as `compartment_id`, + Additional arguments, such as `compartment_id`, for `list_job_shapes `_ Returns @@ -500,6 +500,28 @@ def list_job_shapes(self, **kwargs) -> list: ).data return sanitize_response(oci_client=self.ds_client, response=res) + @telemetry(entry_point="plugin=ui&action=list_model_deployment_shapes", name="aqua") + def list_model_deployment_shapes(self, **kwargs) -> list: + """Lists all available shapes for model deployment in the specified compartment. + + Parameters + ---------- + **kwargs + Additional arguments, such as `compartment_id`, + for `list_model_deployment_shapes `_ + + Returns + ------- + str has json representation of `oci.data_science.models.ModelDeploymentShapeSummary`.""" + compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID) + logger.info( + f"Loading model deployment shape summary from compartment: {compartment_id}" + ) + res = self.ds_client.list_model_deployment_shapes( + compartment_id=compartment_id, **kwargs + ).data + return sanitize_response(oci_client=self.ds_client, response=res) + @telemetry(entry_point="plugin=ui&action=list_vcn", name="aqua") def list_vcn(self, **kwargs) -> list: """Lists the virtual cloud networks (VCNs) in the specified compartment.