Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding API to fetch list of available shapes for model deployment #1012

Merged
merged 3 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion ads/aqua/extension/ui_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def get(self, id=""):
return self.list_buckets()
elif paths.startswith("aqua/job/shapes"):
return self.list_job_shapes()
elif paths.startswith("aqua/modeldeployment/shapes"):
return self.list_model_deployment_shapes()
elif paths.startswith("aqua/vcn"):
return self.list_vcn()
elif paths.startswith("aqua/subnets"):
Expand Down Expand Up @@ -160,6 +162,15 @@ def list_job_shapes(self, **kwargs):
AquaUIApp().list_job_shapes(compartment_id=compartment_id, **kwargs)
)

def list_model_deployment_shapes(self, **kwargs):
"""Lists model deployment shapes available in the specified compartment."""
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
return self.finish(
AquaUIApp().list_model_deployment_shapes(
compartment_id=compartment_id, **kwargs
)
)

def list_vcn(self, **kwargs):
"""Lists the virtual cloud networks (VCNs) in the specified compartment."""
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
Expand Down Expand Up @@ -255,8 +266,9 @@ def post(self, *args, **kwargs):
__handlers__ = [
("logging/?([^/]*)", AquaUIHandler),
("compartments/?([^/]*)", AquaUIHandler),
# TODO: change url to evaluation/experiements/?([^/]*)
# TODO: change url to evaluation/experiments/?([^/]*)
("experiment/?([^/]*)", AquaUIHandler),
("modeldeployment/?([^/]*)", AquaUIHandler),
("versionsets/?([^/]*)", AquaUIHandler),
("buckets/?([^/]*)", AquaUIHandler),
("job/shapes/?([^/]*)", AquaUIHandler),
Expand Down
26 changes: 24 additions & 2 deletions ads/aqua/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,12 +481,12 @@ def _is_bucket_versioned(response):

@telemetry(entry_point="plugin=ui&action=list_job_shapes", name="aqua")
def list_job_shapes(self, **kwargs) -> list:
"""Lists all availiable job shapes for the specified compartment.
"""Lists all available job shapes for the specified compartment.

Parameters
----------
**kwargs
Addtional arguments, such as `compartment_id`,
Additional arguments, such as `compartment_id`,
for `list_job_shapes <https://docs.oracle.com/en-us/iaas/tools/python/2.122.0/api/data_science/client/oci.data_science.DataScienceClient.html#oci.data_science.DataScienceClient.list_job_shapes>`_

Returns
Expand All @@ -500,6 +500,28 @@ def list_job_shapes(self, **kwargs) -> list:
).data
return sanitize_response(oci_client=self.ds_client, response=res)

@telemetry(entry_point="plugin=ui&action=list_model_deployment_shapes", name="aqua")
def list_model_deployment_shapes(self, **kwargs) -> list:
"""Lists all available shapes for model deployment in the specified compartment.

Parameters
----------
**kwargs
Additional arguments, such as `compartment_id`,
for `list_model_deployment_shapes <https://docs.oracle.com/en-us/iaas/api/#/en/data-science/20190101/ModelDeploymentShapeSummary/ListModelDeploymentShapes>`_

Returns
-------
str has json representation of `oci.data_science.models.ModelDeploymentShapeSummary`."""
compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
logger.info(
f"Loading model deployment shape summary from compartment: {compartment_id}"
)
res = self.ds_client.list_model_deployment_shapes(
compartment_id=compartment_id, **kwargs
).data
return sanitize_response(oci_client=self.ds_client, response=res)

@telemetry(entry_point="plugin=ui&action=list_vcn", name="aqua")
def list_vcn(self, **kwargs) -> list:
"""Lists the virtual cloud networks (VCNs) in the specified compartment.
Expand Down
Loading