diff --git a/sdmetrics/reports/multi_table/_properties/base.py b/sdmetrics/reports/multi_table/_properties/base.py index df4d9f1d..1d831888 100644 --- a/sdmetrics/reports/multi_table/_properties/base.py +++ b/sdmetrics/reports/multi_table/_properties/base.py @@ -147,8 +147,10 @@ def get_details(self, table_name=None): return self.details.copy() if self._num_iteration_case in ['relationship', 'inter_table_columns']: - table_rows = ((self.details['Parent Table'] == table_name) | - (self.details['Child Table'] == table_name)) + table_rows = ( + (self.details['Parent Table'] == table_name) | + (self.details['Child Table'] == table_name) + ) else: table_rows = self.details['Table'] == table_name diff --git a/sdmetrics/reports/multi_table/_properties/inter_table_trends.py b/sdmetrics/reports/multi_table/_properties/inter_table_trends.py index f35c64e5..e31eeed7 100644 --- a/sdmetrics/reports/multi_table/_properties/inter_table_trends.py +++ b/sdmetrics/reports/multi_table/_properties/inter_table_trends.py @@ -19,6 +19,7 @@ class InterTableTrends(BaseMultiTableProperty): calculated and the final score represents the average of these measures across all column pairs """ + _num_iteration_case = 'inter_table_columns' def get_score(self, real_data, synthetic_data, metadata, progress_bar=None): @@ -104,6 +105,21 @@ def get_score(self, real_data, synthetic_data, metadata, progress_bar=None): return self._compute_average() def get_visualization(self, table_name=None): + """Create a plot to show the inter table trends data. + + Returns: + plotly.graph_objects._figure.Figure + + Args: + table_name (str, optional): + Table to plot. Defaults to None. + + Raises: + - ``ValueError`` if property has not been computed. + + Returns: + plotly.graph_objects._figure.Figure + """ if not self.is_computed: raise ValueError( 'The property must be computed before getting a visualization.' @@ -112,8 +128,10 @@ def get_visualization(self, table_name=None): to_plot = self.details.copy() if table_name is not None: - to_plot = to_plot[(to_plot['Parent Table'] == table_name) | - (to_plot['Child Table'] == table_name)] + to_plot = to_plot[ + (to_plot['Parent Table'] == table_name) | + (to_plot['Child Table'] == table_name) + ] parent_cols = to_plot['Parent Table'] + '.' + to_plot['Column 1'] child_cols = to_plot['Child Table'] + '.' + to_plot['Column 2'] @@ -145,20 +163,21 @@ def get_visualization(self, table_name=None): 'Metric', 'Score', 'Real Correlation', - 'Synthetic Correlation'] + 'Synthetic Correlation' + ] ) fig.update_yaxes(range=[0, 1]) fig.update_traces( - hovertemplate="
".join([ - "%{x}", - "%{customdata[0]}", - "", - "Metric=%{customdata[1]}", - "Score=%{customdata[2]}", - "Real Correlation=%{customdata[3]}", - "Synthetic Correlation=%{customdata[4]}" + hovertemplate='
'.join([ + '%{x}', + '%{customdata[0]}', + '', + 'Metric=%{customdata[1]}', + 'Score=%{customdata[2]}', + 'Real Correlation=%{customdata[3]}', + 'Synthetic Correlation=%{customdata[4]}' ]) )