diff --git a/sdmetrics/reports/multi_table/_properties/relationship_validity.py b/sdmetrics/reports/multi_table/_properties/relationship_validity.py index 2111ef54..d22d8cb2 100644 --- a/sdmetrics/reports/multi_table/_properties/relationship_validity.py +++ b/sdmetrics/reports/multi_table/_properties/relationship_validity.py @@ -43,7 +43,7 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No for relation in metadata.get('relationships', []): real_columns = self._extract_tuple(real_data, relation) synthetic_columns = self._extract_tuple(synthetic_data, relation) - for metric in metrics: + for idx, metric in enumerate(metrics): try: relation_score = metric.compute( real_columns, @@ -54,7 +54,7 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No relation_score = np.nan error_message = f'{type(e).__name__}: {e}' finally: - if progress_bar is not None: + if progress_bar is not None and idx % 2 == 0: progress_bar.update() child_tables.append(relation['child_table_name']) diff --git a/tests/integration/reports/multi_table/_properties/test_relationship_validity.py b/tests/integration/reports/multi_table/_properties/test_relationship_validity.py new file mode 100644 index 00000000..634590a5 --- /dev/null +++ b/tests/integration/reports/multi_table/_properties/test_relationship_validity.py @@ -0,0 +1,39 @@ +from unittest.mock import Mock + +from tqdm import tqdm + +from sdmetrics.demos import load_demo +from sdmetrics.reports.multi_table._properties import RelationshipValidity + + +class TestRelationshipValidity: + + def test_end_to_end(self): + """Test the ``RelationshipValidity`` multi-table property end to end.""" + # Setup + real_data, synthetic_data, metadata = load_demo(modality='multi_table') + relationship_validity = RelationshipValidity() + + # Run + result = relationship_validity.get_score(real_data, synthetic_data, metadata) + + # Assert + assert result == 1.0 + + def test_with_progress_bar(self): + """Test that the progress bar is correctly updated.""" + # Setup + real_data, synthetic_data, metadata = load_demo(modality='multi_table') + relationship_validity = RelationshipValidity() + num_relationship = 2 + + progress_bar = tqdm(total=num_relationship) + mock_update = Mock() + progress_bar.update = mock_update + + # Run + result = relationship_validity.get_score(real_data, synthetic_data, metadata, progress_bar) + + # Assert + assert result == 1.0 + assert mock_update.call_count == num_relationship diff --git a/tests/unit/reports/multi_table/_properties/test_relationship_validity.py b/tests/unit/reports/multi_table/_properties/test_relationship_validity.py index fa70fd5d..40790a0d 100644 --- a/tests/unit/reports/multi_table/_properties/test_relationship_validity.py +++ b/tests/unit/reports/multi_table/_properties/test_relationship_validity.py @@ -173,7 +173,7 @@ def test_get_score( assert score == 0.5 progress_bar.update.assert_called() - assert progress_bar.update.call_count == 2 + assert progress_bar.update.call_count == 1 mock_compute_average.assert_called_once_with() pd.testing.assert_frame_equal(relationship_validity.details, expected_details_property) @@ -218,7 +218,7 @@ def test_get_score_raises_errors( assert pd.isna(score) pd.testing.assert_frame_equal(relationship_validity.details, expected_details_property) progress_bar.update.assert_called() - assert progress_bar.update.call_count == 2 + assert progress_bar.update.call_count == 1 def test_get_details_with_table_name(self): """Test the ``get_details`` method.