Skip to content

Commit

Permalink
allow customizing tei deployment
Browse files Browse the repository at this point in the history
  • Loading branch information
VipulMascarenhas committed Nov 5, 2024
1 parent 60d48f2 commit c951a1d
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 4 deletions.
4 changes: 4 additions & 0 deletions ads/aqua/extension/deployment_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def post(self, *args, **kwargs):
memory_in_gbs = input_data.get("memory_in_gbs")
model_file = input_data.get("model_file")
private_endpoint_id = input_data.get("private_endpoint_id")
container_image_uri = input_data.get("container_image_uri")
cmd_var = input_data.get("cmd_var")

self.finish(
AquaDeploymentApp().create(
Expand All @@ -126,6 +128,8 @@ def post(self, *args, **kwargs):
memory_in_gbs=memory_in_gbs,
model_file=model_file,
private_endpoint_id=private_endpoint_id,
container_image_uri=container_image_uri,
cmd_var=cmd_var,
)
)

Expand Down
2 changes: 2 additions & 0 deletions ads/aqua/extension/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def post(self, *args, **kwargs):
download_from_hf = (
str(input_data.get("download_from_hf", "false")).lower() == "true"
)
inference_container_uri = input_data.get("inference_container_uri")

return self.finish(
AquaModelApp().register(
Expand All @@ -134,6 +135,7 @@ def post(self, *args, **kwargs):
compartment_id=compartment_id,
project_id=project_id,
model_file=model_file,
inference_container_uri=inference_container_uri,
)
)

Expand Down
2 changes: 2 additions & 0 deletions tests/unitary/with_extras/aqua/test_deployment_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def test_post(self, mock_create):
ocpus=None,
model_file=None,
private_endpoint_id=None,
container_image_uri=None,
cmd_var=None,
)


Expand Down
30 changes: 26 additions & 4 deletions tests/unitary/with_extras/aqua/test_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from huggingface_hub.hf_api import HfApi, ModelInfo
from huggingface_hub.utils import GatedRepoError
from notebook.base.handlers import IPythonHandler
from parameterized import parameterized

from ads.aqua.common.errors import AquaRuntimeError
from ads.aqua.common.utils import get_hf_model_info
Expand Down Expand Up @@ -90,9 +91,25 @@ def test_list(self, mock_list):
compartment_id=None, project_id=None, model_type=None
)

@parameterized.expand(
[
(None, None, False, None),
("odsc-llm-fine-tuning", None, False, None),
(None, "test.gguf", True, None),
(None, None, True, "iad.ocir.io/<namespace>/<image>:<tag>"),
],
)
@patch("notebook.base.handlers.APIHandler.finish")
@patch("ads.aqua.model.AquaModelApp.register")
def test_register(self, mock_register, mock_finish):
def test_register(
self,
finetuning_container,
model_file,
download_from_hf,
inference_container_uri,
mock_register,
mock_finish,
):
mock_register.return_value = AquaModel(
id="test_id",
inference_container="odsc-tgi-serving",
Expand All @@ -105,18 +122,23 @@ def test_register(self, mock_register, mock_finish):
model="test_model_name",
os_path="test_os_path",
inference_container="odsc-tgi-serving",
finetuning_container=finetuning_container,
model_file=model_file,
download_from_hf=download_from_hf,
inference_container_uri=inference_container_uri,
)
)
result = self.model_handler.post()
mock_register.assert_called_with(
model="test_model_name",
os_path="test_os_path",
inference_container="odsc-tgi-serving",
finetuning_container=None,
finetuning_container=finetuning_container,
compartment_id=None,
project_id=None,
model_file=None,
download_from_hf=False,
model_file=model_file,
download_from_hf=download_from_hf,
inference_container_uri=inference_container_uri,
)
assert result["id"] == "test_id"
assert result["inference_container"] == "odsc-tgi-serving"
Expand Down

0 comments on commit c951a1d

Please sign in to comment.