Skip to content

Commit

Permalink
Improved register BYOM model (#1005)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrDzurb authored Nov 13, 2024
2 parents 8d358c7 + 89d6052 commit 0f446a5
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 26 deletions.
5 changes: 4 additions & 1 deletion ads/aqua/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,13 +788,14 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""


def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str:
"""Upload the local folder to the object storage
Args:
os_path (str): object storage URI with prefix. This is the path to upload
local_dir (str): Local directory where the object is downloaded
model_name (str): Name of the huggingface model
exclude_pattern (optional, str): The matching pattern of files to be excluded from uploading.
Retuns:
str: Object name inside the bucket
"""
Expand All @@ -804,6 +805,8 @@ def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
auth_state = AuthState()
object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite"
if exclude_pattern:
command += f" --exclude {exclude_pattern}"
try:
logger.info(f"Running: {command}")
subprocess.check_call(shlex.split(command))
Expand Down
1 change: 1 addition & 0 deletions ads/aqua/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
AQUA_MODEL_ARTIFACT_FILE = "model_file"
HF_METADATA_FOLDER = ".cache/"
HF_LOGIN_DEFAULT_TIMEOUT = 2

TRAINING_METRICS_FINAL = "training_metrics_final"
Expand Down
4 changes: 4 additions & 0 deletions ads/aqua/extension/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def post(self, *args, **kwargs):
str(input_data.get("download_from_hf", "false")).lower() == "true"
)
inference_container_uri = input_data.get("inference_container_uri")
allow_patterns = input_data.get("allow_patterns")
ignore_patterns = input_data.get("ignore_patterns")

