From 69ead2ef98493f5e4642920ba7f0fe5ec492d5de Mon Sep 17 00:00:00 2001 From: Vipul Date: Wed, 30 Oct 2024 10:21:03 -0700 Subject: [PATCH 1/2] Replaced hide_index to hide --- ads/dataset/dataset.py | 4 ++-- ads/dataset/factory.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ads/dataset/dataset.py b/ads/dataset/dataset.py index 102876ad0..025667188 100644 --- a/ads/dataset/dataset.py +++ b/ads/dataset/dataset.py @@ -202,7 +202,7 @@ def _repr_html_(self): self.sampled_df.head(5) .style.set_table_styles(utils.get_dataframe_styles()) .set_table_attributes("class=table") - .hide_index() + .hide() .to_html() ) ) @@ -261,7 +261,7 @@ def _repr_html_(self): utils.horizontal_scrollable_div( self.style.set_table_styles(utils.get_dataframe_styles()) .set_table_attributes("class=table") - .hide_index() + .hide() .to_html() ) ) diff --git a/ads/dataset/factory.py b/ads/dataset/factory.py index 1611f4946..c7bfb5139 100644 --- a/ads/dataset/factory.py +++ b/ads/dataset/factory.py @@ -366,7 +366,7 @@ def list_snapshots(snapshot_dir=None, name="", storage_options=None, **kwargs): display( HTML( list_df.style.set_table_attributes("class=table") - .hide_index() + .hide() .to_html() ) ) From c951a1d77b9578cf04de1d43dcb7bcb9a9207b26 Mon Sep 17 00:00:00 2001 From: Vipul Date: Tue, 5 Nov 2024 12:01:15 -0800 Subject: [PATCH 2/2] allow customizing tei deployment --- ads/aqua/extension/deployment_handler.py | 4 +++ ads/aqua/extension/model_handler.py | 2 ++ .../aqua/test_deployment_handler.py | 2 ++ .../with_extras/aqua/test_model_handler.py | 30 ++++++++++++++++--- 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/ads/aqua/extension/deployment_handler.py b/ads/aqua/extension/deployment_handler.py index e8652e46f..ed0bfdd74 100644 --- a/ads/aqua/extension/deployment_handler.py +++ b/ads/aqua/extension/deployment_handler.py @@ -103,6 +103,8 @@ def post(self, *args, **kwargs): memory_in_gbs = input_data.get("memory_in_gbs") model_file = input_data.get("model_file") private_endpoint_id = input_data.get("private_endpoint_id") + container_image_uri = input_data.get("container_image_uri") + cmd_var = input_data.get("cmd_var") self.finish( AquaDeploymentApp().create( @@ -126,6 +128,8 @@ def post(self, *args, **kwargs): memory_in_gbs=memory_in_gbs, model_file=model_file, private_endpoint_id=private_endpoint_id, + container_image_uri=container_image_uri, + cmd_var=cmd_var, ) ) diff --git a/ads/aqua/extension/model_handler.py b/ads/aqua/extension/model_handler.py index 5fa25992f..96f2826d0 100644 --- a/ads/aqua/extension/model_handler.py +++ b/ads/aqua/extension/model_handler.py @@ -123,6 +123,7 @@ def post(self, *args, **kwargs): download_from_hf = ( str(input_data.get("download_from_hf", "false")).lower() == "true" ) + inference_container_uri = input_data.get("inference_container_uri") return self.finish( AquaModelApp().register( @@ -134,6 +135,7 @@ def post(self, *args, **kwargs): compartment_id=compartment_id, project_id=project_id, model_file=model_file, + inference_container_uri=inference_container_uri, ) ) diff --git a/tests/unitary/with_extras/aqua/test_deployment_handler.py b/tests/unitary/with_extras/aqua/test_deployment_handler.py index d4def6bb8..a91955160 100644 --- a/tests/unitary/with_extras/aqua/test_deployment_handler.py +++ b/tests/unitary/with_extras/aqua/test_deployment_handler.py @@ -130,6 +130,8 @@ def test_post(self, mock_create): ocpus=None, model_file=None, private_endpoint_id=None, + container_image_uri=None, + cmd_var=None, ) diff --git a/tests/unitary/with_extras/aqua/test_model_handler.py b/tests/unitary/with_extras/aqua/test_model_handler.py index cb7a27080..d217684fb 100644 --- a/tests/unitary/with_extras/aqua/test_model_handler.py +++ b/tests/unitary/with_extras/aqua/test_model_handler.py @@ -10,6 +10,7 @@ from huggingface_hub.hf_api import HfApi, ModelInfo from huggingface_hub.utils import GatedRepoError from notebook.base.handlers import IPythonHandler +from parameterized import parameterized from ads.aqua.common.errors import AquaRuntimeError from ads.aqua.common.utils import get_hf_model_info @@ -90,9 +91,25 @@ def test_list(self, mock_list): compartment_id=None, project_id=None, model_type=None ) + @parameterized.expand( + [ + (None, None, False, None), + ("odsc-llm-fine-tuning", None, False, None), + (None, "test.gguf", True, None), + (None, None, True, "iad.ocir.io//:"), + ], + ) @patch("notebook.base.handlers.APIHandler.finish") @patch("ads.aqua.model.AquaModelApp.register") - def test_register(self, mock_register, mock_finish): + def test_register( + self, + finetuning_container, + model_file, + download_from_hf, + inference_container_uri, + mock_register, + mock_finish, + ): mock_register.return_value = AquaModel( id="test_id", inference_container="odsc-tgi-serving", @@ -105,6 +122,10 @@ def test_register(self, mock_register, mock_finish): model="test_model_name", os_path="test_os_path", inference_container="odsc-tgi-serving", + finetuning_container=finetuning_container, + model_file=model_file, + download_from_hf=download_from_hf, + inference_container_uri=inference_container_uri, ) ) result = self.model_handler.post() @@ -112,11 +133,12 @@ def test_register(self, mock_register, mock_finish): model="test_model_name", os_path="test_os_path", inference_container="odsc-tgi-serving", - finetuning_container=None, + finetuning_container=finetuning_container, compartment_id=None, project_id=None, - model_file=None, - download_from_hf=False, + model_file=model_file, + download_from_hf=download_from_hf, + inference_container_uri=inference_container_uri, ) assert result["id"] == "test_id" assert result["inference_container"] == "odsc-tgi-serving"