From 05a800e724b33a9d62bb89dd4f7ac59ed2b5fd3f Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 30 Oct 2023 17:36:09 -0600 Subject: [PATCH] definition 2 --- .../multi_table/_properties/__init__.py | 2 + .../reports/multi_table/_properties/base.py | 24 +++----- .../_properties/relationship_validity.py | 55 +++---------------- 3 files changed, 19 insertions(+), 62 deletions(-) diff --git a/sdmetrics/reports/multi_table/_properties/__init__.py b/sdmetrics/reports/multi_table/_properties/__init__.py index 238b9af5..d1f53331 100644 --- a/sdmetrics/reports/multi_table/_properties/__init__.py +++ b/sdmetrics/reports/multi_table/_properties/__init__.py @@ -7,6 +7,7 @@ from sdmetrics.reports.multi_table._properties.column_shapes import ColumnShapes from sdmetrics.reports.multi_table._properties.coverage import Coverage from sdmetrics.reports.multi_table._properties.inter_table_trends import InterTableTrends +from sdmetrics.reports.multi_table._properties.relationship_validity import RelationshipValidity from sdmetrics.reports.multi_table._properties.synthesis import Synthesis __all__ = [ @@ -18,4 +19,5 @@ 'Coverage', 'InterTableTrends', 'Synthesis', + 'RelationshipValidity', ] diff --git a/sdmetrics/reports/multi_table/_properties/base.py b/sdmetrics/reports/multi_table/_properties/base.py index 9f97eaa7..923fa0ce 100644 --- a/sdmetrics/reports/multi_table/_properties/base.py +++ b/sdmetrics/reports/multi_table/_properties/base.py @@ -45,6 +45,14 @@ def _get_num_iterations(self, metadata): iterations += (len(parent_columns) * len(child_columns)) return iterations + @staticmethod + def _extract_tuple(data, relation): + parent_data = data[relation['parent_table_name']] + child_data = data[relation['child_table_name']] + return ( + parent_data[relation['parent_primary_key']], child_data[relation['child_foreign_key']] + ) + def _compute_average(self): """Average the scores for each column.""" is_dataframe = isinstance(self.details, pd.DataFrame) @@ -137,22 +145,6 @@ def get_visualization(self, table_name): return self._properties[table_name].get_visualization() - def _get_details_for_table_name_with_relationships(self, table_name): - """Return the details for the given table name. - - Args: - table_name (str): - Table name to get the details for. - - Returns: - pandas.DataFrame: - The details for the given table name. - """ - if all(column in self.details.columns for column in ['Child Table', 'Parent Table']): - is_child = self.details['Child Table'] == table_name - is_parent = self.details['Parent Table'] == table_name - return self.details[is_child | is_parent].copy() - def get_details(self, table_name=None): """Return the details table for the property for the given table. diff --git a/sdmetrics/reports/multi_table/_properties/relationship_validity.py b/sdmetrics/reports/multi_table/_properties/relationship_validity.py index 71f47266..2111ef54 100644 --- a/sdmetrics/reports/multi_table/_properties/relationship_validity.py +++ b/sdmetrics/reports/multi_table/_properties/relationship_validity.py @@ -2,14 +2,12 @@ import pandas as pd import plotly.express as px -from sdmetrics.multi_table.statistical import CardinalityShapeSimilarity -from sdmetrics.column_pairs.statistical import ReferentialIntegrity -from sdmetrics.column_pairs.statistical import CardinalityBoundaryAdherence +from sdmetrics.column_pairs.statistical import CardinalityBoundaryAdherence, ReferentialIntegrity from sdmetrics.reports.multi_table._properties.base import BaseMultiTableProperty from sdmetrics.reports.utils import PlotConfig -class Relationship_Validity(BaseMultiTableProperty): +class RelationshipValidity(BaseMultiTableProperty): """``Relationship Validity`` class. This property measures the validity of the relationship @@ -19,13 +17,6 @@ class Relationship_Validity(BaseMultiTableProperty): _num_iteration_case = 'relationship' - def _extract_tuple(data, relation): - parent_data = data[relation['parent_table_name']] - child_data = data[relation['child_table_name']] - return ( - parent_data[relation['parent_primary_key']], child_data[relation['child_foreign_key']] - ) - def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=None): """Get the average score of relationship validity in the given tables. @@ -77,44 +68,15 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No self.details = pd.DataFrame({ 'Parent Table': parent_tables, 'Child Table': child_tables, - 'Primary key': primary_key, - 'Foreign key': foreign_key, + 'Primary Key': primary_key, + 'Foreign Key': foreign_key, 'Metric': metric_names, 'Score': scores, 'Error': error_messages, }) - def _get_details_for_table_name_with_relationships(self, table_name): - """Return the details for the given table name. - - Args: - table_name (str): - Table name to get the details for. - - Returns: - pandas.DataFrame: - The details for the given table name. - """ - is_child = self.details['Child Table'] == table_name - is_parent = self.details['Parent Table'] == table_name - return self.details[is_child | is_parent].copy() - - def get_details(self, table_name=None): - """Return the details for the property. - - Args: - table_name (str): - Table name to get the details for. - Defaults to ``None``. - - Returns: - pandas.DataFrame: - The details for the property. - """ - if table_name is None: - return self.details.copy() - - return self._get_details_for_table_name_with_relationships(table_name) + if self.details['Error'].isna().all(): + self.details = self.details.drop('Error', axis=1) def _get_table_relationships_plot(self, table_name): """Get the table relationships plot from the parent child relationship scores for a table. @@ -126,7 +88,7 @@ def _get_table_relationships_plot(self, table_name): Returns: plotly.graph_objects._figure.Figure """ - plot_data = self._get_details_for_table_name_with_relationships(table_name).copy() + plot_data = self.get_details(table_name).copy() column_name = 'Child → Parent Relationship' plot_data[column_name] = plot_data['Child Table'] + ' → ' + plot_data['Parent Table'] plot_data = plot_data.drop(['Child Table', 'Parent Table'], axis=1) @@ -139,13 +101,14 @@ def _get_table_relationships_plot(self, table_name): y='Score', title=f'Table Relationships (Average Score={average_score})', color='Metric', - color_discrete_sequence=[PlotConfig.DATACEBO_DARK], + color_discrete_sequence=[PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_BLUE], hover_name='Child → Parent Relationship', hover_data={ 'Child → Parent Relationship': False, 'Metric': True, 'Score': True, }, + barmode='group' ) fig.update_yaxes(range=[0, 1])