return self.finish(
AquaModelApp().register(
Expand All @@ -141,6 +143,8 @@ def post(self, *args, **kwargs):
project_id=project_id,
model_file=model_file,
inference_container_uri=inference_container_uri,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
)

Expand Down
2 changes: 2 additions & 0 deletions ads/aqua/model/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ class ImportModelDetails(CLIBuilderMixin):
project_id: Optional[str] = None
model_file: Optional[str] = None
inference_container_uri: Optional[str] = None
allow_patterns: Optional[List[str]] = None
ignore_patterns: Optional[List[str]] = None

def __post_init__(self):
self._command = "model register"
44 changes: 25 additions & 19 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
AQUA_MODEL_ARTIFACT_FILE,
AQUA_MODEL_TYPE_CUSTOM,
HF_METADATA_FOLDER,
LICENSE_TXT,
MODEL_BY_REFERENCE_OSS_PATH_KEY,
README,
Expand Down Expand Up @@ -1274,6 +1275,8 @@ def _download_model_from_hf(
model_name: str,
os_path: str,
local_dir: str = None,
allow_patterns: List[str] = None,
ignore_patterns: List[str] = None,
) -> str:
"""This helper function downloads the model artifact from Hugging Face to a local folder, then uploads
to object storage location.
Expand All @@ -1283,6 +1286,12 @@ def _download_model_from_hf(
model_name (str): The huggingface model name.
os_path (str): The OS path where the model files are located.
local_dir (str): The local temp dir to store the huggingface model.
allow_patterns (list): Model files matching at least one pattern are downloaded.
Example: ["*.json"] will download all .json files. ["folder/*"] will download all files under `folder`.
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
Returns
-------
Expand All @@ -1293,30 +1302,19 @@ def _download_model_from_hf(
if not local_dir:
local_dir = os.path.join(os.path.expanduser("~"), "cached-model")
local_dir = os.path.join(local_dir, model_name)
retry = 10
i = 0
huggingface_download_err_message = None
while i < retry:
try:
# Download to cache folder. The while loop retries when there is a network failure
snapshot_download(repo_id=model_name)
except Exception as e:
huggingface_download_err_message = str(e)
i += 1
else:
break
if i == retry:
raise Exception(
f"Could not download the model {model_name} from https://huggingface.co with message {huggingface_download_err_message}"
)
os.makedirs(local_dir, exist_ok=True)
# Copy the model from the cache to destination
snapshot_download(repo_id=model_name, local_dir=local_dir)
# Upload to object storage
snapshot_download(
repo_id=model_name,
local_dir=local_dir,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
# Upload to object storage and skip .cache/huggingface/ folder
model_artifact_path = upload_folder(
os_path=os_path,
local_dir=local_dir,
model_name=model_name,
exclude_pattern=f"{HF_METADATA_FOLDER}*"
)

return model_artifact_path
Expand All @@ -1335,6 +1333,12 @@ def register(
os_path (str): Object storage destination URI to store the downloaded model. Format: oci://bucket-name@namespace/prefix
inference_container (str): selects service defaults
finetuning_container (str): selects service defaults
allow_patterns (list): Model files matching at least one pattern are downloaded.
Example: ["*.json"] will download all .json files. ["folder/*"] will download all files under `folder`.
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
Returns:
AquaModel:
Expand Down Expand Up @@ -1381,6 +1385,8 @@ def register(
model_name=model_name,
os_path=import_model_details.os_path,
local_dir=import_model_details.local_dir,
allow_patterns=import_model_details.allow_patterns,
ignore_patterns=import_model_details.ignore_patterns,
).rstrip("/")
else:
artifact_path = import_model_details.os_path.rstrip("/")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ opctl = [
"rich",
"fire",
"cachetools",
"huggingface_hub==0.23.4"
"huggingface_hub==0.26.2"
]
optuna = ["optuna==2.9.0", "oracle_ads[viz]"]
spark = ["pyspark>=3.0.0"]
Expand Down
7 changes: 6 additions & 1 deletion tests/unitary/with_extras/aqua/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from unittest.mock import MagicMock, patch

import oci
from ads.aqua.constants import HF_METADATA_FOLDER
import pytest
from ads.aqua.ui import ModelFormat
from parameterized import parameterized
Expand Down Expand Up @@ -746,14 +747,18 @@ def test_import_verified_model(
os_path=os_path,
local_dir=str(tmpdir),
download_from_hf=True,
allow_patterns=["*.json"],
ignore_patterns=["test.json"]
)
mock_snapshot_download.assert_called_with(
repo_id=model_name,
local_dir=f"{str(tmpdir)}/{model_name}",
allow_patterns=["*.json"],
ignore_patterns=["test.json"]
)
mock_subprocess.assert_called_with(
shlex.split(
f"oci os object bulk-upload --src-dir {str(tmpdir)}/{model_name} --prefix prefix/path/{model_name}/ -bn aqua-bkt -ns aqua-ns --auth api_key --profile DEFAULT --no-overwrite"
f"oci os object bulk-upload --src-dir {str(tmpdir)}/{model_name} --prefix prefix/path/{model_name}/ -bn aqua-bkt -ns aqua-ns --auth api_key --profile DEFAULT --no-overwrite --exclude {HF_METADATA_FOLDER}*"
)
)
else:
Expand Down
14 changes: 10 additions & 4 deletions tests/unitary/with_extras/aqua/test_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ def test_list(self, mock_list):

@parameterized.expand(
[
(None, None, False, None),
("odsc-llm-fine-tuning", None, False, None),
(None, "test.gguf", True, None),
(None, None, True, "iad.ocir.io/<namespace>/<image>:<tag>"),
(None, None, False, None, None, None),
("odsc-llm-fine-tuning", None, False, None, None, ["test.json"]),
(None, "test.gguf", True, None, ["*.json"], None),
(None, None, True, "iad.ocir.io/<namespace>/<image>:<tag>", ["*.json"], ["test.json"]),
],
)
@patch("notebook.base.handlers.APIHandler.finish")
Expand All @@ -146,6 +146,8 @@ def test_register(
model_file,
download_from_hf,
inference_container_uri,
allow_patterns,
ignore_patterns,
mock_register,
mock_finish,
):
Expand All @@ -165,6 +167,8 @@ def test_register(
model_file=model_file,
download_from_hf=download_from_hf,
inference_container_uri=inference_container_uri,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns
)
)
result = self.model_handler.post()
Expand All @@ -178,6 +182,8 @@ def test_register(
model_file=model_file,
download_from_hf=download_from_hf,
inference_container_uri=inference_container_uri,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns
)
assert result["id"] == "test_id"
assert result["inference_container"] == "odsc-tgi-serving"
Expand Down

0 comments on commit 0f446a5

Please sign in to comment.