diff --git a/sdmetrics/reports/multi_table/_properties/base.py b/sdmetrics/reports/multi_table/_properties/base.py
index df4d9f1d..1d831888 100644
--- a/sdmetrics/reports/multi_table/_properties/base.py
+++ b/sdmetrics/reports/multi_table/_properties/base.py
@@ -147,8 +147,10 @@ def get_details(self, table_name=None):
return self.details.copy()
if self._num_iteration_case in ['relationship', 'inter_table_columns']:
- table_rows = ((self.details['Parent Table'] == table_name) |
- (self.details['Child Table'] == table_name))
+ table_rows = (
+ (self.details['Parent Table'] == table_name) |
+ (self.details['Child Table'] == table_name)
+ )
else:
table_rows = self.details['Table'] == table_name
diff --git a/sdmetrics/reports/multi_table/_properties/inter_table_trends.py b/sdmetrics/reports/multi_table/_properties/inter_table_trends.py
index f35c64e5..e31eeed7 100644
--- a/sdmetrics/reports/multi_table/_properties/inter_table_trends.py
+++ b/sdmetrics/reports/multi_table/_properties/inter_table_trends.py
@@ -19,6 +19,7 @@ class InterTableTrends(BaseMultiTableProperty):
calculated and the final score represents the average of these measures across
all column pairs
"""
+
_num_iteration_case = 'inter_table_columns'
def get_score(self, real_data, synthetic_data, metadata, progress_bar=None):
@@ -104,6 +105,21 @@ def get_score(self, real_data, synthetic_data, metadata, progress_bar=None):
return self._compute_average()
def get_visualization(self, table_name=None):
+ """Create a plot to show the inter table trends data.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+
+ Args:
+ table_name (str, optional):
+ Table to plot. Defaults to None.
+
+ Raises:
+ - ``ValueError`` if property has not been computed.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
if not self.is_computed:
raise ValueError(
'The property must be computed before getting a visualization.'
@@ -112,8 +128,10 @@ def get_visualization(self, table_name=None):
to_plot = self.details.copy()
if table_name is not None:
- to_plot = to_plot[(to_plot['Parent Table'] == table_name) |
- (to_plot['Child Table'] == table_name)]
+ to_plot = to_plot[
+ (to_plot['Parent Table'] == table_name) |
+ (to_plot['Child Table'] == table_name)
+ ]
parent_cols = to_plot['Parent Table'] + '.' + to_plot['Column 1']
child_cols = to_plot['Child Table'] + '.' + to_plot['Column 2']
@@ -145,20 +163,21 @@ def get_visualization(self, table_name=None):
'Metric',
'Score',
'Real Correlation',
- 'Synthetic Correlation']
+ 'Synthetic Correlation'
+ ]
)
fig.update_yaxes(range=[0, 1])
fig.update_traces(
- hovertemplate="
".join([
- "%{x}",
- "%{customdata[0]}",
- "",
- "Metric=%{customdata[1]}",
- "Score=%{customdata[2]}",
- "Real Correlation=%{customdata[3]}",
- "Synthetic Correlation=%{customdata[4]}"
+ hovertemplate='
'.join([
+ '%{x}',
+ '%{customdata[0]}',
+ '',
+ 'Metric=%{customdata[1]}',
+ 'Score=%{customdata[2]}',
+ 'Real Correlation=%{customdata[3]}',
+ 'Synthetic Correlation=%{customdata[4]}'
])
)