diff --git a/ads/aqua/extension/model_handler.py b/ads/aqua/extension/model_handler.py index 42f90ffef..c7d050e84 100644 --- a/ads/aqua/extension/model_handler.py +++ b/ads/aqua/extension/model_handler.py @@ -133,6 +133,10 @@ def post(self, *args, **kwargs): # noqa: ARG002 ignore_patterns = input_data.get("ignore_patterns") freeform_tags = input_data.get("freeform_tags") defined_tags = input_data.get("defined_tags") + ignore_model_artifact_check = ( + str(input_data.get("ignore_model_artifact_check", "false")).lower() + == "true" + ) return self.finish( AquaModelApp().register( @@ -149,6 +153,7 @@ def post(self, *args, **kwargs): # noqa: ARG002 ignore_patterns=ignore_patterns, freeform_tags=freeform_tags, defined_tags=defined_tags, + ignore_model_artifact_check=ignore_model_artifact_check, ) ) diff --git a/ads/aqua/model/entities.py b/ads/aqua/model/entities.py index ecdb8b8e7..2d6d93cd8 100644 --- a/ads/aqua/model/entities.py +++ b/ads/aqua/model/entities.py @@ -293,6 +293,7 @@ class ImportModelDetails(CLIBuilderMixin): ignore_patterns: Optional[List[str]] = None freeform_tags: Optional[dict] = None defined_tags: Optional[dict] = None + ignore_model_artifact_check: Optional[bool] = None def __post_init__(self): self._command = "model register" diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 02e0df00f..a7952dde3 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -19,7 +19,11 @@ InferenceContainerTypeFamily, Tags, ) -from ads.aqua.common.errors import AquaRuntimeError, AquaValueError +from ads.aqua.common.errors import ( + AquaFileNotFoundError, + AquaRuntimeError, + AquaValueError, +) from ads.aqua.common.utils import ( LifecycleStatus, _build_resource_identifier, @@ -972,13 +976,23 @@ def get_model_files(os_path: str, model_format: ModelFormat) -> List[str]: # todo: revisit this logic to account for .bin files. In the current state, .bin and .safetensor models # are grouped in one category and validation checks for config.json files only. if model_format == ModelFormat.SAFETENSORS: + model_files.extend( + list_os_files_with_extension(oss_path=os_path, extension=".safetensors") + ) try: load_config( file_path=os_path, config_file_name=AQUA_MODEL_ARTIFACT_CONFIG, ) - except Exception: - pass + except Exception as ex: + message = ( + f"The model path {os_path} does not contain the file config.json. " + f"Please check if the path is correct or the model artifacts are available at this location." + ) + logger.warning( + f"{message}\n" + f"Details: {ex.reason if isinstance(ex, AquaFileNotFoundError) else str(ex)}\n" + ) else: model_files.append(AQUA_MODEL_ARTIFACT_CONFIG) @@ -1022,10 +1036,12 @@ def get_hf_model_files(model_name: str, model_format: ModelFormat) -> List[str]: for model_sibling in model_siblings: extension = pathlib.Path(model_sibling.rfilename).suffix[1:].upper() - if model_format == ModelFormat.SAFETENSORS: - if model_sibling.rfilename == AQUA_MODEL_ARTIFACT_CONFIG: - model_files.append(model_sibling.rfilename) - elif extension == model_format.value: + if ( + model_format == ModelFormat.SAFETENSORS + and model_sibling.rfilename == AQUA_MODEL_ARTIFACT_CONFIG + ): + model_files.append(model_sibling.rfilename) + if extension == model_format.value: model_files.append(model_sibling.rfilename) return model_files @@ -1061,7 +1077,10 @@ def _validate_model( safetensors_model_files = self.get_hf_model_files( model_name, ModelFormat.SAFETENSORS ) - if safetensors_model_files: + if ( + safetensors_model_files + and AQUA_MODEL_ARTIFACT_CONFIG in safetensors_model_files + ): hf_download_config_present = True gguf_model_files = self.get_hf_model_files(model_name, ModelFormat.GGUF) else: @@ -1117,8 +1136,11 @@ def _validate_model( Tags.LICENSE: license_value, } validation_result.tags = hf_tags - except Exception: - pass + except Exception as ex: + logger.debug( + f"An error occurred while getting tag information for model {model_name}. " + f"Error: {str(ex)}" + ) validation_result.model_formats = model_formats @@ -1173,40 +1195,55 @@ def _validate_safetensor_format( model_name: str = None, ): if import_model_details.download_from_hf: - # validates config.json exists for safetensors model from hugginface - if not hf_download_config_present: + # validates config.json exists for safetensors model from huggingface + if not ( + hf_download_config_present + or import_model_details.ignore_model_artifact_check + ): raise AquaRuntimeError( f"The model {model_name} does not contain {AQUA_MODEL_ARTIFACT_CONFIG} file as required " f"by {ModelFormat.SAFETENSORS.value} format model." f" Please check if the model name is correct in Hugging Face repository." ) + validation_result.telemetry_model_name = model_name else: + # validate if config.json is available from object storage, and get model name for telemetry + model_config = None try: model_config = load_config( file_path=import_model_details.os_path, config_file_name=AQUA_MODEL_ARTIFACT_CONFIG, ) except Exception as ex: - logger.error( - f"Exception occurred while loading config file from {import_model_details.os_path}" - f"Exception message: {ex}" - ) - raise AquaRuntimeError( + message = ( f"The model path {import_model_details.os_path} does not contain the file config.json. " f"Please check if the path is correct or the model artifacts are available at this location." - ) from ex - else: + ) + if not import_model_details.ignore_model_artifact_check: + logger.error( + f"{message}\n" + f"Details: {ex.reason if isinstance(ex, AquaFileNotFoundError) else str(ex)}" + ) + raise AquaRuntimeError(message) from ex + else: + logger.warning( + f"{message}\n" + f"Proceeding with model registration as ignore_model_artifact_check field is set." + ) + + if verified_model: + # model_type validation, log message if metadata field doesn't match. try: metadata_model_type = verified_model.custom_metadata_list.get( AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE ).value - if metadata_model_type: + if metadata_model_type and model_config is not None: if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config: if ( model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE] != metadata_model_type ): - raise AquaRuntimeError( + logger.debug( f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}" f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for " f"the model {model_name}. Please check if the path is correct or " @@ -1218,22 +1255,26 @@ def _validate_safetensor_format( f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in " f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration." ) - except Exception: - pass - if verified_model: - validation_result.telemetry_model_name = verified_model.display_name - elif ( - model_config is not None - and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config - ): - validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}" - elif ( - model_config is not None - and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config - ): - validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}" - else: - validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM + except Exception as ex: + # todo: raise exception if model_type doesn't match. Currently log message and pass since service + # models do not have this metadata. + logger.debug( + f"Error occurred while processing metadata for model {model_name}. " + f"Exception: {str(ex)}" + ) + validation_result.telemetry_model_name = verified_model.display_name + elif ( + model_config is not None + and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config + ): + validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}" + elif ( + model_config is not None + and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config + ): + validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}" + else: + validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM @staticmethod def _validate_gguf_format( @@ -1416,7 +1457,6 @@ def register( ).rstrip("/") else: artifact_path = import_model_details.os_path.rstrip("/") - # Create Model catalog entry with pass by reference ds_model = self._create_model_catalog_entry( os_path=artifact_path, diff --git a/tests/unitary/with_extras/aqua/test_model.py b/tests/unitary/with_extras/aqua/test_model.py index cabb8c523..569902c02 100644 --- a/tests/unitary/with_extras/aqua/test_model.py +++ b/tests/unitary/with_extras/aqua/test_model.py @@ -920,10 +920,18 @@ def test_import_model_with_project_compartment_override( assert model.project_id == project_override @pytest.mark.parametrize( - "download_from_hf", - [True, False], + ("ignore_artifact_check", "download_from_hf"), + [ + (True, True), + (True, False), + (False, True), + (False, False), + (None, False), + (None, True), + ], ) @patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create") + @patch("ads.model.datascience_model.DataScienceModel.sync") @patch("ads.model.datascience_model.DataScienceModel.upload_artifact") @patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects") @patch("ads.aqua.common.utils.load_config", side_effect=AquaFileNotFoundError) @@ -936,45 +944,65 @@ def test_import_model_with_missing_config( mock_load_config, mock_list_objects, mock_upload_artifact, + mock_sync, mock_ocidsc_create, - mock_get_container_config, + ignore_artifact_check, download_from_hf, mock_get_hf_model_info, mock_init_client, ): - """Test for validating if error is returned when model artifacts are incomplete or not available.""" - - os_path = "oci://aqua-bkt@aqua-ns/prefix/path" - model_name = "oracle/aqua-1t-mega-model" + my_model = "oracle/aqua-1t-mega-model" ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) - mock_list_objects.return_value = MagicMock(objects=[]) - reload(ads.aqua.model.model) - app = AquaModelApp() - app.list = MagicMock(return_value=[]) + # set object list from OSS without config.json + os_path = "oci://aqua-bkt@aqua-ns/prefix/path" + # set object list from HF without config.json if download_from_hf: - with pytest.raises(AquaValueError): - mock_get_hf_model_info.return_value.siblings = [] - with tempfile.TemporaryDirectory() as tmpdir: - model: AquaModel = app.register( - model=model_name, - os_path=os_path, - local_dir=str(tmpdir), - download_from_hf=True, - ) + mock_get_hf_model_info.return_value.siblings = [ + MagicMock(rfilename="model.safetensors") + ] else: - with pytest.raises(AquaRuntimeError): + obj1 = MagicMock(etag="12345-1234-1234-1234-123456789", size=150) + obj1.name = f"prefix/path/model.safetensors" + objects = [obj1] + mock_list_objects.return_value = MagicMock(objects=objects) + + reload(ads.aqua.model.model) + app = AquaModelApp() + with patch.object(AquaModelApp, "list") as aqua_model_mock_list: + aqua_model_mock_list.return_value = [ + AquaModelSummary( + id="test_id1", + name="organization1/name1", + organization="organization1", + ) + ] + + if ignore_artifact_check: model: AquaModel = app.register( - model=model_name, + model=my_model, os_path=os_path, - download_from_hf=False, + inference_container="odsc-vllm-or-tgi-container", + finetuning_container="odsc-llm-fine-tuning", + download_from_hf=download_from_hf, + ignore_model_artifact_check=ignore_artifact_check, ) + assert model.ready_to_deploy is True + else: + with pytest.raises(AquaRuntimeError): + model: AquaModel = app.register( + model=my_model, + os_path=os_path, + inference_container="odsc-vllm-or-tgi-container", + finetuning_container="odsc-llm-fine-tuning", + download_from_hf=download_from_hf, + ignore_model_artifact_check=ignore_artifact_check, + ) @patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create") @patch("ads.model.datascience_model.DataScienceModel.sync") @patch("ads.model.datascience_model.DataScienceModel.upload_artifact") @patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects") - @patch.object(HfApi, "model_info") @patch("ads.aqua.common.utils.load_config", return_value={}) def test_import_any_model_smc_container( self, @@ -1230,6 +1258,15 @@ def test_import_model_with_input_tags( "--download_from_hf True --inference_container odsc-vllm-serving --freeform_tags " '{"ftag1": "fvalue1", "ftag2": "fvalue2"} --defined_tags {"dtag1": "dvalue1", "dtag2": "dvalue2"}', ), + ( + { + "os_path": "oci://aqua-bkt@aqua-ns/path", + "model": "oracle/oracle-1it", + "inference_container": "odsc-vllm-serving", + "ignore_model_artifact_check": True, + }, + "ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --inference_container odsc-vllm-serving --ignore_model_artifact_check True", + ), ], ) def test_import_cli(self, data, expected_output): diff --git a/tests/unitary/with_extras/aqua/test_model_handler.py b/tests/unitary/with_extras/aqua/test_model_handler.py index bf02174b9..16202f477 100644 --- a/tests/unitary/with_extras/aqua/test_model_handler.py +++ b/tests/unitary/with_extras/aqua/test_model_handler.py @@ -132,7 +132,7 @@ def test_list(self, mock_list): @parameterized.expand( [ - (None, None, False, None, None, None, None, None), + (None, None, False, None, None, None, None, None, True), ( "odsc-llm-fine-tuning", None, @@ -142,8 +142,9 @@ def test_list(self, mock_list): ["test.json"], None, None, + False, ), - (None, "test.gguf", True, None, ["*.json"], None, None, None), + (None, "test.gguf", True, None, ["*.json"], None, None, None, False), ( None, None, @@ -153,6 +154,7 @@ def test_list(self, mock_list): ["test.json"], None, None, + False, ), ( None, @@ -163,6 +165,7 @@ def test_list(self, mock_list): None, {"ftag1": "fvalue1"}, {"dtag1": "dvalue1"}, + False, ), ], ) @@ -178,6 +181,7 @@ def test_register( ignore_patterns, freeform_tags, defined_tags, + ignore_model_artifact_check, mock_register, mock_finish, ): @@ -201,6 +205,7 @@ def test_register( ignore_patterns=ignore_patterns, freeform_tags=freeform_tags, defined_tags=defined_tags, + ignore_model_artifact_check=ignore_model_artifact_check, ) ) result = self.model_handler.post() @@ -218,6 +223,7 @@ def test_register( ignore_patterns=ignore_patterns, freeform_tags=freeform_tags, defined_tags=defined_tags, + ignore_model_artifact_check=ignore_model_artifact_check, ) assert result["id"] == "test_id" assert result["inference_container"] == "odsc-tgi-serving"