Skip to content

Commit

Permalink
Fix validation for model containing both formats. (#948)
Browse files Browse the repository at this point in the history
  • Loading branch information
lu-ohai authored Oct 3, 2024
2 parents 9e8e5d7 + 7d966b5 commit eb785ca
Showing 1 changed file with 171 additions and 124 deletions.
295 changes: 171 additions & 124 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
from ads.aqua.app import AquaApp
from ads.aqua.common.enums import Tags
from ads.aqua.common.enums import InferenceContainerTypeFamily, Tags
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
from ads.aqua.common.utils import (
LifecycleStatus,
Expand Down Expand Up @@ -933,139 +933,186 @@ def _validate_model(
# now as we know that at least one type of model files exist, validate the content of oss path.
# for safetensors, we check if config.json files exist, and for gguf format we check if files with
# gguf extension exist.
for model_format in model_formats:
if {ModelFormat.SAFETENSORS, ModelFormat.GGUF}.issubset(set(model_formats)):
if (
model_format == ModelFormat.SAFETENSORS
and len(safetensors_model_files) > 0
import_model_details.inference_container.lower() == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
):
if import_model_details.download_from_hf:
# validates config.json exists for safetensors model from hugginface
if not hf_download_config_present:
raise AquaRuntimeError(
f"The model {model_name} does not contain {AQUA_MODEL_ARTIFACT_CONFIG} file as required "
f"by {ModelFormat.SAFETENSORS.value} format model."
f" Please check if the model name is correct in Hugging Face repository."
)
else:
try:
model_config = load_config(
file_path=import_model_details.os_path,
config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
)
except Exception as ex:
logger.error(
f"Exception occurred while loading config file from {import_model_details.os_path}"
f"Exception message: {ex}"
)
raise AquaRuntimeError(
f"The model path {import_model_details.os_path} does not contain the file config.json. "
f"Please check if the path is correct or the model artifacts are available at this location."
) from ex
else:
try:
metadata_model_type = (
verified_model.custom_metadata_list.get(
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
).value
)
if metadata_model_type:
if (
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
in model_config
):
if (
model_config[
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
]
!= metadata_model_type
):
raise AquaRuntimeError(
f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}"
f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for "
f"the model {model_name}. Please check if the path is correct or "
f"the correct model artifacts are available at this location."
f""
)
else:
logger.debug(
f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in "
f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration."
)
except Exception:
pass
if verified_model:
validation_result.telemetry_model_name = (
verified_model.display_name
)
elif (
model_config is not None
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
):
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
elif (
model_config is not None
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
self._validate_gguf_format(
import_model_details=import_model_details,
verified_model=verified_model,
gguf_model_files=gguf_model_files,
validation_result=validation_result,
model_name=model_name
)
else:
self._validate_safetensor_format(
import_model_details=import_model_details,
verified_model=verified_model,
validation_result=validation_result,
hf_download_config_present=hf_download_config_present,
model_name=model_name
)
elif ModelFormat.SAFETENSORS in model_formats:
self._validate_safetensor_format(
import_model_details=import_model_details,
verified_model=verified_model,
validation_result=validation_result,
hf_download_config_present=hf_download_config_present,
model_name=model_name
)
elif ModelFormat.GGUF in model_formats:
self._validate_gguf_format(
import_model_details=import_model_details,
verified_model=verified_model,
gguf_model_files=gguf_model_files,
validation_result=validation_result,
model_name=model_name
)

return validation_result

@staticmethod
def _validate_safetensor_format(
import_model_details: ImportModelDetails = None,
verified_model: DataScienceModel = None,
validation_result: ModelValidationResult = None,
hf_download_config_present: bool = None,
model_name: str = None
):
if import_model_details.download_from_hf:
# validates config.json exists for safetensors model from hugginface
if not hf_download_config_present:
raise AquaRuntimeError(
f"The model {model_name} does not contain {AQUA_MODEL_ARTIFACT_CONFIG} file as required "
f"by {ModelFormat.SAFETENSORS.value} format model."
f" Please check if the model name is correct in Hugging Face repository."
)
else:
try:
model_config = load_config(
file_path=import_model_details.os_path,
config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
)
except Exception as ex:
logger.error(
f"Exception occurred while loading config file from {import_model_details.os_path}"
f"Exception message: {ex}"
)
raise AquaRuntimeError(
f"The model path {import_model_details.os_path} does not contain the file config.json. "
f"Please check if the path is correct or the model artifacts are available at this location."
) from ex
else:
try:
metadata_model_type = (
verified_model.custom_metadata_list.get(
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
).value
)
if metadata_model_type:
if (
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
in model_config
):
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
if (
model_config[
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
]
!= metadata_model_type
):
raise AquaRuntimeError(
f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}"
f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for "
f"the model {model_name}. Please check if the path is correct or "
f"the correct model artifacts are available at this location."
f""
)
else:
validation_result.telemetry_model_name = (
AQUA_MODEL_TYPE_CUSTOM
logger.debug(
f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in "
f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration."
)
elif model_format == ModelFormat.GGUF and len(gguf_model_files) > 0:
if import_model_details.finetuning_container and not safetensors_model_files:
raise AquaValueError(
"Fine-tuning is currently not supported with GGUF model format."
)
except Exception:
pass
if verified_model:
try:
model_file = verified_model.custom_metadata_list.get(
AQUA_MODEL_ARTIFACT_FILE
).value
except ValueError as err:
raise AquaRuntimeError(
f"The model {verified_model.display_name} does not contain the custom metadata {AQUA_MODEL_ARTIFACT_FILE}. "
f"Please check if the model has the valid metadata."
) from err
else:
model_file = import_model_details.model_file

model_files = gguf_model_files
# todo: have a separate error validation class for different type of error messages.
if model_file:
if model_file not in model_files:
raise AquaRuntimeError(
f"The model path {import_model_details.os_path} or the Hugging Face "
f"model repository for {model_name} does not contain the file "
f"{model_file}. Please check if the path is correct or the model "
f"artifacts are available at this location."
)
else:
validation_result.model_file = model_file
elif len(model_files) == 0:
raise AquaRuntimeError(
f"The model path {import_model_details.os_path} or the Hugging Face model "
f"repository for {model_name} does not contain any GGUF format files. "
f"Please check if the path is correct or the model artifacts are available "
f"at this location."
)
elif len(model_files) > 1:
raise AquaRuntimeError(
f"The model path {import_model_details.os_path} or the Hugging Face model "
f"repository for {model_name} contains multiple GGUF format files. "
f"Please specify the file that needs to be deployed using the model_file "
f"parameter."
validation_result.telemetry_model_name = (
verified_model.display_name
)
elif (
model_config is not None
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
):
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
elif (
model_config is not None
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
):
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
else:
validation_result.model_file = model_files[0]
validation_result.telemetry_model_name = (
AQUA_MODEL_TYPE_CUSTOM
)

if verified_model:
validation_result.telemetry_model_name = verified_model.display_name
elif import_model_details.download_from_hf:
validation_result.telemetry_model_name = model_name
else:
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
@staticmethod
def _validate_gguf_format(
import_model_details: ImportModelDetails = None,
verified_model: DataScienceModel = None,
gguf_model_files: List[str] = None,
validation_result: ModelValidationResult = None,
model_name: str = None,
):
if import_model_details.finetuning_container:
raise AquaValueError(
"Fine-tuning is currently not supported with GGUF model format."
)
if verified_model:
try:
model_file = verified_model.custom_metadata_list.get(
AQUA_MODEL_ARTIFACT_FILE
).value
except ValueError as err:
raise AquaRuntimeError(
f"The model {verified_model.display_name} does not contain the custom metadata {AQUA_MODEL_ARTIFACT_FILE}. "
f"Please check if the model has the valid metadata."
) from err
else:
model_file = import_model_details.model_file

return validation_result
model_files = gguf_model_files
# todo: have a separate error validation class for different type of error messages.
if model_file:
if model_file not in model_files:
raise AquaRuntimeError(
f"The model path {import_model_details.os_path} or the Hugging Face "
f"model repository for {model_name} does not contain the file "
f"{model_file}. Please check if the path is correct or the model "
f"artifacts are available at this location."
)
else:
validation_result.model_file = model_file
elif len(model_files) == 0:
raise AquaRuntimeError(
f"The model path {import_model_details.os_path} or the Hugging Face model "
f"repository for {model_name} does not contain any GGUF format files. "
f"Please check if the path is correct or the model artifacts are available "
f"at this location."
)
elif len(model_files) > 1:
raise AquaRuntimeError(
f"The model path {import_model_details.os_path} or the Hugging Face model "
f"repository for {model_name} contains multiple GGUF format files. "
f"Please specify the file that needs to be deployed using the model_file "
f"parameter."
)
else:
validation_result.model_file = model_files[0]

if verified_model:
validation_result.telemetry_model_name = verified_model.display_name
elif import_model_details.download_from_hf:
validation_result.telemetry_model_name = model_name
else:
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM

@staticmethod
def _download_model_from_hf(
Expand Down

0 comments on commit eb785ca

Please sign in to comment.