Skip to content

Commit

Permalink
integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Oct 30, 2023
1 parent fea8ead commit cd64c38
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No
for relation in metadata.get('relationships', []):
real_columns = self._extract_tuple(real_data, relation)
synthetic_columns = self._extract_tuple(synthetic_data, relation)
for metric in metrics:
for idx, metric in enumerate(metrics):
try:
relation_score = metric.compute(
real_columns,
Expand All @@ -54,7 +54,7 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No
relation_score = np.nan
error_message = f'{type(e).__name__}: {e}'
finally:
if progress_bar is not None:
if progress_bar is not None and idx % 2 == 0:
progress_bar.update()

child_tables.append(relation['child_table_name'])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from unittest.mock import Mock

from tqdm import tqdm

from sdmetrics.demos import load_demo
from sdmetrics.reports.multi_table._properties import RelationshipValidity


class TestRelationshipValidity:

def test_end_to_end(self):
"""Test the ``RelationshipValidity`` multi-table property end to end."""
# Setup
real_data, synthetic_data, metadata = load_demo(modality='multi_table')
relationship_validity = RelationshipValidity()

# Run
result = relationship_validity.get_score(real_data, synthetic_data, metadata)

# Assert
assert result == 1.0

def test_with_progress_bar(self):
"""Test that the progress bar is correctly updated."""
# Setup
real_data, synthetic_data, metadata = load_demo(modality='multi_table')
relationship_validity = RelationshipValidity()
num_relationship = 2

progress_bar = tqdm(total=num_relationship)
mock_update = Mock()
progress_bar.update = mock_update

# Run
result = relationship_validity.get_score(real_data, synthetic_data, metadata, progress_bar)

# Assert
assert result == 1.0
assert mock_update.call_count == num_relationship
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_get_score(

assert score == 0.5
progress_bar.update.assert_called()
assert progress_bar.update.call_count == 2
assert progress_bar.update.call_count == 1
mock_compute_average.assert_called_once_with()
pd.testing.assert_frame_equal(relationship_validity.details, expected_details_property)

Expand Down Expand Up @@ -218,7 +218,7 @@ def test_get_score_raises_errors(
assert pd.isna(score)
pd.testing.assert_frame_equal(relationship_validity.details, expected_details_property)
progress_bar.update.assert_called()
assert progress_bar.update.call_count == 2
assert progress_bar.update.call_count == 1

def test_get_details_with_table_name(self):
"""Test the ``get_details`` method.
Expand Down

0 comments on commit cd64c38

Please sign in to comment.