Skip to content

Commit

Permalink
experimental support for LLM endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
ahosler committed Oct 22, 2023
1 parent 652ca51 commit 7cee4ec
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 0 deletions.
11 changes: 11 additions & 0 deletions ads/opctl/operator/lowcode/forecast/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,17 @@ def generate_report(self):

test_eval_metrics = [sec7_text, sec7, sec8_text, sec8]

if self.spec.llm_endpoint is not None:
metric_str = ""
with pd.option_context("display.float_format", "{:0.2f}".format):
metric_str = self.test_eval_metrics.to_string()
llm_description = utils.describe_metrics(
self.spec.llm_endpoint, metric_str
)
llm_header = dp.Text("## Analysis (Powered by Generative AI)")
llm_body = dp.Text(llm_description)
test_eval_metrics = test_eval_metrics + [llm_header, llm_body]

forecast_text = dp.Text(f"## Forecasted Data Overlaying Historical")
forecast_sec = utils.get_forecast_plots(
self.data,
Expand Down
1 change: 1 addition & 0 deletions ads/opctl/operator/lowcode/forecast/operator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class ForecastOperatorSpec(DataClassSerializable):
confidence_interval_width: float = None
metric: str = None
tuning: Tuning = field(default_factory=Tuning)
llm_endpoint: str = None

def __post_init__(self):
"""Adjusts the specification details."""
Expand Down
4 changes: 4 additions & 0 deletions ads/opctl/operator/lowcode/forecast/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ spec:
meta:
description: "Explainability, both local and global, can be disabled using this flag. Defaults to false."

llm_endpoint:
type: string
required: false

datetime_column:
type: dict
required: true
Expand Down
48 changes: 48 additions & 0 deletions ads/opctl/operator/lowcode/forecast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,3 +465,51 @@ def select_auto_model(columns: List[str]) -> str:
if columns != None and len(columns) > MAX_COLUMNS_AUTOMLX:
return SupportedModels.Arima
return SupportedModels.AutoMLX


def describe_metrics(llm_endpoint: str, metrics_str: str):
"""
Formats the metrics string into a query and submits it to LLM.
Returns the formatted LLM response
Parameters
------------
columns: str
The ip address of the llm that can be invoked from the operator
metrics_str: str
A string version of the metrics being described
Returns
--------
str
The formatted text of the LLM response
"""
from gradio_client import Client
import tempfile
import json

query = "The following table summarises the evaluation metrics for a machine learning forecasting model. Please evaluate the performance of the model across each metric and then summarise the overall strength of the model. Metrics: "
prompt = [[query + metrics_str, ""]]

logger.debug(f"Full prompt is: {prompt}")

client = Client(llm_endpoint, serialize=False)
result = ""

with tempfile.NamedTemporaryFile(mode="w") as temp:
json.dump(prompt, temp)
temp.flush()

result = client.predict(
temp.name,
1024, # int | float (numeric value between 256 and 4096)
0.2, # int | float (numeric value between 0.2 and 2.0)
0.1, # int | float (numeric value between 0.1 and 1.0)
fn_index=2,
)

with open(result) as t:
result = json.dumps(json.loads(t.read()), indent=2, ensure_ascii=False)
logger.debug(f"Output from LLM: {result}")
return str(result)

0 comments on commit 7cee4ec

Please sign in to comment.