From c8407b91a66465af302950bedbb233b4aab1905f Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 30 Oct 2023 14:34:13 -0600 Subject: [PATCH] first changes --- .../reports/multi_table/_properties/base.py | 16 ++ .../multi_table/_properties/cardinality.py | 29 +-- .../_properties/relationship_validity.py | 172 ++++++++++++++++++ 3 files changed, 196 insertions(+), 21 deletions(-) create mode 100644 sdmetrics/reports/multi_table/_properties/relationship_validity.py diff --git a/sdmetrics/reports/multi_table/_properties/base.py b/sdmetrics/reports/multi_table/_properties/base.py index f585a86e..9f97eaa7 100644 --- a/sdmetrics/reports/multi_table/_properties/base.py +++ b/sdmetrics/reports/multi_table/_properties/base.py @@ -137,6 +137,22 @@ def get_visualization(self, table_name): return self._properties[table_name].get_visualization() + def _get_details_for_table_name_with_relationships(self, table_name): + """Return the details for the given table name. + + Args: + table_name (str): + Table name to get the details for. + + Returns: + pandas.DataFrame: + The details for the given table name. + """ + if all(column in self.details.columns for column in ['Child Table', 'Parent Table']): + is_child = self.details['Child Table'] == table_name + is_parent = self.details['Parent Table'] == table_name + return self.details[is_child | is_parent].copy() + def get_details(self, table_name=None): """Return the details table for the property for the given table. diff --git a/sdmetrics/reports/multi_table/_properties/cardinality.py b/sdmetrics/reports/multi_table/_properties/cardinality.py index 6601e359..4c5e5f51 100644 --- a/sdmetrics/reports/multi_table/_properties/cardinality.py +++ b/sdmetrics/reports/multi_table/_properties/cardinality.py @@ -20,10 +20,12 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No """Get the average score of cardinality shape similarity in the given tables. Args: - real_data (pandas.DataFrame): - The real data. - synthetic_data (pandas.DataFrame): - The synthetic data. + real_data (dict[str, pandas.DataFrame]): + The tables from the real dataset, passed as a dictionary of + table names and pandas.DataFrames. + synthetic_data (dict[str, pandas.DataFrame]): + The tables from the synthetic dataset, passed as a dictionary of + table names and pandas.DataFrames. metadata (dict): The metadata, which contains each column's data type as well as relationships. progress_bar (tqdm.tqdm or None): @@ -65,21 +67,6 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No 'Error': error_messages, }) - def _get_details_for_table_name(self, table_name): - """Return the details for the given table name. - - Args: - table_name (str): - Table name to get the details for. - - Returns: - pandas.DataFrame: - The details for the given table name. - """ - is_child = self.details['Child Table'] == table_name - is_parent = self.details['Parent Table'] == table_name - return self.details[is_child | is_parent].copy() - def get_details(self, table_name=None): """Return the details for the property. @@ -95,7 +82,7 @@ def get_details(self, table_name=None): if table_name is None: return self.details.copy() - return self._get_details_for_table_name(table_name) + return self._get_details_for_table_name_with_relationships(table_name) def _get_table_relationships_plot(self, table_name): """Get the table relationships plot from the parent child relationship scores for a table. @@ -107,7 +94,7 @@ def _get_table_relationships_plot(self, table_name): Returns: plotly.graph_objects._figure.Figure """ - plot_data = self._get_details_for_table_name(table_name).copy() + plot_data = self._get_details_for_table_name_with_relationships(table_name).copy() column_name = 'Child → Parent Relationship' plot_data[column_name] = plot_data['Child Table'] + ' → ' + plot_data['Parent Table'] plot_data = plot_data.drop(['Child Table', 'Parent Table'], axis=1) diff --git a/sdmetrics/reports/multi_table/_properties/relationship_validity.py b/sdmetrics/reports/multi_table/_properties/relationship_validity.py new file mode 100644 index 00000000..71f47266 --- /dev/null +++ b/sdmetrics/reports/multi_table/_properties/relationship_validity.py @@ -0,0 +1,172 @@ +import numpy as np +import pandas as pd +import plotly.express as px + +from sdmetrics.multi_table.statistical import CardinalityShapeSimilarity +from sdmetrics.column_pairs.statistical import ReferentialIntegrity +from sdmetrics.column_pairs.statistical import CardinalityBoundaryAdherence +from sdmetrics.reports.multi_table._properties.base import BaseMultiTableProperty +from sdmetrics.reports.utils import PlotConfig + + +class Relationship_Validity(BaseMultiTableProperty): + """``Relationship Validity`` class. + + This property measures the validity of the relationship + from the primary key and the foreign key perspective. + + """ + + _num_iteration_case = 'relationship' + + def _extract_tuple(data, relation): + parent_data = data[relation['parent_table_name']] + child_data = data[relation['child_table_name']] + return ( + parent_data[relation['parent_primary_key']], child_data[relation['child_foreign_key']] + ) + + def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=None): + """Get the average score of relationship validity in the given tables. + + Args: + real_data (dict[str, pandas.DataFrame]): + The tables from the real dataset, passed as a dictionary of + table names and pandas.DataFrames. + synthetic_data (dict[str, pandas.DataFrame]): + The tables from the synthetic dataset, passed as a dictionary of + table names and pandas.DataFrames. + metadata (dict): + The metadata, which contains each column's data type as well as relationships. + progress_bar (tqdm.tqdm or None): + The progress bar object. Defaults to ``None``. + + Returns: + float: + The average score for the property for all the individual metric scores computed. + """ + child_tables, parent_tables = [], [] + primary_key, foreign_key = [], [] + metric_names, scores, error_messages = [], [], [] + metrics = [ReferentialIntegrity, CardinalityBoundaryAdherence] + for relation in metadata.get('relationships', []): + real_columns = self._extract_tuple(real_data, relation) + synthetic_columns = self._extract_tuple(synthetic_data, relation) + for metric in metrics: + try: + relation_score = metric.compute( + real_columns, + synthetic_columns, + ) + error_message = None + except Exception as e: + relation_score = np.nan + error_message = f'{type(e).__name__}: {e}' + finally: + if progress_bar is not None: + progress_bar.update() + + child_tables.append(relation['child_table_name']) + parent_tables.append(relation['parent_table_name']) + primary_key.append(relation['parent_primary_key']) + foreign_key.append(relation['child_foreign_key']) + metric_names.append(metric.__name__) + scores.append(relation_score) + error_messages.append(error_message) + + self.details = pd.DataFrame({ + 'Parent Table': parent_tables, + 'Child Table': child_tables, + 'Primary key': primary_key, + 'Foreign key': foreign_key, + 'Metric': metric_names, + 'Score': scores, + 'Error': error_messages, + }) + + def _get_details_for_table_name_with_relationships(self, table_name): + """Return the details for the given table name. + + Args: + table_name (str): + Table name to get the details for. + + Returns: + pandas.DataFrame: + The details for the given table name. + """ + is_child = self.details['Child Table'] == table_name + is_parent = self.details['Parent Table'] == table_name + return self.details[is_child | is_parent].copy() + + def get_details(self, table_name=None): + """Return the details for the property. + + Args: + table_name (str): + Table name to get the details for. + Defaults to ``None``. + + Returns: + pandas.DataFrame: + The details for the property. + """ + if table_name is None: + return self.details.copy() + + return self._get_details_for_table_name_with_relationships(table_name) + + def _get_table_relationships_plot(self, table_name): + """Get the table relationships plot from the parent child relationship scores for a table. + + Args: + table_name (str): + Table name to get details table for. + + Returns: + plotly.graph_objects._figure.Figure + """ + plot_data = self._get_details_for_table_name_with_relationships(table_name).copy() + column_name = 'Child → Parent Relationship' + plot_data[column_name] = plot_data['Child Table'] + ' → ' + plot_data['Parent Table'] + plot_data = plot_data.drop(['Child Table', 'Parent Table'], axis=1) + + average_score = round(plot_data['Score'].mean(), 2) + + fig = px.bar( + plot_data, + x='Child → Parent Relationship', + y='Score', + title=f'Table Relationships (Average Score={average_score})', + color='Metric', + color_discrete_sequence=[PlotConfig.DATACEBO_DARK], + hover_name='Child → Parent Relationship', + hover_data={ + 'Child → Parent Relationship': False, + 'Metric': True, + 'Score': True, + }, + ) + + fig.update_yaxes(range=[0, 1]) + + fig.update_layout( + xaxis_categoryorder='total ascending', + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + font={'size': PlotConfig.FONT_SIZE} + ) + + return fig + + def get_visualization(self, table_name): + """Return a visualization for each score in the property. + + Args: + table_name (str): + Table name to get the visualization for. + + Returns: + plotly.graph_objects._figure.Figure + The visualization for the property. + """ + return self._get_table_relationships_plot(table_name)