Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stats representation via plotly in to_viz interface #335

Merged
merged 13 commits into from
Sep 21, 2023
19 changes: 10 additions & 9 deletions ads/feature_store/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(self, spec: Dict = None, **kwargs) -> None:
self.oci_dataset = self._to_oci_dataset(**kwargs)
self.lineage = OCILineage(**kwargs)

def _to_oci_dataset(self, **kwargs):
def _to_oci_dataset(self, **kwargs) -> OCIDataset:
"""Creates an `OCIDataset` instance from the `Dataset`.

kwargs
Expand Down Expand Up @@ -235,8 +235,8 @@ def name(self) -> str:
return self.get_spec(self.CONST_NAME)

@name.setter
def name(self, name: str) -> "Dataset":
return self.with_name(name)
def name(self, name: str):
self.with_name(name)

def with_name(self, name: str) -> "Dataset":
"""Sets the name.
Expand Down Expand Up @@ -866,9 +866,8 @@ def _update_from_oci_dataset_model(self, oci_dataset: OCIDataset) -> "Dataset":

value = {self.CONST_ITEMS: features_list}
else:
value = getattr(self.oci_dataset, dsc_attr)
value = dataset_details[infra_attr]
self.set_spec(infra_attr, value)

return self

def materialise(
Expand Down Expand Up @@ -1206,12 +1205,14 @@ def to_dict(self) -> Dict:
for key, value in spec.items():
if hasattr(value, "to_dict"):
value = value.to_dict()
if hasattr(value, "attribute_map"):
value = self.oci_dataset.client.base_client.sanitize_for_serialization(
if key == self.CONST_FEATURE_GROUP:
spec[
key
] = self.oci_dataset.client.base_client.sanitize_for_serialization(
value
)
spec[key] = value

else:
spec[key] = value
return {
"kind": self.kind,
"type": self.type,
Expand Down
20 changes: 12 additions & 8 deletions ads/feature_store/execution_strategy/spark/spark_execution.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8; -*-
import json

# Copyright (c) 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

Expand Down Expand Up @@ -29,8 +27,6 @@
raise

from ads.feature_store.common.enums import (
FeatureStoreJobType,
LifecycleState,
EntityType,
ExpectationType,
)
Expand All @@ -47,6 +43,11 @@

from ads.feature_store.feature_statistics.statistics_service import StatisticsService
from ads.feature_store.common.utils.utility import validate_input_feature_details
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from ads.feature_store.feature_group import FeatureGroup
from ads.feature_store.dataset import Dataset

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,7 +77,10 @@ def __init__(self, metastore_id: str = None):
self._jvm = self._spark_context._jvm

def ingest_feature_definition(
self, feature_group, feature_group_job: FeatureGroupJob, dataframe
self,
feature_group: "FeatureGroup",
feature_group_job: FeatureGroupJob,
dataframe,
):
try:
self._save_offline_dataframe(dataframe, feature_group, feature_group_job)
Expand All @@ -90,7 +94,7 @@ def ingest_dataset(self, dataset, dataset_job: DatasetJob):
raise SparkExecutionException(e).with_traceback(e.__traceback__)

def delete_feature_definition(
self, feature_group, feature_group_job: FeatureGroupJob
self, feature_group: "FeatureGroup", feature_group_job: FeatureGroupJob
):
"""
Deletes a feature definition from the system.
Expand Down Expand Up @@ -122,7 +126,7 @@ def delete_feature_definition(
output_details=output_details,
)

def delete_dataset(self, dataset, dataset_job: DatasetJob):
def delete_dataset(self, dataset: "Dataset", dataset_job: DatasetJob):
"""
Deletes a dataset from the system.

Expand Down Expand Up @@ -154,7 +158,7 @@ def delete_dataset(self, dataset, dataset_job: DatasetJob):
)

@staticmethod
def _validate_expectation(expectation_type, validation_output):
def _validate_expectation(expectation_type, validation_output: dict):
"""
Validates the expectation based on the given expectation type and the validation output.

Expand Down
8 changes: 3 additions & 5 deletions ads/feature_store/feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def name(self) -> str:
return self.get_spec(self.CONST_NAME)

@name.setter
def name(self, name: str) -> "FeatureGroup":
return self.with_name(name)
def name(self, name: str):
self.with_name(name)

def with_name(self, name: str) -> "FeatureGroup":
"""Sets the name.
Expand Down Expand Up @@ -338,7 +338,7 @@ def transformation_kwargs(self, value: Dict):
self.with_transformation_kwargs(value)

def with_transformation_kwargs(
self, transformation_kwargs: Dict = {}
self, transformation_kwargs: Dict = ()
) -> "FeatureGroup":
"""Sets the primary keys of the feature group.

Expand Down Expand Up @@ -604,7 +604,6 @@ def with_statistics_config(
FeatureGroup
The FeatureGroup instance (self).
"""
statistics_config_in = None
if isinstance(statistics_config, StatisticsConfig):
statistics_config_in = statistics_config
elif isinstance(statistics_config, bool):
Expand Down Expand Up @@ -1108,7 +1107,6 @@ def restore(self, version_number: int = None, timestamp: datetime = None):
f"RESTORE TABLE {target_table} TO VERSION AS OF {version_number}"
)
else:
iso_timestamp = timestamp.isoformat(" ", "seconds").__str__()
sql_query = f"RESTORE TABLE {target_table} TO TIMESTAMP AS OF {timestamp}"

restore_output = self.spark_engine.sql(sql_query)
Expand Down
243 changes: 243 additions & 0 deletions ads/feature_store/feature_stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
from abc import abstractmethod

from ads.common.decorator.runtime_dependency import OptionalDependency

from typing import List

try:
import plotly
from plotly.graph_objs import Figure
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also might need to add integration tests so that we ensure that library is present in the conda pack

import plotly.graph_objects as go
from plotly.subplots import make_subplots
except ModuleNotFoundError:
raise ModuleNotFoundError(
f"The `plotly` module was not found. Please run `pip install "
f"{OptionalDependency.FEATURE_STORE}`."
)


class FeatureStat:
@abstractmethod
def add_to_figure(self, fig: Figure, xaxis: int, yaxis: int):
pass

@classmethod
@abstractmethod
def from_json(cls, json_dict: dict):
pass

@staticmethod
def get_x_y_str_axes(xaxis: int, yaxis: int) -> ():
return (
("xaxis" + str(xaxis + 1)),
("yaxis" + str(yaxis + 1)),
("x" + str(xaxis + 1)),
("y" + str(yaxis + 1)),
)


class FrequencyDistribution(FeatureStat):
CONST_FREQUENCY = "frequency"
CONST_BINS = "bins"
CONST_FREQUENCY_DISTRIBUTION_TITLE = "Frequency Distribution"

def __init__(self, frequency: List, bins: List):
self.frequency = frequency
self.bins = bins

@classmethod
def from_json(cls, json_dict: dict) -> "FrequencyDistribution":
if json_dict is not None:
return FrequencyDistribution(
frequency=json_dict.get(FrequencyDistribution.CONST_FREQUENCY),
bins=json_dict.get(FrequencyDistribution.CONST_BINS),
)
else:
return None

def add_to_figure(self, fig: Figure, xaxis: int, yaxis: int):
xaxis_str, yaxis_str, x_str, y_str = self.get_x_y_str_axes(xaxis, yaxis)
if (
type(self.frequency) == list
and type(self.bins) == list
and 0 < len(self.frequency) == len(self.bins) > 0
):
fig.add_bar(
x=self.bins, y=self.frequency, xaxis=x_str, yaxis=y_str, name=""
)
fig.layout.annotations[xaxis].text = self.CONST_FREQUENCY_DISTRIBUTION_TITLE
fig.layout[xaxis_str]["title"] = "Bins"
fig.layout[yaxis_str]["title"] = "Frequency"


class ProbabilityDistribution(FeatureStat):
CONST_DENSITY = "density"
CONST_BINS = "bins"
CONST_PROBABILITY_DISTRIBUTION_TITLE = "Probability Distribution"

def __init__(self, density: List, bins: List):
self.density = density
self.bins = bins

@classmethod
def from_json(cls, json_dict: dict):
if json_dict is not None:
return cls(
density=json_dict.get(ProbabilityDistribution.CONST_DENSITY),
bins=json_dict.get(ProbabilityDistribution.CONST_BINS),
)
else:
return None

def add_to_figure(self, fig: Figure, xaxis: int, yaxis: int):
xaxis_str, yaxis_str, x_str, y_str = self.get_x_y_str_axes(xaxis, yaxis)
if (
type(self.density) == list
and type(self.bins) == list
and 0 < len(self.density) == len(self.bins) > 0
):
fig.add_bar(
x=self.bins,
y=self.density,
xaxis=x_str,
yaxis=y_str,
name="",
)
fig.layout.annotations[xaxis].text = self.CONST_PROBABILITY_DISTRIBUTION_TITLE
fig.layout[xaxis_str]["title"] = "Bins"
fig.layout[yaxis_str]["title"] = "Density"

return go.Bar(x=self.bins, y=self.density)


class TopKFrequentElements(FeatureStat):
CONST_VALUE = "value"
CONST_TOP_K_FREQUENT_TITLE = "Top K Frequent Elements"

class TopKFrequentElement:
CONST_VALUE = "value"
CONST_ESTIMATE = "estimate"
CONST_LOWER_BOUND = "lower_bound"
CONST_UPPER_BOUND = "upper_bound"

def __init__(
self, value: str, estimate: int, lower_bound: int, upper_bound: int
):
self.value = value
self.estimate = estimate
self.lower_bound = lower_bound
self.upper_bound = upper_bound

@classmethod
def from_json(cls, json_dict: dict):
if json_dict is not None:
return cls(
value=json_dict.get(cls.CONST_VALUE),
estimate=json_dict.get(cls.CONST_ESTIMATE),
lower_bound=json_dict.get(cls.CONST_LOWER_BOUND),
upper_bound=json_dict.get(cls.CONST_UPPER_BOUND),
)

else:
return None

def __init__(self, elements: List[TopKFrequentElement]):
self.elements = elements

@classmethod
def from_json(cls, json_dict: dict):
if json_dict is not None and json_dict.get(cls.CONST_VALUE) is not None:
elements = json_dict.get(cls.CONST_VALUE)
return cls(
[cls.TopKFrequentElement.from_json(element) for element in elements]
)
else:
return None

def add_to_figure(self, fig: Figure, xaxis: int, yaxis: int):
xaxis_str, yaxis_str, x_str, y_str = self.get_x_y_str_axes(xaxis, yaxis)
if type(self.elements) == list and len(self.elements) > 0:
x_axis = [element.value for element in self.elements]
y_axis = [element.estimate for element in self.elements]
fig.add_bar(x=x_axis, y=y_axis, xaxis=x_str, yaxis=y_str, name="")
fig.layout.annotations[xaxis].text = self.CONST_TOP_K_FREQUENT_TITLE
fig.layout[yaxis_str]["title"] = "Count"
fig.layout[xaxis_str]["title"] = "Element"


class FeatureStatistics:
Copy link
Member

@KshitizLohia KshitizLohia Sep 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to have box plot for other numerical variables. if not we need to settle for Scatter for min max and median added to to_viz() for each numerical feature

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is!!

CONST_FREQUENCY_DISTRIBUTION = "FrequencyDistribution"
CONST_TITLE_FORMAT = "<b>{}</b>"
CONST_PLOT_FORMAT = "{}_plot"
CONST_PROBABILITY_DISTRIBUTION = "ProbabilityDistribution"
CONST_TOP_K_FREQUENT = "TopKFrequentElements"

def __init__(
self,
feature_name: str,
top_k_frequent_elements: TopKFrequentElements,
frequency_distribution: FrequencyDistribution,
probability_distribution: ProbabilityDistribution,
):
self.feature_name: str = feature_name
self.top_k_frequent_elements = top_k_frequent_elements
self.frequency_distribution = frequency_distribution
self.probability_distribution = probability_distribution

@classmethod
def from_json(cls, feature_name: str, json_dict: dict) -> "FeatureStatistics":
if json_dict is not None:
return cls(
feature_name,
TopKFrequentElements.from_json(json_dict.get(cls.CONST_TOP_K_FREQUENT)),
FrequencyDistribution.from_json(
json_dict.get(cls.CONST_FREQUENCY_DISTRIBUTION)
),
ProbabilityDistribution.from_json(
json_dict.get(cls.CONST_PROBABILITY_DISTRIBUTION)
),
)
else:
return None

@property
def __stat_count__(self):
graph_count = 0
if self.top_k_frequent_elements is not None:
graph_count += 1
if self.probability_distribution is not None:
graph_count += 1
if self.frequency_distribution is not None:
graph_count += 1
return graph_count

@property
def __feature_stat_objects__(self) -> List[FeatureStat]:
return [
stat
for stat in [
self.top_k_frequent_elements,
self.frequency_distribution,
self.probability_distribution,
]
if stat is not None
]

def to_viz(self):
graph_count = len(self.__feature_stat_objects__)
if graph_count > 0:
fig = make_subplots(cols=graph_count, column_titles=["title"] * graph_count)
index = 0
for stat in [
stat for stat in self.__feature_stat_objects__ if stat is not None
]:
stat.add_to_figure(fig, index, index)
index += 1
fig.layout.title = self.CONST_TITLE_FORMAT.format(self.feature_name)
fig.update_layout(title_font_size=20)
fig.update_layout(title_x=0.5)
fig.update_layout(showlegend=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the rationale behind setting showlegend=False?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will start showing trace name in subplots for graphs which aren't related

plotly.offline.iplot(
fig,
filename=self.CONST_PLOT_FORMAT.format(self.feature_name),
)
Loading