From 595b38ae4551240c6021b7afb54784cfca687378 Mon Sep 17 00:00:00 2001 From: Hafizhan Aliady <105761044+tkpd-hafizhan@users.noreply.github.com> Date: Tue, 3 Oct 2023 21:43:00 +0700 Subject: [PATCH] Add standard transformer Simulate to SDK (#463) **What this PR does / why we need it**: - This PR adds Standard Transformer simulator to SDK **Which issue(s) this PR fixes**: Fixes # **Does this PR introduce a user-facing change?**: ```release-note ``` **Checklist** - [ ] Added unit test, integration, and/or e2e tests - [ ] Tested locally - [ ] Updated documentation - [ ] Update Swagger spec if the PR introduce API changes - [ ] Regenerated Golang and Python client if the PR introduce API changes --- .../client/api/standard_transformer_api.py | 60 +- ...standard_transformer_simulation_request.py | 51 +- python/sdk/merlin/client.py | 130 ++- python/sdk/merlin/fluent.py | 133 +-- python/sdk/merlin/merlin.py | 124 ++- python/sdk/merlin/model.py | 810 ++++++++++-------- python/sdk/merlin/transformer.py | 83 +- python/sdk/test/integration_test.py | 57 ++ .../sim_exp_resp_valid_w_tracing.json | 224 +++++ .../sim_exp_resp_valid_wo_tracing.json | 20 + python/sdk/test/transformer_test.py | 27 +- 11 files changed, 1165 insertions(+), 554 deletions(-) create mode 100644 python/sdk/test/transformer/sim_exp_resp_valid_w_tracing.json create mode 100644 python/sdk/test/transformer/sim_exp_resp_valid_wo_tracing.json diff --git a/python/sdk/client/api/standard_transformer_api.py b/python/sdk/client/api/standard_transformer_api.py index 50acfec33..996909ea2 100644 --- a/python/sdk/client/api/standard_transformer_api.py +++ b/python/sdk/client/api/standard_transformer_api.py @@ -46,11 +46,15 @@ def standard_transformer_simulate_post(self, **kwargs): # noqa: E501 If the method is called asynchronously, returns the request thread. """ - kwargs['_return_http_data_only'] = True - if kwargs.get('async_req'): - return self.standard_transformer_simulate_post_with_http_info(**kwargs) # noqa: E501 + kwargs["_return_http_data_only"] = True + if kwargs.get("async_req"): + return self.standard_transformer_simulate_post_with_http_info( + **kwargs + ) # noqa: E501 else: - (data) = self.standard_transformer_simulate_post_with_http_info(**kwargs) # noqa: E501 + (data) = self.standard_transformer_simulate_post_with_http_info( + **kwargs + ) # noqa: E501 return data def standard_transformer_simulate_post_with_http_info(self, **kwargs): # noqa: E501 @@ -68,21 +72,21 @@ def standard_transformer_simulate_post_with_http_info(self, **kwargs): # noqa: returns the request thread. """ - all_params = ['body'] # noqa: E501 - all_params.append('async_req') - all_params.append('_return_http_data_only') - all_params.append('_preload_content') - all_params.append('_request_timeout') + all_params = ["body"] # noqa: E501 + all_params.append("async_req") + all_params.append("_return_http_data_only") + all_params.append("_preload_content") + all_params.append("_request_timeout") params = locals() - for key, val in six.iteritems(params['kwargs']): + for key, val in six.iteritems(params["kwargs"]): if key not in all_params: raise TypeError( "Got an unexpected keyword argument '%s'" " to method standard_transformer_simulate_post" % key ) params[key] = val - del params['kwargs'] + del params["kwargs"] collection_formats = {} @@ -96,31 +100,37 @@ def standard_transformer_simulate_post_with_http_info(self, **kwargs): # noqa: local_var_files = {} body_params = None - if 'body' in params: - body_params = params['body'] + if "body" in params: + body_params = params["body"] # HTTP header `Accept` - header_params['Accept'] = self.api_client.select_header_accept( - ['*/*']) # noqa: E501 + header_params["Accept"] = self.api_client.select_header_accept( + ["*/*"] + ) # noqa: E501 # HTTP header `Content-Type` - header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501 - ['*/*']) # noqa: E501 + header_params[ + "Content-Type" + ] = self.api_client.select_header_content_type( # noqa: E501 + ["*/*"] + ) # noqa: E501 # Authentication setting - auth_settings = ['Bearer'] # noqa: E501 + auth_settings = ["Bearer"] # noqa: E501 return self.api_client.call_api( - '/standard_transformer/simulate', 'POST', + "/standard_transformer/simulate", + "POST", path_params, query_params, header_params, body=body_params, post_params=form_params, files=local_var_files, - response_type='StandardTransformerSimulationResponse', # noqa: E501 + response_type="StandardTransformerSimulationResponse", # noqa: E501 auth_settings=auth_settings, - async_req=params.get('async_req'), - _return_http_data_only=params.get('_return_http_data_only'), - _preload_content=params.get('_preload_content', True), - _request_timeout=params.get('_request_timeout'), - collection_formats=collection_formats) + async_req=params.get("async_req"), + _return_http_data_only=params.get("_return_http_data_only"), + _preload_content=params.get("_preload_content", True), + _request_timeout=params.get("_request_timeout"), + collection_formats=collection_formats, + ) diff --git a/python/sdk/client/models/standard_transformer_simulation_request.py b/python/sdk/client/models/standard_transformer_simulation_request.py index f97520259..1324cc5db 100644 --- a/python/sdk/client/models/standard_transformer_simulation_request.py +++ b/python/sdk/client/models/standard_transformer_simulation_request.py @@ -15,11 +15,13 @@ import six + class StandardTransformerSimulationRequest(object): """NOTE: This class is auto generated by the swagger code generator program. Do not edit the class manually. """ + """ Attributes: swagger_types (dict): The key is attribute name @@ -28,22 +30,29 @@ class StandardTransformerSimulationRequest(object): and the value is json key in definition. """ swagger_types = { - 'payload': 'FreeFormObject', - 'headers': 'FreeFormObject', - 'config': 'FreeFormObject', - 'model_prediction_config': 'ModelPredictionConfig', - 'protocol': 'Protocol' + "payload": "FreeFormObject", + "headers": "FreeFormObject", + "config": "FreeFormObject", + "model_prediction_config": "ModelPredictionConfig", + "protocol": "Protocol", } attribute_map = { - 'payload': 'payload', - 'headers': 'headers', - 'config': 'config', - 'model_prediction_config': 'model_prediction_config', - 'protocol': 'protocol' + "payload": "payload", + "headers": "headers", + "config": "config", + "model_prediction_config": "model_prediction_config", + "protocol": "protocol", } - def __init__(self, payload=None, headers=None, config=None, model_prediction_config=None, protocol=None): # noqa: E501 + def __init__( + self, + payload=None, + headers=None, + config=None, + model_prediction_config=None, + protocol=None, + ): # noqa: E501 """StandardTransformerSimulationRequest - a model defined in Swagger""" # noqa: E501 self._payload = None self._headers = None @@ -174,18 +183,20 @@ def to_dict(self): for attr, _ in six.iteritems(self.swagger_types): value = getattr(self, attr) if isinstance(value, list): - result[attr] = list(map( - lambda x: x.to_dict() if hasattr(x, "to_dict") else x, - value - )) + result[attr] = list( + map(lambda x: x.to_dict() if hasattr(x, "to_dict") else x, value) + ) elif hasattr(value, "to_dict"): result[attr] = value.to_dict() elif isinstance(value, dict): - result[attr] = dict(map( - lambda item: (item[0], item[1].to_dict()) - if hasattr(item[1], "to_dict") else item, - value.items() - )) + result[attr] = dict( + map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") + else item, + value.items(), + ) + ) else: result[attr] = value if issubclass(StandardTransformerSimulationRequest, dict): diff --git a/python/sdk/merlin/client.py b/python/sdk/merlin/client.py index b8bc940d0..633de203b 100644 --- a/python/sdk/merlin/client.py +++ b/python/sdk/merlin/client.py @@ -14,12 +14,23 @@ import warnings from sys import version_info -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Any import urllib3 from caraml_auth.id_token_credentials import get_default_id_token_credentials -from client import (ApiClient, Configuration, EndpointApi, EnvironmentApi, - ModelsApi, ProjectApi, VersionApi) +from client import ( + ApiClient, + Configuration, + EndpointApi, + EnvironmentApi, + ModelsApi, + ProjectApi, + StandardTransformerApi, + VersionApi, + StandardTransformerSimulationRequest, + FreeFormObject, +) + from google.auth.transport.requests import Request from google.auth.transport.urllib3 import AuthorizedHttp from merlin.autoscaling import AutoscalingPolicy @@ -36,7 +47,7 @@ class MerlinClient: - def __init__(self, merlin_url: str, use_google_oauth: bool=True): + def __init__(self, merlin_url: str, use_google_oauth: bool = True): self._merlin_url = merlin_url config = Configuration() config.host = self._merlin_url + "/v1" @@ -50,13 +61,14 @@ def __init__(self, merlin_url: str, use_google_oauth: bool=True): authorized_http = AuthorizedHttp(credentials, urllib3.PoolManager()) self._api_client.rest_client.pool_manager = authorized_http - python_version = f'{version_info.major}.{version_info.minor}.{version_info.micro}' # capture user's python version + python_version = f"{version_info.major}.{version_info.minor}.{version_info.micro}" # capture user's python version self._api_client.user_agent = f"merlin-sdk/{VERSION} python/{python_version}" self._project_api = ProjectApi(self._api_client) self._model_api = ModelsApi(self._api_client) self._version_api = VersionApi(self._api_client) self._endpoint_api = EndpointApi(self._api_client) self._env_api = EnvironmentApi(self._api_client) + self._standard_transformer_api = StandardTransformerApi(self._api_client) @property def url(self): @@ -74,7 +86,6 @@ def list_environment(self) -> List[Environment]: envs.append(Environment(env)) return envs - def get_environment(self, env_name: str) -> Optional[Environment]: """ Get environment for given env name @@ -128,12 +139,12 @@ def get_project(self, project_name: str) -> Project: """ if not valid_name_check(project_name): raise ValueError( - '''Your project/model name contains invalid characters.\ + """Your project/model name contains invalid characters.\ \nUse only the following characters\ \n- Characters: a-z (Lowercase ONLY)\ \n- Numbers: 0-9\ \n- Symbols: - - ''' + """ ) p_list = self._project_api.projects_get(name=project_name) @@ -143,14 +154,15 @@ def get_project(self, project_name: str) -> Project: p = prj if p is None: - raise Exception(f"{project_name} does not exist or you don't have access to the project. Please create new " - f"project using MLP console or ask the project's administrator to be able to access " - f"existing project.") + raise Exception( + f"{project_name} does not exist or you don't have access to the project. Please create new " + f"project using MLP console or ask the project's administrator to be able to access " + f"existing project." + ) return Project(p, self.url, self._api_client) - def get_model(self, model_name: str, project_name: str) \ - -> Optional[Model]: + def get_model(self, model_name: str, project_name: str) -> Optional[Model]: """ Get model with given name @@ -160,7 +172,8 @@ def get_model(self, model_name: str, project_name: str) \ """ prj = self.get_project(project_name) m_list = self._model_api.projects_project_id_models_get( - project_id=int(prj.id), name=model_name) + project_id=int(prj.id), name=model_name + ) model = m_list[0] if len(m_list) == 0: return None @@ -171,9 +184,9 @@ def get_model(self, model_name: str, project_name: str) \ return Model(model, prj, self._api_client) - def get_or_create_model(self, model_name: str, - project_name: str, - model_type: ModelType = None) -> Model: + def get_or_create_model( + self, model_name: str, project_name: str, model_type: ModelType = None + ) -> Model: """ Get or create a model under a project @@ -188,17 +201,18 @@ def get_or_create_model(self, model_name: str, """ if not valid_name_check(model_name): raise ValueError( - '''Your project/model name contains invalid characters.\ + """Your project/model name contains invalid characters.\ \nUse only the following characters\ \n- Characters: a-z (Lowercase ONLY)\ \n- Numbers: 0-9\ \n- Symbols: - - ''' + """ ) prj = self.get_project(project_name) m_list = self._model_api.projects_project_id_models_get( - project_id=int(prj.id), name=model_name) + project_id=int(prj.id), name=model_name + ) model = None for mdl in m_list: @@ -207,18 +221,20 @@ def get_or_create_model(self, model_name: str, if model is None: if model_type is None: - raise ValueError(f"model {model_name} is not found, specify " - f"{model_type} to create it") + raise ValueError( + f"model {model_name} is not found, specify " + f"{model_type} to create it" + ) model = self._model_api.projects_project_id_models_post( - project_id=int(prj.id), body={ - "name": model_name, - "type": model_type.value - }) + project_id=int(prj.id), + body={"name": model_name, "type": model_type.value}, + ) return Model(model, prj, self._api_client) - def new_model_version(self, model_name: str, project_name: str, labels: Dict[str, str] = None) \ - -> ModelVersion: + def new_model_version( + self, model_name: str, project_name: str, labels: Dict[str, str] = None + ) -> ModelVersion: """ Create new model version for the given model and project @@ -232,18 +248,48 @@ def new_model_version(self, model_name: str, project_name: str, labels: Dict[str raise ValueError(f"Model with name: {model_name} is not found") return mdl.new_model_version(labels=labels) - def deploy(self, model_version: ModelVersion, - environment_name: str = None, - resource_request: ResourceRequest = None, - env_vars: Dict[str, str] = None, - transformer: Transformer = None, - logger: Logger = None, - deployment_mode: DeploymentMode = DeploymentMode.SERVERLESS, - autoscaling_policy: AutoscalingPolicy = None, - protocol: Protocol = Protocol.HTTP_JSON) -> VersionEndpoint: - return model_version.deploy(environment_name, resource_request, env_vars, transformer, logger, deployment_mode, - autoscaling_policy, protocol) - - def undeploy(self, model_version: ModelVersion, - environment_name: str = None): + def deploy( + self, + model_version: ModelVersion, + environment_name: str = None, + resource_request: ResourceRequest = None, + env_vars: Dict[str, str] = None, + transformer: Transformer = None, + logger: Logger = None, + deployment_mode: DeploymentMode = DeploymentMode.SERVERLESS, + autoscaling_policy: AutoscalingPolicy = None, + protocol: Protocol = Protocol.HTTP_JSON, + ) -> VersionEndpoint: + return model_version.deploy( + environment_name, + resource_request, + env_vars, + transformer, + logger, + deployment_mode, + autoscaling_policy, + protocol, + ) + + def undeploy(self, model_version: ModelVersion, environment_name: str = None): model_version.undeploy(environment_name) + + def standard_transformer_simulate( + self, + payload: Dict, + headers: Optional[Dict[Any, Any]] = None, + config: Optional[Dict[Any, Any]] = None, + model_prediction_config: Dict = None, + protocol: str = "HTTP_JSON", + ): + request = StandardTransformerSimulationRequest( + payload=payload, + headers=headers, + config=config, + model_prediction_config=model_prediction_config, + protocol=protocol, + ) + + return self._standard_transformer_api.standard_transformer_simulate_post( + body=request.to_dict() + ) diff --git a/python/sdk/merlin/fluent.py b/python/sdk/merlin/fluent.py index 8440b1493..14b9cd566 100644 --- a/python/sdk/merlin/fluent.py +++ b/python/sdk/merlin/fluent.py @@ -133,9 +133,9 @@ def set_model(model_name, model_type: ModelType = None): """ _check_active_project() active_project_name = _active_project.name # type: ignore - mdl = _merlin_client.get_or_create_model(model_name, # type: ignore - active_project_name, - model_type) + mdl = _merlin_client.get_or_create_model( # type: ignore + model_name, active_project_name, model_type + ) global _active_model _active_model = mdl @@ -276,12 +276,15 @@ def log_artifact(local_path: str, artifact_path: str = None): :param artifact_path: destination directory in artifact store """ _check_active_model_version() - _active_model_version.log_artifact(local_path, # type: ignore - artifact_path) + _active_model_version.log_artifact(local_path, artifact_path) # type: ignore -def log_pyfunc_model(model_instance: Any, conda_env: str, code_dir: List[str] = None, - artifacts: Dict[str, str] = None): +def log_pyfunc_model( + model_instance: Any, + conda_env: str, + code_dir: List[str] = None, + artifacts: Dict[str, str] = None, +): """ Upload PyFunc based model into artifact storage. @@ -298,9 +301,9 @@ def log_pyfunc_model(model_instance: Any, conda_env: str, code_dir: List[str] = :param artifacts: dictionary of artifact that will be stored together with the model. This will be passed to PythonModel.initialize. Example: {"config": "config/staging.yaml"} """ _check_active_model_version() - _active_model_version.log_pyfunc_model(model_instance, # type: ignore - conda_env, - code_dir, artifacts) + _active_model_version.log_pyfunc_model( # type: ignore + model_instance, conda_env, code_dir, artifacts + ) def log_pytorch_model(model_dir: str, model_class_name: str = None): @@ -311,8 +314,7 @@ def log_pytorch_model(model_dir: str, model_class_name: str = None): :param model_class_name: class name of PyTorch model. By default the model class name is 'PyTorchModel' """ _check_active_model_version() - _active_model_version.log_pytorch_model(model_dir, # type: ignore - model_class_name) + _active_model_version.log_pytorch_model(model_dir, model_class_name) # type: ignore def log_model(model_dir): @@ -326,10 +328,9 @@ def log_model(model_dir): _active_model_version.log_model(model_dir) # type: ignore -def log_custom_model(image: str, - model_dir: str = None, - command: str = "", - args: str = ""): +def log_custom_model( + image: str, model_dir: str = None, command: str = "", args: str = "" +): """ Upload model to artifact storage. This method is used to upload model for custom model type. @@ -340,21 +341,22 @@ def log_custom_model(image: str, :param args: Arguments that needs to be specified when running docker """ _check_active_model_version() - _active_model_version.log_custom_model(image=image, # type: ignore - model_dir=model_dir, - command=command, - args=args) - -def deploy(model_version: ModelVersion = None, - environment_name: str = None, - resource_request: ResourceRequest = None, - env_vars: Dict[str, str] = None, - transformer: Transformer = None, - logger: Logger = None, - deployment_mode: DeploymentMode = DeploymentMode.SERVERLESS, - autoscaling_policy: AutoscalingPolicy = None, - protocol: Protocol = Protocol.HTTP_JSON - ) -> VersionEndpoint: + _active_model_version.log_custom_model( # type: ignore + image=image, model_dir=model_dir, command=command, args=args + ) + + +def deploy( + model_version: ModelVersion = None, + environment_name: str = None, + resource_request: ResourceRequest = None, + env_vars: Dict[str, str] = None, + transformer: Transformer = None, + logger: Logger = None, + deployment_mode: DeploymentMode = DeploymentMode.SERVERLESS, + autoscaling_policy: AutoscalingPolicy = None, + protocol: Protocol = Protocol.HTTP_JSON, +) -> VersionEndpoint: """ Deploy a model version. @@ -372,28 +374,31 @@ def deploy(model_version: ModelVersion = None, _check_active_client() if model_version is None: _check_active_model_version() - return _active_model_version.deploy(environment_name, # type: ignore - resource_request, - env_vars, - transformer, - logger, - deployment_mode, - autoscaling_policy, - protocol) - - return _merlin_client.deploy(model_version, # type: ignore - environment_name, - resource_request, - env_vars, - transformer, - logger, - deployment_mode, - autoscaling_policy, - protocol) - - -def undeploy(model_version=None, - environment_name: str = None): + return _active_model_version.deploy( # type: ignore + environment_name, + resource_request, + env_vars, + transformer, + logger, + deployment_mode, + autoscaling_policy, + protocol, + ) + + return _merlin_client.deploy( # type: ignore + model_version, + environment_name, + resource_request, + env_vars, + transformer, + logger, + deployment_mode, + autoscaling_policy, + protocol, + ) + + +def undeploy(model_version=None, environment_name: str = None): """ Delete deployment of a model version. @@ -408,8 +413,9 @@ def undeploy(model_version=None, _merlin_client.undeploy(model_version, environment_name) # type: ignore -def serve_traffic(traffic_rule: Dict['VersionEndpoint', int], - environment_name: str = None) -> ModelEndpoint: +def serve_traffic( + traffic_rule: Dict["VersionEndpoint", int], environment_name: str = None +) -> ModelEndpoint: """ Update traffic rule of the active model. @@ -418,8 +424,7 @@ def serve_traffic(traffic_rule: Dict['VersionEndpoint', int], :return: ModelEndpoint """ _check_active_model() - return _active_model.serve_traffic(traffic_rule, # type: ignore - environment_name) + return _active_model.serve_traffic(traffic_rule, environment_name) # type: ignore def stop_serving_traffic(environment_name: str = None): @@ -453,7 +458,9 @@ def list_model_endpoints() -> List[ModelEndpoint]: return _active_model.list_endpoint() # type: ignore -def create_prediction_job(job_config: PredictionJobConfig, sync: bool = True) -> PredictionJob: +def create_prediction_job( + job_config: PredictionJobConfig, sync: bool = True +) -> PredictionJob: """ :param sync: @@ -470,14 +477,12 @@ def create_prediction_job(job_config: PredictionJobConfig, sync: bool = True) -> def _check_active_project(): if _active_project is None: - raise Exception( - "Active project isn't set, use set_project(...) to set it") + raise Exception("Active project isn't set, use set_project(...) to set it") def _check_active_client(): if _merlin_client is None: - raise Exception( - "URL is not set, use set_url(...) to set it") + raise Exception("URL is not set, use set_url(...) to set it") def _check_active_model(): @@ -488,5 +493,5 @@ def _check_active_model(): def _check_active_model_version(): if _active_model_version is None: raise Exception( - "Active model version isn't set, use new_model_version(...) to " - "create it") + "Active model version isn't set, use new_model_version(...) to " "create it" + ) diff --git a/python/sdk/merlin/merlin.py b/python/sdk/merlin/merlin.py index e9d2c68f4..42d87934c 100644 --- a/python/sdk/merlin/merlin.py +++ b/python/sdk/merlin/merlin.py @@ -20,7 +20,8 @@ from cookiecutter.main import cookiecutter from merlin.util import valid_name_check -warnings.filterwarnings('ignore') +warnings.filterwarnings("ignore") + @click.group() def cli(): @@ -32,20 +33,52 @@ def cli(): """ pass -@cli.command('deploy', short_help='Deploy the model') -@click.option('--env', '-e', required=True, help='The environment of model deployment') -@click.option('--url', '-u', required=True, help='The endpoint of model deployment') -@click.option('--project', '-p', required=True, help='The project name of model deployment') -@click.option('--model-dir', '-m', required=True, help='The directory with model for deployment') -@click.option('--model-name', '-n', required=True, help='The model name for deployment') -@click.option('--model-type', '-t', required=True, help='The type of machine learning algorithm') -@click.option('--min-replica', required=False, help='The minimum number of replicas to create for this deployment') -@click.option('--max-replica', required=False, help='The maximum number of replicas to create for this deployment') -@click.option('--cpu-request', required=False, help='The CPU resource requirement requests for this deployment. Example: 100m.') -@click.option('--memory-request', required=False, help='The memory resource requirement requests for this deployment. Example: 256Mi.') -def deploy(env, model_name, model_type, model_dir, project, url, - min_replica, max_replica, cpu_request, memory_request): +@cli.command("deploy", short_help="Deploy the model") +@click.option("--env", "-e", required=True, help="The environment of model deployment") +@click.option("--url", "-u", required=True, help="The endpoint of model deployment") +@click.option( + "--project", "-p", required=True, help="The project name of model deployment" +) +@click.option( + "--model-dir", "-m", required=True, help="The directory with model for deployment" +) +@click.option("--model-name", "-n", required=True, help="The model name for deployment") +@click.option( + "--model-type", "-t", required=True, help="The type of machine learning algorithm" +) +@click.option( + "--min-replica", + required=False, + help="The minimum number of replicas to create for this deployment", +) +@click.option( + "--max-replica", + required=False, + help="The maximum number of replicas to create for this deployment", +) +@click.option( + "--cpu-request", + required=False, + help="The CPU resource requirement requests for this deployment. Example: 100m.", +) +@click.option( + "--memory-request", + required=False, + help="The memory resource requirement requests for this deployment. Example: 256Mi.", +) +def deploy( + env, + model_name, + model_type, + model_dir, + project, + url, + min_replica, + max_replica, + cpu_request, + memory_request, +): merlin.set_url(url) target_env = merlin.get_environment(env) @@ -69,17 +102,21 @@ def deploy(env, model_name, model_type, model_dir, project, url, try: endpoint = merlin.deploy(v, env, resource_request) if endpoint: - print('Model deployed to {}'.format(endpoint)) + print("Model deployed to {}".format(endpoint)) except Exception as e: print(e) -@cli.command('undeploy', short_help='Undeploy the model') -@click.option('--url', '-u', required=True, help='The endpoint of model deployment') -@click.option('--project', '-p', required=True, help='The project name of model deployment') -@click.option('--model-name', '-n', required=True, help='The model name for deployment') -@click.option('--model-version', '-v', required=True, help='The model version for deployment') -def undeploy(model_name, model_version, project, url): +@cli.command("undeploy", short_help="Undeploy the model") +@click.option("--url", "-u", required=True, help="The endpoint of model deployment") +@click.option( + "--project", "-p", required=True, help="The project name of model deployment" +) +@click.option("--model-name", "-n", required=True, help="The model name for deployment") +@click.option( + "--model-version", "-v", required=True, help="The model version for deployment" +) +def undeploy(model_name, model_version, project, url): merlin.set_url(url) merlin.set_project(project) merlin.set_model(model_name) @@ -88,10 +125,14 @@ def undeploy(model_name, model_version, project, url): all_versions = merlin_active_model.list_version() try: - wanted_model_info = [model_info for model_info in all_versions if model_info._id == int(model_version)][0] + wanted_model_info = [ + model_info + for model_info in all_versions + if model_info._id == int(model_version) + ][0] except Exception as e: print(e) - print('Model Version {} is not found.'.format(model_version)) + print("Model Version {} is not found.".format(model_version)) try: merlin.undeploy(wanted_model_info) @@ -99,26 +140,45 @@ def undeploy(model_name, model_version, project, url): print(e) -@cli.command('scaffold', short_help='Generate PyFunc project') -@click.option('--project', '-p', required=True, help='The merlin project name of PyFunc server') -@click.option('--model-name', '-m', required=True, help='The model name which will be listed in merlin') -@click.option('--env', '-e', required=True, help='The environment which PyFunc server will be deployed, available environment are id and global') +@cli.command("scaffold", short_help="Generate PyFunc project") +@click.option( + "--project", "-p", required=True, help="The merlin project name of PyFunc server" +) +@click.option( + "--model-name", + "-m", + required=True, + help="The model name which will be listed in merlin", +) +@click.option( + "--env", + "-e", + required=True, + help="The environment which PyFunc server will be deployed, available environment are id and global", +) def scaffold(project, model_name, env): if not valid_name_check(project) or not valid_name_check(model_name): print( - '''Your project/model name contains invalid characters.\ + """Your project/model name contains invalid characters.\ \nUse only the following characters\ \n- Characters: a-z (Lowercase ONLY)\ \n- Numbers: 0-9\ \n- Symbols: - - ''' + """ ) else: try: - cookiecutter("git@github.com:caraml-dev/merlin/python/pyfunc-scaffolding", + cookiecutter( + "git@github.com:caraml-dev/merlin/python/pyfunc-scaffolding", checkout="tags/v0.1", - no_input=True, directory="python/pyfunc-scaffolding", - extra_context={'project_name': project, 'model_name': model_name, 'environment_name': env}) + no_input=True, + directory="python/pyfunc-scaffolding", + extra_context={ + "project_name": project, + "model_name": model_name, + "environment_name": env, + }, + ) except Exception as e: print(e) diff --git a/python/sdk/merlin/model.py b/python/sdk/merlin/model.py index e727d66e8..5e1dde158 100644 --- a/python/sdk/merlin/model.py +++ b/python/sdk/merlin/model.py @@ -28,29 +28,40 @@ import docker import pyprind import yaml -from client import (EndpointApi, EnvironmentApi, ModelEndpointsApi, ModelsApi, - SecretApi, VersionApi) +from client import ( + EndpointApi, + EnvironmentApi, + ModelEndpointsApi, + ModelsApi, + SecretApi, + VersionApi, +) from docker import APIClient from docker.errors import BuildError from docker.models.containers import Container from merlin import pyfunc -from merlin.autoscaling import (RAW_DEPLOYMENT_DEFAULT_AUTOSCALING_POLICY, - SERVERLESS_DEFAULT_AUTOSCALING_POLICY, - AutoscalingPolicy) +from merlin.autoscaling import ( + RAW_DEPLOYMENT_DEFAULT_AUTOSCALING_POLICY, + SERVERLESS_DEFAULT_AUTOSCALING_POLICY, + AutoscalingPolicy, +) from merlin.batch.config import PredictionJobConfig from merlin.batch.job import PredictionJob from merlin.batch.sink import BigQuerySink from merlin.batch.source import BigQuerySource from merlin.deployment_mode import DeploymentMode -from merlin.docker.docker import (copy_pyfunc_dockerfile, - copy_standard_dockerfile) +from merlin.docker.docker import copy_pyfunc_dockerfile, copy_standard_dockerfile from merlin.endpoint import ModelEndpoint, Status, VersionEndpoint from merlin.logger import Logger from merlin.protocol import Protocol from merlin.resource_request import ResourceRequest from merlin.transformer import Transformer -from merlin.util import (autostr, download_files_from_gcs, guess_mlp_ui_url, - valid_name_check) +from merlin.util import ( + autostr, + download_files_from_gcs, + guess_mlp_ui_url, + valid_name_check, +) from merlin.validation import validate_model_dir from merlin.version import VERSION from mlflow.entities import Run, RunData @@ -79,7 +90,6 @@ class ModelEndpointDeploymentError(Exception): - def __init__(self, model_name: str, version: int, details: str): self._model_name = model_name self._version = version @@ -100,8 +110,9 @@ def details(self): @autostr class Project: - def __init__(self, project: client.Project, mlp_url: str, - api_client: client.ApiClient): + def __init__( + self, project: client.Project, mlp_url: str, api_client: client.ApiClient + ): self._id = project.id self._name = project.name self._mlflow_tracking_url = project.mlflow_tracking_url @@ -144,23 +155,22 @@ def updated_at(self) -> datetime: def url(self) -> str: return self._url - def list_model(self) -> List['Model']: + def list_model(self) -> List["Model"]: """ List all model available within the project :return: list of Model """ model_api = ModelsApi(self._api_client) - m_list = model_api.projects_project_id_models_get( - project_id=int(self.id)) + m_list = model_api.projects_project_id_models_get(project_id=int(self.id)) result = [] for model in m_list: - result.append( - Model(model, self, self._api_client)) + result.append(Model(model, self, self._api_client)) return result - def get_or_create_model(self, model_name: str, - model_type: 'ModelType' = None) -> 'Model': + def get_or_create_model( + self, model_name: str, model_type: "ModelType" = None + ) -> "Model": """ Get or create a model with given name @@ -170,25 +180,29 @@ def get_or_create_model(self, model_name: str, """ if not valid_name_check(model_name): raise ValueError( - '''Your project/model name contains invalid characters.\ + """Your project/model name contains invalid characters.\ \nUse only the following characters\ \n- Characters: a-z (Lowercase ONLY)\ \n- Numbers: 0-9\ \n- Symbols: - - ''' + """ ) model_api = ModelsApi(self._api_client) m_list = model_api.projects_project_id_models_get( - project_id=int(self.id), name=model_name) + project_id=int(self.id), name=model_name + ) if len(m_list) == 0: if model_type is None: - raise ValueError(f"model {model_name} is not found, specify " - f"{model_type} to create it") + raise ValueError( + f"model {model_name} is not found, specify " + f"{model_type} to create it" + ) model = model_api.projects_project_id_models_post( - project_id=int(self.id), body={"name": model_name, - "type": model_type.value}) + project_id=int(self.id), + body={"name": model_name, "type": model_type.value}, + ) else: model = m_list[0] @@ -203,11 +217,9 @@ def create_secret(self, name: str, data: str): :return: """ secret_api = SecretApi(self._api_client) - secret_api.projects_project_id_secrets_post(project_id=int(self.id), - body={ - "name": name, - "data": data - }) + secret_api.projects_project_id_secrets_post( + project_id=int(self.id), body={"name": name, "data": data} + ) def list_secret(self) -> List[str]: """ @@ -216,8 +228,7 @@ def list_secret(self) -> List[str]: :return: """ secret_api = SecretApi(self._api_client) - secrets = secret_api.projects_project_id_secrets_get( - project_id=int(self.id)) + secrets = secret_api.projects_project_id_secrets_get(project_id=int(self.id)) secret_names = [] for s in secrets: secret_names.append(s.name) @@ -234,13 +245,11 @@ def update_secret(self, name: str, data: str): secret_api = SecretApi(self._api_client) match = self._find_secret(name) - secret_api.projects_project_id_secrets_secret_id_patch(project_id=int(self.id), - secret_id=int( - match.id), - body={ - "name": name, - "data": data - }) + secret_api.projects_project_id_secrets_secret_id_patch( + project_id=int(self.id), + secret_id=int(match.id), + body={"name": name, "data": data}, + ) def delete_secret(self, name: str): """ @@ -252,20 +261,19 @@ def delete_secret(self, name: str): secret_api = SecretApi(self._api_client) match = self._find_secret(name) - secret_api.projects_project_id_secrets_secret_id_delete(project_id=int(self.id), - secret_id=int(match.id)) + secret_api.projects_project_id_secrets_secret_id_delete( + project_id=int(self.id), secret_id=int(match.id) + ) def _find_secret(self, name: str): secret_api = SecretApi(self._api_client) - secrets = secret_api.projects_project_id_secrets_get( - project_id=int(self.id)) + secrets = secret_api.projects_project_id_secrets_get(project_id=int(self.id)) match = None for s in secrets: if s.name == name: match = s if match is None: - raise ValueError( - f"unable to find secret {name} in project {self.name}") + raise ValueError(f"unable to find secret {name} in project {self.name}") return match @@ -273,6 +281,7 @@ class ModelType(Enum): """ Model type supported by merlin """ + XGBOOST = "xgboost" TENSORFLOW = "tensorflow" SKLEARN = "sklearn" @@ -289,8 +298,9 @@ class Model: Model representation """ - def __init__(self, model: client.Model, project: Project, - api_client: client.ApiClient): + def __init__( + self, model: client.Model, project: Project, api_client: client.ApiClient + ): self._id = model.id self._name = model.name self._mlflow_experiment_id = model.mlflow_experiment_id @@ -341,8 +351,9 @@ def endpoint(self) -> Optional[ModelEndpoint]: :return: Endpoint if exist, otherwise None """ mdl_endpoints_api = ModelEndpointsApi(self._api_client) - mdl_endpoints_list = \ - mdl_endpoints_api.models_model_id_endpoints_get(model_id=self.id) + mdl_endpoints_list = mdl_endpoints_api.models_model_id_endpoints_get( + model_id=self.id + ) for endpoint in mdl_endpoints_list: if endpoint.environment.is_default: return ModelEndpoint(endpoint) @@ -356,14 +367,15 @@ def list_endpoint(self) -> List[ModelEndpoint]: :return: List[ModelEndpoint] """ mdl_endpoints_api = ModelEndpointsApi(self._api_client) - mdl_endpoints_list = \ - mdl_endpoints_api.models_model_id_endpoints_get(model_id=self.id) + mdl_endpoints_list = mdl_endpoints_api.models_model_id_endpoints_get( + model_id=self.id + ) mdl_endpoints = [] for mdl_ep in mdl_endpoints_list: mdl_endpoints.append(ModelEndpoint(mdl_ep)) return mdl_endpoints - def get_version(self, id: int) -> Optional['ModelVersion']: + def get_version(self, id: int) -> Optional["ModelVersion"]: """ Get version with specific ID @@ -377,18 +389,20 @@ def get_version(self, id: int) -> Optional['ModelVersion']: return ModelVersion(v, self, self._api_client) return None - def list_version(self, labels: Dict[str, List[str]] = None) -> List['ModelVersion']: + def list_version(self, labels: Dict[str, List[str]] = None) -> List["ModelVersion"]: """ List all version of the model :return: list of ModelVersion """ - result: List['ModelVersion'] = [] + result: List["ModelVersion"] = [] search_dsl = self._build_search_labels_dsl(labels) versions, cursor = self._list_version_pagination(search=search_dsl) result = result + versions while cursor != "": - versions, cursor = self._list_version_pagination(cursor=cursor, search=search_dsl) + versions, cursor = self._list_version_pagination( + cursor=cursor, search=search_dsl + ) result = result + versions return result @@ -403,7 +417,9 @@ def _build_search_labels_dsl(self, labels: Dict[str, List[str]] = None): return f"labels:{','.join(all_search_kv_pair)}" - def _list_version_pagination(self, limit=DEFAULT_MODEL_VERSION_LIMIT, cursor="", search="") -> Tuple[List['ModelVersion'], str]: + def _list_version_pagination( + self, limit=DEFAULT_MODEL_VERSION_LIMIT, cursor="", search="" + ) -> Tuple[List["ModelVersion"], str]: """ List version of the model with pagination :param limit: integer, max number of rows will be returned @@ -414,15 +430,20 @@ def _list_version_pagination(self, limit=DEFAULT_MODEL_VERSION_LIMIT, cursor="", :return: next cursor to fetch next page of version """ version_api = VersionApi(self._api_client) - (versions, _, headers) = version_api.models_model_id_versions_get_with_http_info( - int(self.id), limit=limit, cursor=cursor, search=search) + ( + versions, + _, + headers, + ) = version_api.models_model_id_versions_get_with_http_info( + int(self.id), limit=limit, cursor=cursor, search=search + ) next_cursor = headers.get("Next-Cursor") or "" result = [] for v in versions: result.append(ModelVersion(v, self, self._api_client)) return result, next_cursor - def new_model_version(self, labels: Dict[str, str] = None) -> 'ModelVersion': + def new_model_version(self, labels: Dict[str, str] = None) -> "ModelVersion": """ Create a new version of this model @@ -430,12 +451,15 @@ def new_model_version(self, labels: Dict[str, str] = None) -> 'ModelVersion': :return: new ModelVersion """ version_api = VersionApi(self._api_client) - python_version = f'{version_info.major}.{version_info.minor}.*' # capture user's python version - v = version_api.models_model_id_versions_post(int(self.id), body={"labels": labels, "python_version": python_version}) + python_version = f"{version_info.major}.{version_info.minor}.*" # capture user's python version + v = version_api.models_model_id_versions_post( + int(self.id), body={"labels": labels, "python_version": python_version} + ) return ModelVersion(v, self, self._api_client) - def serve_traffic(self, traffic_rule: Dict['VersionEndpoint', int], - environment_name: str = None) -> ModelEndpoint: + def serve_traffic( + self, traffic_rule: Dict["VersionEndpoint", int], environment_name: str = None + ) -> ModelEndpoint: """ Set traffic rule for this model. @@ -445,7 +469,8 @@ def serve_traffic(self, traffic_rule: Dict['VersionEndpoint', int], """ if not isinstance(traffic_rule, dict): raise ValueError( - f"Traffic_rule should be dictionary, got: {type(traffic_rule)}") + f"Traffic_rule should be dictionary, got: {type(traffic_rule)}" + ) if len(traffic_rule) > 1: raise ValueError("Traffic splitting is not yet supported") @@ -459,14 +484,17 @@ def serve_traffic(self, traffic_rule: Dict['VersionEndpoint', int], target_env = env.name if target_env is None: - raise ValueError("Unable to find default environment, " - "pass environment_name to the method") + raise ValueError( + "Unable to find default environment, " + "pass environment_name to the method" + ) total_traffic = 0 for version_endpoint, traffic_split in traffic_rule.items(): if version_endpoint.environment_name != target_env: - raise ValueError("Version Endpoint must have same " - "environment as target") + raise ValueError( + "Version Endpoint must have same " "environment as target" + ) if traffic_split <= 0: raise ValueError("Traffic percentage should be non negative") @@ -479,8 +507,7 @@ def serve_traffic(self, traffic_rule: Dict['VersionEndpoint', int], # get existing model endpoint mdl_epi_api = ModelEndpointsApi(self._api_client) - endpoints = mdl_epi_api.models_model_id_endpoints_get( - model_id=self.id) + endpoints = mdl_epi_api.models_model_id_endpoints_get(model_id=self.id) prev_endpoint = None for endpoint in endpoints: if endpoint.environment_name == target_env: @@ -489,25 +516,27 @@ def serve_traffic(self, traffic_rule: Dict['VersionEndpoint', int], if prev_endpoint is None: # create dst = client.ModelEndpointRuleDestination( - version_endpoint_id=version_endpoint.id, weight=100) + version_endpoint_id=version_endpoint.id, weight=100 + ) rule = client.ModelEndpointRule(destinations=[dst]) - ep = client.ModelEndpoint(model_id=self.id, - environment_name=target_env, - rule=rule) - ep = mdl_epi_api.models_model_id_endpoints_post(model_id=self.id, - body=ep.to_dict()) + ep = client.ModelEndpoint( + model_id=self.id, environment_name=target_env, rule=rule + ) + ep = mdl_epi_api.models_model_id_endpoints_post( + model_id=self.id, body=ep.to_dict() + ) else: # update: GET and PUT ep = mdl_epi_api.models_model_id_endpoints_model_endpoint_id_get( - model_id=self.id, - model_endpoint_id=prev_endpoint.id) - ep.rule.destinations[ - 0].version_endpoint_id = version_endpoint.id + model_id=self.id, model_endpoint_id=prev_endpoint.id + ) + ep.rule.destinations[0].version_endpoint_id = version_endpoint.id ep.rule.destinations[0].weight = 100 ep = mdl_epi_api.models_model_id_endpoints_model_endpoint_id_put( model_id=int(self.id), model_endpoint_id=prev_endpoint.id, - body=ep.to_dict()) + body=ep.to_dict(), + ) return ModelEndpoint(ep) @@ -526,8 +555,10 @@ def stop_serving_traffic(self, environment_name: str = None): target_env = env.name if target_env is None: - raise ValueError("Unable to find default environment, " - "pass environment_name to the method") + raise ValueError( + "Unable to find default environment, " + "pass environment_name to the method" + ) mdl_epi_api = ModelEndpointsApi(self._api_client) endpoints = mdl_epi_api.models_model_id_endpoints_get(model_id=self.id) @@ -538,16 +569,20 @@ def stop_serving_traffic(self, environment_name: str = None): target_endpoint = endpoint if target_endpoint is None: - raise ValueError(f"there is no model endpoint for model " - f"{self.name} in {target_env} environment") + raise ValueError( + f"there is no model endpoint for model " + f"{self.name} in {target_env} environment" + ) - print(f"Stopping serving traffic for model {self.name} " - f"in {target_env} environment") - mdl_epi_api \ - .models_model_id_endpoints_model_endpoint_id_delete(self.id, target_endpoint.id) + print( + f"Stopping serving traffic for model {self.name} " + f"in {target_env} environment" + ) + mdl_epi_api.models_model_id_endpoints_model_endpoint_id_delete( + self.id, target_endpoint.id + ) - def set_traffic(self, traffic_rule: Dict['ModelVersion', int]) \ - -> ModelEndpoint: + def set_traffic(self, traffic_rule: Dict["ModelVersion", int]) -> ModelEndpoint: """ Set traffic rule for this model. @@ -556,11 +591,13 @@ def set_traffic(self, traffic_rule: Dict['ModelVersion', int]) \ :param traffic_rule: dict of model version and the percentage of traffic. :return: ModelEndpoint """ - print("This method is going to be deprecated, please use " - "serve_traffic instead") + print( + "This method is going to be deprecated, please use " "serve_traffic instead" + ) if not isinstance(traffic_rule, dict): raise ValueError( - f"Traffic_rule should be dictionary, got: {type(traffic_rule)}") + f"Traffic_rule should be dictionary, got: {type(traffic_rule)}" + ) if len(traffic_rule) > 1: raise ValueError("Traffic splitting is not yet supported") @@ -571,51 +608,56 @@ def set_traffic(self, traffic_rule: Dict['ModelVersion', int]) \ raise ValueError("Traffic percentage should be non negative") total_traffic += traffic_split - if mdl_version.endpoint is None or mdl_version.endpoint.status \ - != Status.RUNNING: + if ( + mdl_version.endpoint is None + or mdl_version.endpoint.status != Status.RUNNING + ): raise ValueError( - f"Model version with id {mdl_version.id} is not running") + f"Model version with id {mdl_version.id} is not running" + ) if total_traffic != 100: - raise ValueError( - f"Total traffic should be 100, got {total_traffic}") + raise ValueError(f"Total traffic should be 100, got {total_traffic}") mdl_version = traffic_rule.popitem()[0] model_endpoint_api = ModelEndpointsApi(self._api_client) if mdl_version.endpoint is None: - raise ValueError(f"there is no version endpoint for model version " - f"{mdl_version.id} in default environment") + raise ValueError( + f"there is no version endpoint for model version " + f"{mdl_version.id} in default environment" + ) def_version_endpoint = mdl_version.endpoint if self.endpoint is None: # create model endpoint - ep = model_endpoint_api.models_model_id_endpoints_post(body={ - "model_id": self.id, - "rule": { - "destinations": [ - { - "version_endpoint_id": def_version_endpoint.id, - "weight": 100 - } - ] - } - }, model_id=int(self.id)) + ep = model_endpoint_api.models_model_id_endpoints_post( + body={ + "model_id": self.id, + "rule": { + "destinations": [ + { + "version_endpoint_id": def_version_endpoint.id, + "weight": 100, + } + ] + }, + }, + model_id=int(self.id), + ) return ModelEndpoint(ep) else: def_model_endpoint = self.endpoint # GET and PUT - ep = model_endpoint_api \ - .models_model_id_endpoints_model_endpoint_id_get( - model_id=int(self.id), - model_endpoint_id=def_model_endpoint.id) - ep.rule.destinations[0] \ - .version_endpoint_id = def_version_endpoint.id + ep = model_endpoint_api.models_model_id_endpoints_model_endpoint_id_get( + model_id=int(self.id), model_endpoint_id=def_model_endpoint.id + ) + ep.rule.destinations[0].version_endpoint_id = def_version_endpoint.id ep.rule.destinations[0].weight = 100 - ep = model_endpoint_api \ - .models_model_id_endpoints_model_endpoint_id_put( - model_id=int(self.id), - model_endpoint_id=def_model_endpoint.id, - body=ep.to_dict()) + ep = model_endpoint_api.models_model_id_endpoints_model_endpoint_id_put( + model_id=int(self.id), + model_endpoint_id=def_model_endpoint.id, + body=ep.to_dict(), + ) return ModelEndpoint(ep) @@ -627,8 +669,9 @@ def delete_model(self) -> int: :return: id of deleted model """ model_api = ModelsApi(self._api_client) - return model_api.projects_project_id_models_model_id_delete(int(self.project.id), int(self.id)) - + return model_api.projects_project_id_models_model_id_delete( + int(self.project.id), int(self.id) + ) @autostr @@ -636,6 +679,7 @@ class ModelVersion: """ Representation of version in a model """ + MODEL_TYPE_TO_IMAGE_MAP = { ModelType.SKLEARN: "gcr.io/kfserving/sklearnserver:0.2.2", ModelType.TENSORFLOW: "tensorflow/serving:1.14.0", @@ -643,8 +687,9 @@ class ModelVersion: ModelType.PYTORCH: "gcr.io/kfserving/pytorchserver:0.2.2", } - def __init__(self, version: client.Version, model: Model, - api_client: client.ApiClient): + def __init__( + self, version: client.Version, model: Model, api_client: client.ApiClient + ): self._api_client = api_client self._id = version.id self._mlflow_run_id = version.mlflow_run_id @@ -854,12 +899,12 @@ def download_artifact(self, destination_path): """ run = self.get_run() if run is None: - raise Exception('There is no mlflow run for this model version') + raise Exception("There is no mlflow run for this model version") run_info = run.info artifact_uri = run_info.artifact_uri if artifact_uri is None or artifact_uri == "": - raise Exception('There is no artifact uri for this model version') + raise Exception("There is no artifact uri for this model version") download_files_from_gcs(artifact_uri, destination_path) @@ -883,11 +928,13 @@ def log_artifact(self, local_path, artifact_path=None): """ mlflow.log_artifact(local_path, artifact_path) - def log_pyfunc_model(self, - model_instance: PythonModel, - conda_env: Union[str, Dict[str, Any]], - code_dir: Optional[List[str]] = None, - artifacts: Dict[str, str] =None): + def log_pyfunc_model( + self, + model_instance: PythonModel, + conda_env: Union[str, Dict[str, Any]], + code_dir: Optional[List[str]] = None, + artifacts: Dict[str, str] = None, + ): """ Upload PyFunc based model into artifact storage. User has to specify model_instance and @@ -900,18 +947,22 @@ def log_pyfunc_model(self, :param code_dir: additional code directory that will be loaded with ModelType.PYFUNC model :param artifacts: dictionary of artifact that will be stored together with the model. This will be passed to PythonModel.initialize. Example: {"config" : "config/staging.yaml"} """ - if self._model.type != ModelType.PYFUNC and self._model.type != ModelType.PYFUNC_V2: - raise ValueError( - "log_pyfunc_model is only for PyFunc and PyFuncV2 model") + if ( + self._model.type != ModelType.PYFUNC + and self._model.type != ModelType.PYFUNC_V2 + ): + raise ValueError("log_pyfunc_model is only for PyFunc and PyFuncV2 model") # add/replace python version in conda to match that used to create model version conda_env = _process_conda_env(conda_env, self._python_version) - mlflow.pyfunc.log_model(DEFAULT_MODEL_PATH, - python_model=model_instance, - code_path=code_dir, - conda_env=conda_env, - artifacts=artifacts) + mlflow.pyfunc.log_model( + DEFAULT_MODEL_PATH, + python_model=model_instance, + code_path=code_dir, + conda_env=conda_env, + artifacts=artifacts, + ) def log_pytorch_model(self, model_dir, model_class_name=None): """ @@ -923,8 +974,10 @@ def log_pytorch_model(self, model_dir, model_class_name=None): if self._model.type != ModelType.PYTORCH: raise ValueError("log_pytorch_model is only for PyTorch model") - warnings.warn("'log_pytorch_model' is deprecated, use 'log_model' instead", - DeprecationWarning) + warnings.warn( + "'log_pytorch_model' is deprecated, use 'log_model' instead", + DeprecationWarning, + ) self.log_model(model_dir) def log_model(self, model_dir=None): @@ -935,17 +988,18 @@ def log_model(self, model_dir=None): :param model_dir: directory which contain serialized model """ - if self._model.type == ModelType.PYFUNC or self._model.type == ModelType.PYFUNC_V2: + if ( + self._model.type == ModelType.PYFUNC + or self._model.type == ModelType.PYFUNC_V2 + ): raise ValueError("use log_pyfunc_model to log pyfunc model") validate_model_dir(self._model.type, model_dir) mlflow.log_artifacts(model_dir, DEFAULT_MODEL_PATH) - def log_custom_model(self, - image: str, - model_dir: str = None, - command: str = "", - args: str = ""): + def log_custom_model( + self, image: str, model_dir: str = None, command: str = "", args: str = "" + ): """ Upload model to artifact storage. This method is used to upload model for custom model type. @@ -963,14 +1017,14 @@ def log_custom_model(self, if model_dir is None: """ - Create temp directory, which later on will be uploaded - The reason is iff no data that will be uploaded to mlflow artifact (gcs), given artifact URI will not exist - Hence will raise error when creating inferenceservice + Create temp directory, which later on will be uploaded + The reason is iff no data that will be uploaded to mlflow artifact (gcs), given artifact URI will not exist + Hence will raise error when creating inferenceservice """ is_using_temp_dir = True model_dir = tempfile.mkdtemp(suffix="merlin-custom-model") - with open(os.path.join(model_dir, model_properties_file), 'w') as writer: + with open(os.path.join(model_dir, model_properties_file), "w") as writer: writer.write(f"image = {image}\n") writer.write(f"command = {command}\n") writer.write(f"args = {args}\n") @@ -986,9 +1040,14 @@ def log_custom_model(self, shutil.rmtree(model_dir) version_api = VersionApi(self._api_client) - custom_predictor_body = client.CustomPredictor(image=image, command=command, args=args) + custom_predictor_body = client.CustomPredictor( + image=image, command=command, args=args + ) version_api.models_model_id_versions_version_id_patch( - int(self.model.id), int(self.id), body={"custom_predictor": custom_predictor_body}) + int(self.model.id), + int(self.id), + body={"custom_predictor": custom_predictor_body}, + ) def list_endpoint(self) -> List[VersionEndpoint]: """ @@ -997,24 +1056,26 @@ def list_endpoint(self) -> List[VersionEndpoint]: :return: List of VersionEndpoint """ endpoint_api = EndpointApi(self._api_client) - ep_list = endpoint_api. \ - models_model_id_versions_version_id_endpoint_get( - model_id=self.model.id, version_id=self.id) + ep_list = endpoint_api.models_model_id_versions_version_id_endpoint_get( + model_id=self.model.id, version_id=self.id + ) endpoints = [] for ep in ep_list: endpoints.append(VersionEndpoint(ep)) return endpoints - def deploy(self, environment_name: str = None, - resource_request: ResourceRequest = None, - env_vars: Dict[str, str] = None, - transformer: Transformer = None, - logger: Logger = None, - deployment_mode: DeploymentMode = DeploymentMode.SERVERLESS, - autoscaling_policy: AutoscalingPolicy = None, - protocol: Protocol = Protocol.HTTP_JSON - ) -> VersionEndpoint: + def deploy( + self, + environment_name: str = None, + resource_request: ResourceRequest = None, + env_vars: Dict[str, str] = None, + transformer: Transformer = None, + logger: Logger = None, + deployment_mode: DeploymentMode = DeploymentMode.SERVERLESS, + autoscaling_policy: AutoscalingPolicy = None, + protocol: Protocol = Protocol.HTTP_JSON, + ) -> VersionEndpoint: """ Deploy current model to MLP One of log_model, log_pytorch_model, and log_pyfunc_model has to be called beforehand @@ -1039,8 +1100,10 @@ def deploy(self, environment_name: str = None, target_env_name = env.name if target_env_name is None: - raise ValueError("Unable to find default environment, " - "pass environment_name to the method") + raise ValueError( + "Unable to find default environment, " + "pass environment_name to the method" + ) if resource_request is None: env_api = EnvironmentApi(self._api_client) @@ -1053,7 +1116,7 @@ def deploy(self, environment_name: str = None, env.default_resource_request.max_replica, env.default_resource_request.cpu_request, env.default_resource_request.memory_request, - ) + ) # This case is when the default resource request is not specified in the environment config if resource_request is None: @@ -1062,10 +1125,16 @@ def deploy(self, environment_name: str = None, resource_request.validate() target_resource_request = client.ResourceRequest( - resource_request.min_replica, resource_request.max_replica, - resource_request.cpu_request, resource_request.memory_request) + resource_request.min_replica, + resource_request.max_replica, + resource_request.cpu_request, + resource_request.memory_request, + ) - if resource_request.gpu_request is not None and resource_request.gpu_name is not None: + if ( + resource_request.gpu_request is not None + and resource_request.gpu_name is not None + ): env_api = EnvironmentApi(self._api_client) env_list = env_api.environments_get() @@ -1073,27 +1142,32 @@ def deploy(self, environment_name: str = None, for gpu in env.gpus: if resource_request.gpu_name == gpu.name: if resource_request.gpu_request not in gpu.values: - raise ValueError(f"Invalid GPU request count. Supported GPUs count for {resource_request.gpu_name} is {gpu.values}") + raise ValueError( + f"Invalid GPU request count. Supported GPUs count for {resource_request.gpu_name} is {gpu.values}" + ) target_resource_request.gpu_name = resource_request.gpu_name - target_resource_request.gpu_request = resource_request.gpu_request + target_resource_request.gpu_request = ( + resource_request.gpu_request + ) break target_env_vars = [] if env_vars is not None: if not isinstance(env_vars, dict): raise ValueError( - f"env_vars should be dictionary, got: {type(env_vars)}") + f"env_vars should be dictionary, got: {type(env_vars)}" + ) if len(env_vars) > 0: for name, value in env_vars.items(): - target_env_vars.append( - client.EnvVar(str(name), str(value))) + target_env_vars.append(client.EnvVar(str(name), str(value))) target_transformer = None if transformer is not None: target_transformer = self.create_transformer_spec( - transformer, target_env_name) + transformer, target_env_name + ) target_logger = None if logger is not None: @@ -1108,16 +1182,18 @@ def deploy(self, environment_name: str = None, model = self._model endpoint_api = EndpointApi(self._api_client) - endpoint = client.VersionEndpoint(environment_name=target_env_name, - resource_request=target_resource_request, - env_vars=target_env_vars, - transformer=target_transformer, - logger=target_logger, - deployment_mode=deployment_mode.value, - autoscaling_policy=client.AutoscalingPolicy(autoscaling_policy.metrics_type.value, - autoscaling_policy.target_value), - protocol=protocol.value - ) + endpoint = client.VersionEndpoint( + environment_name=target_env_name, + resource_request=target_resource_request, + env_vars=target_env_vars, + transformer=target_transformer, + logger=target_logger, + deployment_mode=deployment_mode.value, + autoscaling_policy=client.AutoscalingPolicy( + autoscaling_policy.metrics_type.value, autoscaling_policy.target_value + ), + protocol=protocol.value, + ) current_endpoint = self.endpoint if current_endpoint is not None: # This allows a serving deployment to be update while it is serving @@ -1125,26 +1201,26 @@ def deploy(self, environment_name: str = None, endpoint.status = Status.SERVING.value else: endpoint.status = Status.RUNNING.value - endpoint = endpoint_api \ - .models_model_id_versions_version_id_endpoint_endpoint_id_put(int(model.id), - int(self.id), - current_endpoint.id, - body=endpoint.to_dict()) + endpoint = endpoint_api.models_model_id_versions_version_id_endpoint_endpoint_id_put( + int(model.id), + int(self.id), + current_endpoint.id, + body=endpoint.to_dict(), + ) else: - endpoint = endpoint_api \ - .models_model_id_versions_version_id_endpoint_post(int(model.id), - int(self.id), - body=endpoint.to_dict()) - bar = pyprind.ProgBar(100, track_time=True, - title=f"Deploying model {model.name} version " - f"{self.id}") + endpoint = endpoint_api.models_model_id_versions_version_id_endpoint_post( + int(model.id), int(self.id), body=endpoint.to_dict() + ) + bar = pyprind.ProgBar( + 100, + track_time=True, + title=f"Deploying model {model.name} version " f"{self.id}", + ) while endpoint.status == "pending": - endpoint = endpoint_api \ - .models_model_id_versions_version_id_endpoint_endpoint_id_get( - model_id=int(model.id), - version_id=int(self.id), - endpoint_id=endpoint.id) + endpoint = endpoint_api.models_model_id_versions_version_id_endpoint_endpoint_id_get( + model_id=int(model.id), version_id=int(self.id), endpoint_id=endpoint.id + ) bar.update() sleep(5) bar.stop() @@ -1153,14 +1229,18 @@ def deploy(self, environment_name: str = None, raise ModelEndpointDeploymentError(model.name, self.id, endpoint.message) log_url = f"{self.url}/{self.id}/endpoints/{endpoint.id}/logs" - print(f"Model {model.name} version {self.id} is deployed." - f"\nView model version logs: {log_url}") + print( + f"Model {model.name} version {self.id} is deployed." + f"\nView model version logs: {log_url}" + ) self._version_endpoints = self.list_endpoint() return VersionEndpoint(endpoint, log_url) - def create_transformer_spec(self, transformer: Transformer, target_env_name: str) -> client.Transformer: + def create_transformer_spec( + self, transformer: Transformer, target_env_name: str + ) -> client.Transformer: resource_request = transformer.resource_request if resource_request is None: env_api = EnvironmentApi(self._api_client) @@ -1171,7 +1251,8 @@ def create_transformer_spec(self, transformer: Transformer, target_env_name: str env.default_resource_request.min_replica, env.default_resource_request.max_replica, env.default_resource_request.cpu_request, - env.default_resource_request.memory_request) + env.default_resource_request.memory_request, + ) # This case is when the default resource request is not specified in the environment config if resource_request is None: raise ValueError("resource request must be specified") @@ -1179,24 +1260,32 @@ def create_transformer_spec(self, transformer: Transformer, target_env_name: str resource_request.validate() target_resource_request = client.ResourceRequest( - resource_request.min_replica, resource_request.max_replica, - resource_request.cpu_request, resource_request.memory_request) + resource_request.min_replica, + resource_request.max_replica, + resource_request.cpu_request, + resource_request.memory_request, + ) target_env_vars = [] if transformer.env_vars is not None: if not isinstance(transformer.env_vars, dict): raise ValueError( - f"transformer.env_vars should be dictionary, got: {type(transformer.env_vars)}") + f"transformer.env_vars should be dictionary, got: {type(transformer.env_vars)}" + ) if len(transformer.env_vars) > 0: for name, value in transformer.env_vars.items(): - target_env_vars.append( - client.EnvVar(str(name), str(value))) + target_env_vars.append(client.EnvVar(str(name), str(value))) return client.Transformer( - transformer.enabled, transformer.transformer_type.value, - transformer.image, transformer.command, transformer.args, - target_resource_request, target_env_vars) + transformer.enabled, + transformer.transformer_type.value, + transformer.image, + transformer.command, + transformer.args, + target_resource_request, + target_env_vars, + ) def undeploy(self, environment_name: str = None): """ @@ -1213,13 +1302,15 @@ def undeploy(self, environment_name: str = None): target_env = env.name if target_env is None: - raise ValueError("Unable to find default environment, " - "pass environment_name to the method") + raise ValueError( + "Unable to find default environment, " + "pass environment_name to the method" + ) endpoint_api = EndpointApi(self._api_client) - endpoints = \ - endpoint_api.models_model_id_versions_version_id_endpoint_get( - model_id=self.model.id, version_id=self.id) + endpoints = endpoint_api.models_model_id_versions_version_id_endpoint_get( + model_id=self.model.id, version_id=self.id + ) target_endpoint = None for endpoint in endpoints: if endpoint.environment_name == target_env: @@ -1229,14 +1320,18 @@ def undeploy(self, environment_name: str = None): print(f"No endpoint found for environment: {target_endpoint}") return - print(f"Deleting deployment of model {self._model.name} " - f"version {self.id} from enviroment {target_env}") + print( + f"Deleting deployment of model {self._model.name} " + f"version {self.id} from enviroment {target_env}" + ) endpoint_api = EndpointApi(self._api_client) - endpoint_api \ - .models_model_id_versions_version_id_endpoint_endpoint_id_delete(self.model.id, self.id, - target_endpoint.id) + endpoint_api.models_model_id_versions_version_id_endpoint_endpoint_id_delete( + self.model.id, self.id, target_endpoint.id + ) - def create_prediction_job(self, job_config: PredictionJobConfig, sync: bool = True) -> PredictionJob: + def create_prediction_job( + self, job_config: PredictionJobConfig, sync: bool = True + ) -> PredictionJob: """ Create and run prediction job with given config using this model version @@ -1246,31 +1341,35 @@ def create_prediction_job(self, job_config: PredictionJobConfig, sync: bool = Tr """ if self.model.type != ModelType.PYFUNC_V2: raise ValueError( - f"model type is not supported for prediction job: {self.model.type}") + f"model type is not supported for prediction job: {self.model.type}" + ) - job_cfg = client.PredictionJobConfig(version=V1, kind=PREDICTION_JOB, model={ - "type": self.model.type.value.upper(), - "uri": os.path.join(self.artifact_uri, DEFAULT_MODEL_PATH), - "result": { - "type": job_config.result_type.value, - "item_type": job_config.item_type.value - } - }) + job_cfg = client.PredictionJobConfig( + version=V1, + kind=PREDICTION_JOB, + model={ + "type": self.model.type.value.upper(), + "uri": os.path.join(self.artifact_uri, DEFAULT_MODEL_PATH), + "result": { + "type": job_config.result_type.value, + "item_type": job_config.item_type.value, + }, + }, + ) if isinstance(job_config.source, BigQuerySource): job_cfg.bigquery_source = job_config.source.to_dict() else: - raise ValueError( - f"source type is not supported {type(job_config.source)}") + raise ValueError(f"source type is not supported {type(job_config.source)}") if isinstance(job_config.sink, BigQuerySink): job_cfg.bigquery_sink = job_config.sink.to_dict() else: - raise ValueError( - f"sink type is not supported {type(job_config.sink)}") + raise ValueError(f"sink type is not supported {type(job_config.sink)}") - cfg = client.Config(job_config=job_cfg, - service_account_name=job_config.service_account_name) + cfg = client.Config( + job_config=job_cfg, service_account_name=job_config.service_account_name + ) if job_config.resource_request is not None: cfg.resource_request = job_config.resource_request.to_dict() @@ -1279,46 +1378,49 @@ def create_prediction_job(self, job_config: PredictionJobConfig, sync: bool = Tr if job_config.env_vars is not None: if not isinstance(job_config.env_vars, dict): raise ValueError( - f"env_vars should be dictionary, got: {type(job_config.env_vars)}") + f"env_vars should be dictionary, got: {type(job_config.env_vars)}" + ) if len(job_config.env_vars) > 0: for name, value in job_config.env_vars.items(): target_env_vars.append(client.EnvVar(name, value)) cfg.env_vars = target_env_vars - req = client.PredictionJob(version_id=self.id, - model_id=self.model.id, - config=cfg) + req = client.PredictionJob( + version_id=self.id, model_id=self.model.id, config=cfg + ) job_client = client.PredictionJobsApi(self._api_client) j = job_client.models_model_id_versions_version_id_jobs_post( - model_id=self.model.id, - version_id=self.id, - body=req) + model_id=self.model.id, version_id=self.id, body=req + ) - bar = pyprind.ProgBar(100, track_time=True, - title=f"Running prediction job {j.id} from model {self.model.name} version {self.id} " - f"under project {self.model.project.name}") + bar = pyprind.ProgBar( + 100, + track_time=True, + title=f"Running prediction job {j.id} from model {self.model.name} version {self.id} " + f"under project {self.model.project.name}", + ) retry = DEFAULT_API_CALL_RETRY - while j.status == "pending" or \ - j.status == "running" or \ - j.status == "terminating": + while ( + j.status == "pending" or j.status == "running" or j.status == "terminating" + ): if not sync: - j = job_client.models_model_id_versions_version_id_jobs_job_id_get(model_id=self.model.id, - version_id=self.id, - job_id=j.id) + j = job_client.models_model_id_versions_version_id_jobs_job_id_get( + model_id=self.model.id, version_id=self.id, job_id=j.id + ) return PredictionJob(j, self._api_client) else: try: - j = job_client.models_model_id_versions_version_id_jobs_job_id_get(model_id=self.model.id, - version_id=self.id, - job_id=j.id) + j = job_client.models_model_id_versions_version_id_jobs_job_id_get( + model_id=self.model.id, version_id=self.id, job_id=j.id + ) retry = DEFAULT_API_CALL_RETRY except Exception: - retry -= 1 - if retry == 0: - j.status = "failed" - break - sleep(DEFAULT_PREDICTION_JOB_RETRY_DELAY) + retry -= 1 + if retry == 0: + j.status = "failed" + break + sleep(DEFAULT_PREDICTION_JOB_RETRY_DELAY) bar.update() sleep(DEFAULT_PREDICTION_JOB_DELAY) bar.stop() @@ -1336,19 +1438,22 @@ def list_prediction_job(self) -> List[PredictionJob]: """ job_client = client.PredictionJobsApi(self._api_client) res = job_client.models_model_id_versions_version_id_jobs_get( - model_id=self.model.id, - version_id=self.id) + model_id=self.model.id, version_id=self.id + ) jobs = [] for j in res: jobs.append(PredictionJob(j, self._api_client)) return jobs - def start_server(self, env_vars: Dict[str, str] = None, - port: int = 8080, - pyfunc_base_image: str = None, - kill_existing_server: bool = False, - tmp_dir: Optional[str] = os.environ.get("MERLIN_TMP_DIR"), - build_image: bool = False): + def start_server( + self, + env_vars: Dict[str, str] = None, + port: int = 8080, + pyfunc_base_image: str = None, + kill_existing_server: bool = False, + tmp_dir: Optional[str] = os.environ.get("MERLIN_TMP_DIR"), + build_image: bool = False, + ): """ Start a local server running the model version @@ -1366,14 +1471,16 @@ def start_server(self, env_vars: Dict[str, str] = None, pathlib.Path(artifact_path).mkdir(parents=True, exist_ok=True) if len(os.listdir(artifact_path)) < 1: print( - f"Downloading model artifact for model {self.model.name} version {self.id}") + f"Downloading model artifact for model {self.model.name} version {self.id}" + ) self.download_artifact(artifact_path) # stop all previous containers to avoid port conflict client = docker.from_env() if kill_existing_server: started_containers = client.containers.list( - filters={"name": self._container_name()}) + filters={"name": self._container_name()} + ) for started_container in started_containers: print(f"Stopping model server {started_container.name}") started_container.remove(force=True) @@ -1381,24 +1488,33 @@ def start_server(self, env_vars: Dict[str, str] = None, model_type = self.model.type if model_type == ModelType.PYFUNC: self._run_pyfunc_local_server( - artifact_path, env_vars, port, pyfunc_base_image) + artifact_path, env_vars, port, pyfunc_base_image + ) return - if model_type == ModelType.TENSORFLOW \ - or model_type == ModelType.XGBOOST \ - or model_type == ModelType.SKLEARN \ - or model_type == ModelType.PYTORCH: + if ( + model_type == ModelType.TENSORFLOW + or model_type == ModelType.XGBOOST + or model_type == ModelType.SKLEARN + or model_type == ModelType.PYTORCH + ): self._run_standard_model_local_server( - artifact_path, env_vars, port, build_image) + artifact_path, env_vars, port, build_image + ) return raise ValueError( - f"running local model server is not supported for model type: {model_type}") + f"running local model server is not supported for model type: {model_type}" + ) def _create_launch_command(self): model_type = self.model.type print(f"model type: {model_type}") - if model_type == ModelType.SKLEARN or model_type == ModelType.XGBOOST or model_type == ModelType.PYTORCH: + if ( + model_type == ModelType.SKLEARN + or model_type == ModelType.XGBOOST + or model_type == ModelType.PYTORCH + ): return f"--port=9000 --rest_api_port=8080 --model_name={self.model.name}-{self.id} --model_dir=/mnt/models" if model_type == ModelType.TENSORFLOW: @@ -1406,7 +1522,9 @@ def _create_launch_command(self): raise ValueError(f"unknown model type: {model_type}") - def _run_standard_model_local_server(self, artifact_path, env_vars, port, build_image): + def _run_standard_model_local_server( + self, artifact_path, env_vars, port, build_image + ): container: Optional[Container] = None # type: ignore try: container_name = self._container_name() @@ -1418,16 +1536,13 @@ def _run_standard_model_local_server(self, artifact_path, env_vars, port, build_ image_tag = f"{self.model.project.name}-{self.model.name}:{self.id}" dockerfile_path = copy_standard_dockerfile(artifact_path) print(f"Building {self.model.type} image: {image_tag}") - logs = apiClient.build(path=artifact_path, - tag=image_tag, - buildargs={ - "BASE_IMAGE": image_name, - "MODEL_PATH": artifact_path - }, - dockerfile=os.path.basename( - dockerfile_path), - decode=True - ) + logs = apiClient.build( + path=artifact_path, + tag=image_tag, + buildargs={"BASE_IMAGE": image_name, "MODEL_PATH": artifact_path}, + dockerfile=os.path.basename(dockerfile_path), + decode=True, + ) self._wait_build_complete(logs) image_name = image_tag @@ -1438,16 +1553,17 @@ def _run_standard_model_local_server(self, artifact_path, env_vars, port, build_ volumes = None client = docker.from_env() - container = client.containers.run(image_name, - name=container_name, - labels={"managed-by": "merlin"}, - command=cmd, - ports={"8080/tcp": port}, - volumes=volumes, - environment=env_vars, - detach=True, - remove=True - ) + container = client.containers.run( + image_name, + name=container_name, + labels={"managed-by": "merlin"}, + command=cmd, + ports={"8080/tcp": port}, + volumes=volumes, + environment=env_vars, + detach=True, + remove=True, + ) # continously print docker log until the process is interrupted for log in container.logs(stream=True): @@ -1456,7 +1572,9 @@ def _run_standard_model_local_server(self, artifact_path, env_vars, port, build_ if container is not None: container.remove(force=True) - def _run_pyfunc_local_server(self, artifact_path, env_vars, port, pyfunc_base_image): + def _run_pyfunc_local_server( + self, artifact_path, env_vars, port, pyfunc_base_image + ): if pyfunc_base_image is None: if "dev" in VERSION: pyfunc_base_image = "ghcr.io/caraml-dev/merlin-pyfunc-base:dev" @@ -1468,15 +1586,13 @@ def _run_pyfunc_local_server(self, artifact_path, env_vars, port, pyfunc_base_im client = docker.from_env() apiClient = APIClient() print(f"Building pyfunc image: {image_tag}") - logs = apiClient.build(path=artifact_path, - tag=image_tag, - buildargs={ - "BASE_IMAGE": pyfunc_base_image, - "MODEL_PATH": artifact_path - }, - dockerfile=os.path.basename(dockerfile_path), - decode=True - ) + logs = apiClient.build( + path=artifact_path, + tag=image_tag, + buildargs={"BASE_IMAGE": pyfunc_base_image, "MODEL_PATH": artifact_path}, + dockerfile=os.path.basename(dockerfile_path), + decode=True, + ) self._wait_build_complete(logs) container: Optional[Container] = None # type: ignore @@ -1490,14 +1606,15 @@ def _run_pyfunc_local_server(self, artifact_path, env_vars, port, pyfunc_base_im env_vars["MODEL_NAME"] = f"{self.model.name}-{self.id}" env_vars["WORKERS"] = 1 env_vars["PORT"] = 8080 - container = client.containers.run(image=image_tag, - name=container_name, - labels={"managed-by": "merlin"}, - ports={"8080/tcp": port}, - environment=env_vars, - detach=True, - remove=True - ) + container = client.containers.run( + image=image_tag, + name=container_name, + labels={"managed-by": "merlin"}, + ports={"8080/tcp": port}, + environment=env_vars, + detach=True, + remove=True, + ) # continously print docker log until the process is interrupted for log in container.logs(stream=True): @@ -1511,19 +1628,18 @@ def _container_name(self): def _wait_build_complete(self, logs): for chunk in logs: - if 'error' in chunk: - raise BuildError(chunk['error'], logs) - if 'stream' in chunk: + if "error" in chunk: + raise BuildError(chunk["error"], logs) + if "stream" in chunk: match = re.search( - r'(^Successfully built |sha256:)([0-9a-f]+)$', - chunk['stream'] + r"(^Successfully built |sha256:)([0-9a-f]+)$", chunk["stream"] ) if match: image_id = match.group(2) last_event = chunk if image_id: return - raise BuildError('Unknown', logs) + raise BuildError("Unknown", logs) def delete_model_version(self) -> int: """ @@ -1533,17 +1649,22 @@ def delete_model_version(self) -> int: :return: id of deleted model """ versionApi = VersionApi(self._api_client) - return versionApi.models_model_id_versions_version_id_delete(int(self.model.id), int(self.id)) + return versionApi.models_model_id_versions_version_id_delete( + int(self.model.id), int(self.id) + ) -def _process_conda_env(conda_env: Union[str, Dict[str, Any]], python_version: str) -> Dict[str, Any]: +def _process_conda_env( + conda_env: Union[str, Dict[str, Any]], python_version: str +) -> Dict[str, Any]: """ - This function will replace/add python version dependency to the conda environment file. + This function will replace/add python version dependency to the conda environment file. - :param conda_env: Either a dictionary representation of a conda environment or the path to a conda environment yaml file. - :param python_version: The python version to replace the conda environment with - :return: dict representation of the conda environment with python_version set as per given input + :param conda_env: Either a dictionary representation of a conda environment or the path to a conda environment yaml file. + :param python_version: The python version to replace the conda environment with + :return: dict representation of the conda environment with python_version set as per given input """ + def match_dependency(spec, name): # Using direct match or regex match to match the dependency name, # where the regex accounts for the official conda dependency formats: @@ -1564,11 +1685,14 @@ def match_dependency(spec, name): elif isinstance(conda_env, dict): new_conda_env = conda_env - if 'dependencies' not in new_conda_env: - new_conda_env['dependencies'] = [] + if "dependencies" not in new_conda_env: + new_conda_env["dependencies"] = [] # Replace python dependency to match minor version - new_conda_env['dependencies'] = ([f'python={python_version}'] + - [spec for spec in new_conda_env['dependencies'] if not match_dependency(spec, 'python')]) + new_conda_env["dependencies"] = [f"python={python_version}"] + [ + spec + for spec in new_conda_env["dependencies"] + if not match_dependency(spec, "python") + ] return new_conda_env diff --git a/python/sdk/merlin/transformer.py b/python/sdk/merlin/transformer.py index 86c02f60f..28bd36bac 100644 --- a/python/sdk/merlin/transformer.py +++ b/python/sdk/merlin/transformer.py @@ -12,32 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional +from typing import Dict, Optional, Any +from merlin.protocol import Protocol from merlin.resource_request import ResourceRequest from merlin.util import autostr +from merlin import fluent + from enum import Enum -from client import EnvironmentApi import client import yaml import json -class TransformerType(Enum): - CUSTOM_TRANSFORMER = 'custom' - STANDARD_TRANSFORMER = 'standard' +class TransformerType(Enum): + CUSTOM_TRANSFORMER = "custom" + STANDARD_TRANSFORMER = "standard" @autostr class Transformer: StandardTransformerConfigKey = "STANDARD_TRANSFORMER_CONFIG" - def __init__(self, image: str, enabled: bool = True, - command: str = None, args: str = None, - resource_request: ResourceRequest = None, - env_vars: Dict[str, str] = None, - transformer_type: TransformerType = TransformerType.CUSTOM_TRANSFORMER): + def __init__( + self, + image: str, + enabled: bool = True, + command: str = None, + args: str = None, + resource_request: ResourceRequest = None, + env_vars: Dict[str, str] = None, + transformer_type: TransformerType = TransformerType.CUSTOM_TRANSFORMER, + ): self._image = image self._enabled = enabled self._command = command @@ -76,20 +83,56 @@ def transformer_type(self) -> TransformerType: class StandardTransformer(Transformer): - def __init__(self, config_file: str, enabled: bool = True, - resource_request: ResourceRequest = None, - env_vars: Dict[str, str] = None): - - transformer_config = self._load_transformer_config(config_file) + def __init__( + self, + config_file: str, + enabled: bool = True, + resource_request: ResourceRequest = None, + env_vars: Dict[str, str] = None, + ): + self._load_transformer_config(config_file) + transformer_env_var = self._transformer_config_env_var() merged_env_vars = env_vars or {} - merged_env_vars = {**merged_env_vars, **transformer_config} - super().__init__(image="", enabled=enabled, resource_request=resource_request, - env_vars=merged_env_vars, transformer_type=TransformerType.STANDARD_TRANSFORMER) + merged_env_vars = {**merged_env_vars, **transformer_env_var} + super().__init__( + image="", + enabled=enabled, + resource_request=resource_request, + env_vars=merged_env_vars, + transformer_type=TransformerType.STANDARD_TRANSFORMER, + ) def _load_transformer_config(self, config_file: str): with open(config_file, "r") as stream: - transformer_config = yaml.safe_load(stream) + self.transformer_config = yaml.safe_load(stream) - config_json_string = json.dumps(transformer_config) + def _transformer_config_env_var(self): + config_json_string = json.dumps(self.transformer_config) return {self.StandardTransformerConfigKey: config_json_string} + def simulate( + self, + payload: Dict, + headers: Optional[Dict[Any, Any]] = None, + model_prediction_config: Optional[Dict[Any, Any]] = None, + protocol: str = "HTTP_JSON", + exclude_tracing: bool = False, + ) -> Dict: + fluent._check_active_client() + if not fluent._merlin_client: + raise Exception("Merlin client is not initialized") + + response = fluent._merlin_client.standard_transformer_simulate( + payload=payload, + headers=headers, + config=self.transformer_config, + model_prediction_config=model_prediction_config, + protocol=protocol, + ) + + # if exclude tracing delete key operation_tracing + response = response.to_dict() + if exclude_tracing: + del response["operation_tracing"] + + return response diff --git a/python/sdk/test/integration_test.py b/python/sdk/test/integration_test.py index c8455b29d..5427d170f 100644 --- a/python/sdk/test/integration_test.py +++ b/python/sdk/test/integration_test.py @@ -1124,3 +1124,60 @@ def test_redeploy_model(integration_test_url, project_name, use_google_oauth, re def deployment_mode_suffix(deployment_mode: DeploymentMode): return deployment_mode.value.lower()[0:1] + + +@pytest.mark.integration +def test_standard_transformer_simulate(integration_test_url, use_google_oauth): + """ + Test the `simulate` method of the `StandardTransformer` class. + """ + merlin.set_url(integration_test_url, use_google_oauth=use_google_oauth) + + transformer_config_path = os.path.join( + "test/transformer", "standard_transformer_no_feast.yaml" + ) + transformer = StandardTransformer( + config_file=transformer_config_path, enabled=False + ) + + payload = { + "drivers": [ + # 1 Feb 2022, 00:00:00 + { + "id": 1, + "name": "driver-1", + "vehicle": "motorcycle", + "previous_vehicle": "suv", + "rating": 4, + "ep_time": 1643673600, + }, + # 30 Jan 2022, 00:00:00 + { + "id": 2, + "name": "driver-2", + "vehicle": "sedan", + "previous_vehicle": "mpv", + "rating": 3, + "ep_time": 1643500800, + }, + ], + "customer": {"id": 1111}, + } + + resp_wo_tracing = transformer.simulate(payload=payload, exclude_tracing=True) + resp_w_tracing = transformer.simulate(payload=payload, exclude_tracing=False) + + with open("test/transformer/sim_exp_resp_valid_wo_tracing.json", "r") as f: + exp_resp_valid_wo_tracing = json.load(f) + + with open("test/transformer/sim_exp_resp_valid_w_tracing.json", "r") as f: + exp_resp_valid_w_tracing = json.load(f) + + assert isinstance(resp_wo_tracing, dict) + assert isinstance(resp_w_tracing, dict) + assert "response" in resp_wo_tracing.keys() + assert "response" in resp_w_tracing.keys() + assert "operation_tracing" not in resp_wo_tracing.keys() + assert "operation_tracing" in resp_w_tracing.keys() + assert resp_wo_tracing == exp_resp_valid_wo_tracing + assert resp_w_tracing == exp_resp_valid_w_tracing diff --git a/python/sdk/test/transformer/sim_exp_resp_valid_w_tracing.json b/python/sdk/test/transformer/sim_exp_resp_valid_w_tracing.json new file mode 100644 index 000000000..4edba8728 --- /dev/null +++ b/python/sdk/test/transformer/sim_exp_resp_valid_w_tracing.json @@ -0,0 +1,224 @@ +{ + "response": { + "instances": { + "columns": [ + "customer_id", + "name", + "rank", + "rating", + "vehicle", + "previous_vehicle", + "ep_time_x", + "ep_time_y" + ], + "data": [ + [1111, "driver-2", 2.5, 0.5, 2, 3, 1, -4.100007228307977e-13], + [1111, "driver-1", -2.5, 0.75, 0, 1, 1, -6.364838707220068e-12] + ] + } + }, + "operation_tracing": { + "preprocess": [ + { + "input": null, + "output": { "customer_id": 1111 }, + "spec": { "name": "customer_id", "jsonPath": "$.customer.id" }, + "operation_type": "variable_op" + }, + { + "input": null, + "output": { + "driver_table": [ + { + "ep_time": 1643673600, + "id": 1, + "name": "driver-1", + "previous_vehicle": "suv", + "rating": 4, + "row_number": 0, + "vehicle": "motorcycle" + }, + { + "ep_time": 1643500800, + "id": 2, + "name": "driver-2", + "previous_vehicle": "mpv", + "rating": 3, + "row_number": 1, + "vehicle": "sedan" + } + ] + }, + "spec": { + "name": "driver_table", + "baseTable": { + "fromJson": { "jsonPath": "$.drivers[*]", "addRowNumber": true } + } + }, + "operation_type": "create_table_op" + }, + { + "input": null, + "output": { + "vehicle_mapping": "The result of this operation is on the transformer step that use this encoder" + }, + "spec": { + "name": "vehicle_mapping", + "ordinalEncoderConfig": { + "defaultValue": "0", + "targetValueType": "INT", + "mapping": { "mpv": "3", "sedan": "2", "suv": "1" } + } + }, + "operation_type": "encoder_op" + }, + { + "input": null, + "output": { + "daily_cycle": "The result of this operation is on the transformer step that use this encoder" + }, + "spec": { + "name": "daily_cycle", + "cyclicalEncoderConfig": { "byEpochTime": { "periodType": "DAY" } } + }, + "operation_type": "encoder_op" + }, + { + "input": { + "driver_table": [ + { + "ep_time": 1643673600, + "id": 1, + "name": "driver-1", + "previous_vehicle": "suv", + "rating": 4, + "row_number": 0, + "vehicle": "motorcycle" + }, + { + "ep_time": 1643500800, + "id": 2, + "name": "driver-2", + "previous_vehicle": "mpv", + "rating": 3, + "row_number": 1, + "vehicle": "sedan" + } + ] + }, + "output": { + "transformed_driver_table": [ + { + "customer_id": 1111, + "ep_time_x": 1, + "ep_time_y": -4.100007228307977e-13, + "name": "driver-2", + "previous_vehicle": 3, + "rank": 2.5, + "rating": 0.5, + "vehicle": 2 + }, + { + "customer_id": 1111, + "ep_time_x": 1, + "ep_time_y": -6.364838707220068e-12, + "name": "driver-1", + "previous_vehicle": 1, + "rank": -2.5, + "rating": 0.75, + "vehicle": 0 + } + ] + }, + "spec": { + "inputTable": "driver_table", + "outputTable": "transformed_driver_table", + "steps": [ + { "dropColumns": ["id"] }, + { "sort": [{ "column": "row_number", "order": "DESC" }] }, + { "renameColumns": { "row_number": "rank" } }, + { + "updateColumns": [ + { "column": "customer_id", "expression": "customer_id" } + ] + }, + { + "scaleColumns": [ + { + "column": "rank", + "standardScalerConfig": { "mean": 0.5, "std": 0.2 } + } + ] + }, + { + "scaleColumns": [ + { + "column": "rating", + "minMaxScalerConfig": { "min": 1, "max": 5 } + } + ] + }, + { + "encodeColumns": [ + { + "columns": ["vehicle", "previous_vehicle"], + "encoder": "vehicle_mapping" + }, + { "columns": ["ep_time"], "encoder": "daily_cycle" } + ] + }, + { + "selectColumns": [ + "customer_id", + "name", + "rank", + "rating", + "vehicle", + "previous_vehicle", + "ep_time_x", + "ep_time_y" + ] + } + ] + }, + "operation_type": "table_transform_op" + }, + { + "input": null, + "output": { + "instances": { + "columns": [ + "customer_id", + "name", + "rank", + "rating", + "vehicle", + "previous_vehicle", + "ep_time_x", + "ep_time_y" + ], + "data": [ + [1111, "driver-2", 2.5, 0.5, 2, 3, 1, -4.100007228307977e-13], + [1111, "driver-1", -2.5, 0.75, 0, 1, 1, -6.364838707220068e-12] + ] + } + }, + "spec": { + "jsonTemplate": { + "fields": [ + { + "fieldName": "instances", + "fromTable": { + "tableName": "transformed_driver_table", + "format": "SPLIT" + } + } + ] + } + }, + "operation_type": "json_output_op" + } + ], + "postprocess": [] + } +} diff --git a/python/sdk/test/transformer/sim_exp_resp_valid_wo_tracing.json b/python/sdk/test/transformer/sim_exp_resp_valid_wo_tracing.json new file mode 100644 index 000000000..119012193 --- /dev/null +++ b/python/sdk/test/transformer/sim_exp_resp_valid_wo_tracing.json @@ -0,0 +1,20 @@ +{ + "response": { + "instances": { + "columns": [ + "customer_id", + "name", + "rank", + "rating", + "vehicle", + "previous_vehicle", + "ep_time_x", + "ep_time_y" + ], + "data": [ + [1111, "driver-2", 2.5, 0.5, 2, 3, 1, -4.100007228307977e-13], + [1111, "driver-1", -2.5, 0.75, 0, 1, 1, -6.364838707220068e-12] + ] + } + } +} diff --git a/python/sdk/test/transformer_test.py b/python/sdk/test/transformer_test.py index 92a041096..24280fd4e 100644 --- a/python/sdk/test/transformer_test.py +++ b/python/sdk/test/transformer_test.py @@ -21,8 +21,12 @@ @pytest.mark.unit def test_feast_enricher(): transformer_config_path = os.path.join("test/transformer", "feast_enricher.yaml") - transformer = StandardTransformer(config_file=transformer_config_path, enabled=False) - assert transformer.env_vars == {'STANDARD_TRANSFORMER_CONFIG': '{"transformerConfig": {"feast": [{"project": "merlin", "entities": [{"name": "merlin_test_driver_id", "valueType": "STRING", "jsonPath": "$.driver_id"}], "features": [{"name": "merlin_test_driver_features:test_int32", "valueType": "INT32", "defaultValue": "0"}, {"name": "merlin_test_driver_features:test_float", "valueType": "FLOAT", "defaultValue": "0.0"}, {"name": "merlin_test_driver_features:test_double", "valueType": "DOUBLE", "defaultValue": "0.0"}, {"name": "merlin_test_driver_features:test_string", "valueType": "STRING", "defaultValue": ""}]}]}}'} + transformer = StandardTransformer( + config_file=transformer_config_path, enabled=False + ) + assert transformer.env_vars == { + "STANDARD_TRANSFORMER_CONFIG": '{"transformerConfig": {"feast": [{"project": "merlin", "entities": [{"name": "merlin_test_driver_id", "valueType": "STRING", "jsonPath": "$.driver_id"}], "features": [{"name": "merlin_test_driver_features:test_int32", "valueType": "INT32", "defaultValue": "0"}, {"name": "merlin_test_driver_features:test_float", "valueType": "FLOAT", "defaultValue": "0.0"}, {"name": "merlin_test_driver_features:test_double", "valueType": "DOUBLE", "defaultValue": "0.0"}, {"name": "merlin_test_driver_features:test_string", "valueType": "STRING", "defaultValue": ""}]}]}}' + } assert not transformer.enabled assert transformer.command is None assert transformer.args is None @@ -31,12 +35,19 @@ def test_feast_enricher(): def test_feast_enricher_with_env_vars(): transformer_config_path = os.path.join("test/transformer", "feast_enricher.yaml") - resource = ResourceRequest(min_replica=1, max_replica=2, cpu_request="100m", memory_request="128Mi") - transformer = StandardTransformer(config_file=transformer_config_path, - enabled=True, - resource_request=resource, - env_vars={"MODEL_URL": "http://model.default"}) - assert transformer.env_vars == {'MODEL_URL': "http://model.default", 'STANDARD_TRANSFORMER_CONFIG': '{"transformerConfig": {"feast": [{"project": "merlin", "entities": [{"name": "merlin_test_driver_id", "valueType": "STRING", "jsonPath": "$.driver_id"}], "features": [{"name": "merlin_test_driver_features:test_int32", "valueType": "INT32", "defaultValue": "0"}, {"name": "merlin_test_driver_features:test_float", "valueType": "FLOAT", "defaultValue": "0.0"}, {"name": "merlin_test_driver_features:test_double", "valueType": "DOUBLE", "defaultValue": "0.0"}, {"name": "merlin_test_driver_features:test_string", "valueType": "STRING", "defaultValue": ""}]}]}}'} + resource = ResourceRequest( + min_replica=1, max_replica=2, cpu_request="100m", memory_request="128Mi" + ) + transformer = StandardTransformer( + config_file=transformer_config_path, + enabled=True, + resource_request=resource, + env_vars={"MODEL_URL": "http://model.default"}, + ) + assert transformer.env_vars == { + "MODEL_URL": "http://model.default", + "STANDARD_TRANSFORMER_CONFIG": '{"transformerConfig": {"feast": [{"project": "merlin", "entities": [{"name": "merlin_test_driver_id", "valueType": "STRING", "jsonPath": "$.driver_id"}], "features": [{"name": "merlin_test_driver_features:test_int32", "valueType": "INT32", "defaultValue": "0"}, {"name": "merlin_test_driver_features:test_float", "valueType": "FLOAT", "defaultValue": "0.0"}, {"name": "merlin_test_driver_features:test_double", "valueType": "DOUBLE", "defaultValue": "0.0"}, {"name": "merlin_test_driver_features:test_string", "valueType": "STRING", "defaultValue": ""}]}]}}', + } assert transformer.enabled assert transformer.command is None assert transformer.args is None