diff --git a/ads/aqua/common/utils.py b/ads/aqua/common/utils.py index 7a0d9d46b..6e1e09aca 100644 --- a/ads/aqua/common/utils.py +++ b/ads/aqua/common/utils.py @@ -788,13 +788,14 @@ def get_ocid_substring(ocid: str, key_len: int) -> str: return ocid[-key_len:] if ocid and len(ocid) > key_len else "" -def upload_folder(os_path: str, local_dir: str, model_name: str) -> str: +def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str: """Upload the local folder to the object storage Args: os_path (str): object storage URI with prefix. This is the path to upload local_dir (str): Local directory where the object is downloaded model_name (str): Name of the huggingface model + exclude_pattern (optional, str): The matching pattern of files to be excluded from uploading. Retuns: str: Object name inside the bucket """ @@ -804,6 +805,8 @@ def upload_folder(os_path: str, local_dir: str, model_name: str) -> str: auth_state = AuthState() object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/" command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite" + if exclude_pattern: + command += f" --exclude {exclude_pattern}" try: logger.info(f"Running: {command}") subprocess.check_call(shlex.split(command)) diff --git a/ads/aqua/constants.py b/ads/aqua/constants.py index 76406d0d7..0b03a1507 100644 --- a/ads/aqua/constants.py +++ b/ads/aqua/constants.py @@ -35,6 +35,7 @@ AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path" AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type" AQUA_MODEL_ARTIFACT_FILE = "model_file" +HF_METADATA_FOLDER = ".cache/" HF_LOGIN_DEFAULT_TIMEOUT = 2 TRAINING_METRICS_FINAL = "training_metrics_final" diff --git a/ads/aqua/extension/model_handler.py b/ads/aqua/extension/model_handler.py index 8a5e490ea..1a322d801 100644 --- a/ads/aqua/extension/model_handler.py +++ b/ads/aqua/extension/model_handler.py @@ -129,6 +129,8 @@ def post(self, *args, **kwargs): str(input_data.get("download_from_hf", "false")).lower() == "true" ) inference_container_uri = input_data.get("inference_container_uri") + allow_patterns = input_data.get("allow_patterns") + ignore_patterns = input_data.get("ignore_patterns") return self.finish( AquaModelApp().register( @@ -141,6 +143,8 @@ def post(self, *args, **kwargs): project_id=project_id, model_file=model_file, inference_container_uri=inference_container_uri, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, ) ) diff --git a/ads/aqua/model/entities.py b/ads/aqua/model/entities.py index 3ba884da9..2d94fb3d6 100644 --- a/ads/aqua/model/entities.py +++ b/ads/aqua/model/entities.py @@ -289,6 +289,8 @@ class ImportModelDetails(CLIBuilderMixin): project_id: Optional[str] = None model_file: Optional[str] = None inference_container_uri: Optional[str] = None + allow_patterns: Optional[List[str]] = None + ignore_patterns: Optional[List[str]] = None def __post_init__(self): self._command = "model register" diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index ce8d523d2..2d01022ae 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -40,6 +40,7 @@ AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE, AQUA_MODEL_ARTIFACT_FILE, AQUA_MODEL_TYPE_CUSTOM, + HF_METADATA_FOLDER, LICENSE_TXT, MODEL_BY_REFERENCE_OSS_PATH_KEY, README, @@ -1274,6 +1275,8 @@ def _download_model_from_hf( model_name: str, os_path: str, local_dir: str = None, + allow_patterns: List[str] = None, + ignore_patterns: List[str] = None, ) -> str: """This helper function downloads the model artifact from Hugging Face to a local folder, then uploads to object storage location. @@ -1283,6 +1286,12 @@ def _download_model_from_hf( model_name (str): The huggingface model name. os_path (str): The OS path where the model files are located. local_dir (str): The local temp dir to store the huggingface model. + allow_patterns (list): Model files matching at least one pattern are downloaded. + Example: ["*.json"] will download all .json files. ["folder/*"] will download all files under `folder`. + Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html + ignore_patterns (list): Model files matching any of the patterns are not downloaded. + Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`. + Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html Returns ------- @@ -1293,30 +1302,19 @@ def _download_model_from_hf( if not local_dir: local_dir = os.path.join(os.path.expanduser("~"), "cached-model") local_dir = os.path.join(local_dir, model_name) - retry = 10 - i = 0 - huggingface_download_err_message = None - while i < retry: - try: - # Download to cache folder. The while loop retries when there is a network failure - snapshot_download(repo_id=model_name) - except Exception as e: - huggingface_download_err_message = str(e) - i += 1 - else: - break - if i == retry: - raise Exception( - f"Could not download the model {model_name} from https://huggingface.co with message {huggingface_download_err_message}" - ) os.makedirs(local_dir, exist_ok=True) - # Copy the model from the cache to destination - snapshot_download(repo_id=model_name, local_dir=local_dir) - # Upload to object storage + snapshot_download( + repo_id=model_name, + local_dir=local_dir, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + # Upload to object storage and skip .cache/huggingface/ folder model_artifact_path = upload_folder( os_path=os_path, local_dir=local_dir, model_name=model_name, + exclude_pattern=f"{HF_METADATA_FOLDER}*" ) return model_artifact_path @@ -1335,6 +1333,12 @@ def register( os_path (str): Object storage destination URI to store the downloaded model. Format: oci://bucket-name@namespace/prefix inference_container (str): selects service defaults finetuning_container (str): selects service defaults + allow_patterns (list): Model files matching at least one pattern are downloaded. + Example: ["*.json"] will download all .json files. ["folder/*"] will download all files under `folder`. + Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html + ignore_patterns (list): Model files matching any of the patterns are not downloaded. + Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`. + Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html Returns: AquaModel: @@ -1381,6 +1385,8 @@ def register( model_name=model_name, os_path=import_model_details.os_path, local_dir=import_model_details.local_dir, + allow_patterns=import_model_details.allow_patterns, + ignore_patterns=import_model_details.ignore_patterns, ).rstrip("/") else: artifact_path = import_model_details.os_path.rstrip("/") diff --git a/pyproject.toml b/pyproject.toml index 1535e89c1..c77206338 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,7 +125,7 @@ opctl = [ "rich", "fire", "cachetools", - "huggingface_hub==0.23.4" + "huggingface_hub==0.26.2" ] optuna = ["optuna==2.9.0", "oracle_ads[viz]"] spark = ["pyspark>=3.0.0"] diff --git a/tests/unitary/with_extras/aqua/test_model.py b/tests/unitary/with_extras/aqua/test_model.py index f84dd604c..bc6fc21d7 100644 --- a/tests/unitary/with_extras/aqua/test_model.py +++ b/tests/unitary/with_extras/aqua/test_model.py @@ -12,6 +12,7 @@ from unittest.mock import MagicMock, patch import oci +from ads.aqua.constants import HF_METADATA_FOLDER import pytest from ads.aqua.ui import ModelFormat from parameterized import parameterized @@ -746,14 +747,18 @@ def test_import_verified_model( os_path=os_path, local_dir=str(tmpdir), download_from_hf=True, + allow_patterns=["*.json"], + ignore_patterns=["test.json"] ) mock_snapshot_download.assert_called_with( repo_id=model_name, local_dir=f"{str(tmpdir)}/{model_name}", + allow_patterns=["*.json"], + ignore_patterns=["test.json"] ) mock_subprocess.assert_called_with( shlex.split( - f"oci os object bulk-upload --src-dir {str(tmpdir)}/{model_name} --prefix prefix/path/{model_name}/ -bn aqua-bkt -ns aqua-ns --auth api_key --profile DEFAULT --no-overwrite" + f"oci os object bulk-upload --src-dir {str(tmpdir)}/{model_name} --prefix prefix/path/{model_name}/ -bn aqua-bkt -ns aqua-ns --auth api_key --profile DEFAULT --no-overwrite --exclude {HF_METADATA_FOLDER}*" ) ) else: diff --git a/tests/unitary/with_extras/aqua/test_model_handler.py b/tests/unitary/with_extras/aqua/test_model_handler.py index d4b741463..0367d4c3c 100644 --- a/tests/unitary/with_extras/aqua/test_model_handler.py +++ b/tests/unitary/with_extras/aqua/test_model_handler.py @@ -132,10 +132,10 @@ def test_list(self, mock_list): @parameterized.expand( [ - (None, None, False, None), - ("odsc-llm-fine-tuning", None, False, None), - (None, "test.gguf", True, None), - (None, None, True, "iad.ocir.io//:"), + (None, None, False, None, None, None), + ("odsc-llm-fine-tuning", None, False, None, None, ["test.json"]), + (None, "test.gguf", True, None, ["*.json"], None), + (None, None, True, "iad.ocir.io//:", ["*.json"], ["test.json"]), ], ) @patch("notebook.base.handlers.APIHandler.finish") @@ -146,6 +146,8 @@ def test_register( model_file, download_from_hf, inference_container_uri, + allow_patterns, + ignore_patterns, mock_register, mock_finish, ): @@ -165,6 +167,8 @@ def test_register( model_file=model_file, download_from_hf=download_from_hf, inference_container_uri=inference_container_uri, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns ) ) result = self.model_handler.post() @@ -178,6 +182,8 @@ def test_register( model_file=model_file, download_from_hf=download_from_hf, inference_container_uri=inference_container_uri, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns ) assert result["id"] == "test_id" assert result["inference_container"] == "odsc-tgi-serving"