Skip to content

Commit

Permalink
bug: fix publisher failure when prediction output is in one dimension…
Browse files Browse the repository at this point in the history
…al format
  • Loading branch information
khorshuheng committed Apr 2, 2024
1 parent 5f70094 commit 67c3f71
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 4 deletions.
41 changes: 41 additions & 0 deletions python/observation-publisher/publisher/prediction_log_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,35 @@


class PredictionLogConsumer(abc.ABC):
"""
Abstract class for consuming prediction logs from a streaming source, then write to one or multiple sinks
"""
def __init__(self, buffer_capacity: int, buffer_max_duration_seconds: int):
self.buffer_capacity = buffer_capacity
self.buffer_max_duration_seconds = buffer_max_duration_seconds

@abc.abstractmethod
def poll_new_logs(self) -> List[PredictionLog]:
"""
Poll new logs from the source
:return:
"""
raise NotImplementedError

@abc.abstractmethod
def commit(self):
"""
Commit the current offset after the logs have been written to all the sinks
:return:
"""
raise NotImplementedError

@abc.abstractmethod
def close(self):
"""
Clean up the resources when the polling process run into error unexpectedly.
:return:
"""
raise NotImplementedError

def start_polling(
Expand All @@ -43,6 +58,13 @@ def start_polling(
inference_schema: InferenceSchema,
model_version: str,
):
"""
Start polling new logs from the source, then write to the sinks. The prediction logs are written to each sink asynchronously.
:param observation_sinks:
:param inference_schema:
:param model_version:
:return:
"""
try:
buffered_logs = []
buffer_start_time = datetime.now()
Expand Down Expand Up @@ -155,6 +177,11 @@ def new_consumer(config: ObservationSourceConfig) -> PredictionLogConsumer:


def parse_message_to_prediction_log(msg: str) -> PredictionLog:
"""
Parse the message from the Kafka consumer to a PredictionLog object
:param msg:
:return:
"""
log = PredictionLog()
log.ParseFromString(msg)
return log
Expand All @@ -163,6 +190,13 @@ def parse_message_to_prediction_log(msg: str) -> PredictionLog:
def log_to_records(
log: PredictionLog, inference_schema: InferenceSchema, model_version: str
) -> Tuple[List[List[np.int64 | np.float64 | np.bool_ | np.str_]], List[str]]:
"""
Convert a PredictionLog object to a list of records and column names
:param log: Prediction log.
:param inference_schema: Inference schema.
:param model_version: Model version.
:return:
"""
request_timestamp = log.request_timestamp.ToDatetime()
feature_table = PredictionLogFeatureTable.from_struct(
log.input.features_table, inference_schema
Expand Down Expand Up @@ -199,6 +233,13 @@ def log_to_records(
def log_batch_to_dataframe(
logs: List[PredictionLog], inference_schema: InferenceSchema, model_version: str
) -> pd.DataFrame:
"""
Combines several logs into a single DataFrame
:param logs: List of prediction logs.
:param inference_schema: Inference schema.
:param model_version: Model version.
:return:
"""
combined_records = []
column_names: List[str] = []
for log in logs:
Expand Down
40 changes: 38 additions & 2 deletions python/observation-publisher/publisher/prediction_log_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ class PredictionLogFeatureTable:
def from_struct(
cls, table_struct: Struct, inference_schema: InferenceSchema
) -> Self:
"""
Create a PredictionLogFeatureTable object from a Protobuf Struct object
:param table_struct: A Protobuf Struct object that represents a feature table.
:param inference_schema: Model inference schema.
:return: Instance of PredictionLogFeatureTable.
"""
if inference_schema.feature_orders is not None:
columns = inference_schema.feature_orders
else:
Expand All @@ -36,6 +42,11 @@ def from_struct(


def prediction_columns(inference_schema: InferenceSchema) -> List[str]:
"""
Get the column name for the prediction output
:param inference_schema: Model inference schema
:return: List of column names
"""
if isinstance(inference_schema.model_prediction_output, BinaryClassificationOutput):
return [inference_schema.model_prediction_output.prediction_score_column]
elif isinstance(inference_schema.model_prediction_output, RankingOutput):
Expand All @@ -56,6 +67,12 @@ class PredictionLogResultsTable:
def from_struct(
cls, table_struct: Struct, inference_schema: InferenceSchema
) -> Self:
"""
Create a PredictionLogResultsTable object from a Protobuf Struct object
:param table_struct: Protobuf Struct object that represents a prediction result table.
:param inference_schema: Model InferenceSchema.
:return: PredictionLogResultsTable instnace.
"""
if "columns" in table_struct.keys():
assert isinstance(table_struct["columns"], ListValue)
columns = list_value_as_string_list(table_struct["columns"])
Expand Down Expand Up @@ -102,6 +119,9 @@ def convert_to_numpy_value(


def list_value_as_string_list(list_value: ListValue) -> List[str]:
"""
Convert protobuf string list to it's native python type counterpart.
"""
string_list: List[str] = []
for v in list_value.items():
assert isinstance(v, str)
Expand All @@ -110,17 +130,33 @@ def list_value_as_string_list(list_value: ListValue) -> List[str]:


def list_value_as_rows(list_value: ListValue) -> List[ListValue]:
"""
Convert a ListValue object to a list of ListValue objects
:param list_value: Representation of a two dimensional matrix
:return: List of ListValue objects
"""
rows: List[ListValue] = []
for d in list_value.items():
assert isinstance(d, ListValue)
rows.append(d)
if isinstance(d, ListValue):
rows.append(d)
else:
nd = ListValue()
nd.append(d)
rows.append(nd)

return rows


def list_value_as_numpy_list(
list_value: ListValue, column_names: List[str], column_types: Dict[str, ValueType]
) -> List[np.int64 | np.float64 | np.bool_ | np.str_]:
"""
Convert a ListValue representing a row, to it's native python type counterpart.
:param list_value: ListValue object representing a row.
:param column_names: Column names corresponds to each column in a row.
:param column_types: Map of column name to type.
:return: List of numpy types.
"""
column_values: List[int | str | float | bool | None] = []
for v in list_value.items():
assert isinstance(v, (int, str, float, bool, NoneType))
Expand Down
65 changes: 63 additions & 2 deletions python/observation-publisher/tests/test_prediction_log_consumer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, List
from typing import Any, List, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -57,7 +57,7 @@ def new_standard_model_log(
session_id: str,
row_ids: List[str],
input_data: List[List[Any]],
output_data: List[List[Any]],
output_data: Union[List[List[Any]], List[Any]],
request_timestamp: datetime,
):
prediction_log = PredictionLog()
Expand Down Expand Up @@ -303,3 +303,64 @@ def test_standard_model_log_to_dataframe():
],
)
assert_frame_equal(prediction_logs_df, expected_df, check_like=True)


def test_one_dimensional_prediction_output_to_dataframe():
model_id = "test_model"
model_version = "0.1.0"
inference_schema = InferenceSchema(
feature_types={
"acceptance_rate": ValueType.FLOAT64,
"minutes_since_last_order": ValueType.INT64,
"service_type": ValueType.STRING,
},
feature_orders=["acceptance_rate", "minutes_since_last_order", "service_type"],
model_prediction_output=BinaryClassificationOutput(
prediction_score_column="prediction_score",
actual_score_column="actual_score",
positive_class_label="fraud",
negative_class_label="non fraud",
score_threshold=0.5,
),
session_id_column="order_id",
row_id_column="driver_id"
)
request_timestamp = datetime(2021, 1, 1, 0, 0, 0)
prediction_logs = [
new_standard_model_log(
session_id="1234",
model_id=model_id,
model_version=model_version,
input_data=[
[0.8, 24, "FOOD"],
[0.5, 2, "RIDE"],
],
output_data=[
0.9,
0.5,
],
request_timestamp=request_timestamp,
row_ids=["a", "b"],
),
]
prediction_logs_df = log_batch_to_dataframe(
prediction_logs, inference_schema, model_version
)
expected_df = pd.DataFrame.from_records(
[
[0.8, 24, "FOOD", 0.9, "fraud", "1234", "a", request_timestamp, model_version],
[0.5, 2, "RIDE", 0.5, "fraud", "1234", "b", request_timestamp, model_version],
],
columns=[
"acceptance_rate",
"minutes_since_last_order",
"service_type",
"prediction_score",
"_prediction_label",
"order_id",
"driver_id",
"request_timestamp",
"model_version",
],
)
assert_frame_equal(prediction_logs_df, expected_df, check_like=True)

0 comments on commit 67c3f71

Please sign in to comment.