Skip to content

Commit

Permalink
Merge pull request #48 from IBM/watson_nlp
Browse files Browse the repository at this point in the history
Add support for Watson NLP deployments
  • Loading branch information
dhruv5995 authored Nov 10, 2023
2 parents bd03b83 + 0b2b3ee commit 5bff80f
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 5 deletions.
10 changes: 6 additions & 4 deletions mlflow_watsonml/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,18 @@ def create_deployment(
)

artifact_name = f"{name}_v1"
environment_variables = get_mlflow_config()

artifact_id, revision_id = store_or_update_artifact(
client=client,
model_uri=model_uri,
artifact_name=artifact_name,
flavor=flavor,
software_spec_id=software_spec_id,
environment_variables=environment_variables,
)

batch = config.get("batch", False)
environment_variables = get_mlflow_config()

hardware_spec_name = config.get("hardware_spec_name")
if hardware_spec_name is not None:
Expand Down Expand Up @@ -309,13 +310,14 @@ def update_deployment(
conda_yaml=conda_yaml,
rewrite=True,
)

environment_variables = get_mlflow_config()
artifact_id, revision_id = store_or_update_artifact(
client=client,
model_uri=model_uri,
artifact_name=new_artifact_name,
flavor=flavor,
software_spec_id=software_spec_id,
environment_variables=environment_variables,
)

deployment_details = update_deployment(
Expand Down Expand Up @@ -425,10 +427,10 @@ def predict(
client.deployments.ScoringMetaNames.INPUT_DATA: [{"values": inputs}]
}

if "custom" in deployment_details["entity"]["asset"].keys():
if "custom" in deployment_details["entity"].keys():
scoring_payload[
client.deployments.ScoringMetaNames.ENVIRONMENT_VARIABLES
] = deployment_details["entity"]["asset"]["custom"]
] = deployment_details["entity"]["custom"]

deployment_id = client.deployments.get_id(deployment_details=deployment_details)

Expand Down
75 changes: 75 additions & 0 deletions mlflow_watsonml/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,78 @@ def store_sklearn_artifact(
)

return (model_id, rev_id)


def store_watson_nlp_artifact(
client: APIClient,
model_uri: str,
artifact_name: str,
software_spec_id: str,
artifact_id: Optional[str] = None,
config: Optional[Dict] = None,
) -> Tuple[str, str]:
"""store watson nlp artifact in WML
Parameters
----------
client : APIClient
WML client
model_uri : str
model URI
artifact_name : str
name of the artifact
software_spec_id : str
id of software specification
artifact_id : Optional[str], optional
artifact id of the stored model, by default None
Returns
-------
Tuple[str, str]
model id, revision id
"""

# the args have to be passed as default value in the scorer
def deployable_watson_nlp_scorer(artifact_uri=model_uri, config=config):
import os
import tempfile

import mlflow
import watson_nlp # type: ignore

for key, val in config.items(): # type: ignore
os.environ[key] = val

def score(payload: dict):
artifact_dir = os.path.join(tempfile.gettempdir(), "artifacts")

# `download_artifacts` returns the local path if it's already been downloaded
artifact_file = mlflow.artifacts.download_artifacts(
artifact_uri=artifact_uri, dst_path=artifact_dir
)

model = watson_nlp.load(artifact_file)

scoring_output = {"predictions": []}

for data in payload["input_data"]:
values = data.get("values")
# fields = data.get("fields")
predictions = model.run_batch(values)
predictions = [prediction.to_dict() for prediction in predictions]

scoring_output["predictions"].append({"values": predictions})

return scoring_output

return score

function_id, rev_id = store_or_update_function(
client=client,
deployable_function=deployable_watson_nlp_scorer,
function_name=artifact_name,
software_spec_uid=software_spec_id,
function_id=artifact_id,
)

return (function_id, rev_id)
2 changes: 1 addition & 1 deletion mlflow_watsonml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def refine_conda_yaml(conda_yaml: str) -> str:
# TODO: implement logic to make sure the environment variables are set
def get_mlflow_config() -> Dict:
return {
"MLFLOW_TRACKING_URI": os.environ.get("MLFLOW_TRACKING_URI"),
# "MLFLOW_TRACKING_URI": os.environ.get("MLFLOW_TRACKING_URI", ""),
"MLFLOW_S3_ENDPOINT_URL": os.environ.get("MLFLOW_S3_ENDPOINT_URL"),
"AWS_SECRET_ACCESS_KEY": os.environ.get("AWS_SECRET_ACCESS_KEY"),
"AWS_ACCESS_KEY_ID": os.environ.get("AWS_ACCESS_KEY_ID"),
Expand Down
11 changes: 11 additions & 0 deletions mlflow_watsonml/wml.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def store_or_update_artifact(
flavor: str,
software_spec_id: str,
artifact_id: Optional[str] = None,
environment_variables: Optional[Dict] = None,
) -> Tuple[str, str]:
if flavor == "sklearn":
artifact_id, revision_id = store_sklearn_artifact(
Expand All @@ -183,6 +184,16 @@ def store_or_update_artifact(
artifact_id=artifact_id,
)

elif flavor == "watson_nlp":
artifact_id, revision_id = store_watson_nlp_artifact(
client=client,
model_uri=model_uri,
artifact_name=artifact_name,
software_spec_id=software_spec_id,
artifact_id=artifact_id,
config=environment_variables,
)

else:
raise MlflowException(
f"Flavor {flavor} is invalid or not implemented",
Expand Down

0 comments on commit 5bff80f

Please sign in to comment.