Skip to content

Commit

Permalink
tests + lint
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h committed Sep 25, 2023
1 parent e41364a commit 728feb6
Show file tree
Hide file tree
Showing 6 changed files with 435 additions and 3 deletions.
4 changes: 2 additions & 2 deletions sdmetrics/reports/multi_table/_properties/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _get_num_iterations(self, metadata):
elif self._num_iteration_case == 'column_pair':
num_columns = [len(table['columns']) for table in metadata['tables'].values()]
return sum([(n_cols * (n_cols - 1)) // 2 for n_cols in num_columns])
elif self._num_iteration_case == 'inter_table_columns':
elif self._num_iteration_case == 'inter_table_column_pair':
iterations = 0
for relationship in metadata['relationships']:
parent_columns = metadata['tables'][relationship['parent_table_name']]['columns']
Expand Down Expand Up @@ -146,7 +146,7 @@ def get_details(self, table_name=None):
if table_name is None:
return self.details.copy()

if self._num_iteration_case in ['relationship', 'inter_table_columns']:
if self._num_iteration_case in ['relationship', 'inter_table_column_pair']:
table_rows = (
(self.details['Parent Table'] == table_name) |
(self.details['Child Table'] == table_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class InterTableTrends(BaseMultiTableProperty):
all column pairs
"""

_num_iteration_case = 'inter_table_columns'
_num_iteration_case = 'inter_table_column_pair'

def get_score(self, real_data, synthetic_data, metadata, progress_bar=None):
"""Get the average score of all the individual metric scores computed.
Expand Down Expand Up @@ -82,6 +82,8 @@ def get_score(self, real_data, synthetic_data, metadata, progress_bar=None):
f'{parent}.{col}': col_meta for col, col_meta in parent_meta['columns'].items()
}
merged_metadata['columns'] = {**child_cols, **parent_cols}
if 'primary_key' in merged_metadata:
merged_metadata['primary_key'] = f'{child}.{merged_metadata["primary_key"]}'

parent_child_pairs = itertools.product(parent_cols.keys(), child_cols.keys())

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from unittest.mock import Mock

from tqdm import tqdm

from sdmetrics.demos import load_demo
from sdmetrics.reports.multi_table._properties import InterTableTrends


class TestInterTableTrends:

def test_end_to_end(self):
"""Test ``ColumnPairTrends`` multi-table property end to end."""
# Setup
real_data, synthetic_data, metadata = load_demo(modality='multi_table')
column_pair_trends = InterTableTrends()

# Run
result = column_pair_trends.get_score(real_data, synthetic_data, metadata)

# Assert
assert result == 0.48240740740740734

def test_with_progress_bar(self):
"""Test that the progress bar is correctly updated."""
# Setup
real_data, synthetic_data, metadata = load_demo(modality='multi_table')
column_pair_trends = InterTableTrends()
num_iter = sum(
len(metadata['tables'][relationship['parent_table_name']]['columns'])
* len(metadata['tables'][relationship['child_table_name']]['columns']) # noqa: W503
for relationship in metadata['relationships']
)

progress_bar = tqdm(total=num_iter)
mock_update = Mock()
progress_bar.update = mock_update

# Run
result = column_pair_trends.get_score(real_data, synthetic_data, metadata, progress_bar)

# Assert
assert result == 0.48240740740740734
assert mock_update.call_count == num_iter
70 changes: 70 additions & 0 deletions tests/unit/reports/multi_table/_properties/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def test__get_num_iterations(self):
base_property._num_iteration_case = 'column_pair'
assert base_property._get_num_iterations(metadata) == 10

base_property._num_iteration_case = 'inter_table_column_pair'
assert base_property._get_num_iterations(metadata) == 11

def test__generate_details_property(self):
"""Test the ``_generate_details`` method."""
# Setup
Expand Down Expand Up @@ -221,3 +224,70 @@ def test_get_visualization_raises_error(self):
)
with pytest.raises(ValueError, match=expected_message):
base_property.get_visualization('table')

def test_get_details(self):
"""Test the ``get_details`` method."""
# Setup
details = pd.DataFrame({
'Table': ['table1', 'table2', 'table3'],
'Column 1': ['col1', 'col2', 'col3'],
'Column 2': ['colA', 'colB', 'colC'],
'Score': [0, 0.5, 1.0],
'Error': [None, None, None]
})

base_property = BaseMultiTableProperty()
base_property.details = details

# Run
full_details = base_property.get_details()
table_details = base_property.get_details('table2')

# Assert
expected_table_details = pd.DataFrame({
'Table': ['table2'],
'Column 1': ['col2'],
'Column 2': ['colB'],
'Score': [0.5],
'Error': [None]
}, index=[1])
pd.testing.assert_frame_equal(details, full_details)
pd.testing.assert_frame_equal(table_details, expected_table_details)

def test_get_details_with_parent_child(self):
"""Test ``get_details`` with properties with parent/child relationships."""
# Setup
details = pd.DataFrame({
'Parent Table': ['table1', 'table3', 'table3'],
'Child Table': ['table2', 'table2', 'table4'],
'Column 1': ['col1', 'col2', 'col3'],
'Column 2': ['colA', 'colB', 'colC'],
'Score': [0, 0.5, 1.0],
'Error': [None, None, None]
})

base_property = BaseMultiTableProperty()
base_property.details = details

# Run
full_details = []
table_details = []
for prop in ['relationship', 'inter_table_column_pair']:
base_property._num_iteration_case = prop
full_details.append(base_property.get_details())
table_details.append(base_property.get_details('table2'))

# Assert
expected_table_details = pd.DataFrame({
'Parent Table': ['table1', 'table3'],
'Child Table': ['table2', 'table2'],
'Column 1': ['col1', 'col2'],
'Column 2': ['colA', 'colB'],
'Score': [0.0, 0.5],
'Error': [None, None]
})
for detail_df in full_details:
pd.testing.assert_frame_equal(detail_df, details)

for detail_df in table_details:
pd.testing.assert_frame_equal(detail_df, expected_table_details)
Loading

0 comments on commit 728feb6

Please sign in to comment.