diff --git a/sdmetrics/single_table/table_structure.py b/sdmetrics/single_table/table_structure.py index edbc511e..06a97ede 100644 --- a/sdmetrics/single_table/table_structure.py +++ b/sdmetrics/single_table/table_structure.py @@ -36,11 +36,12 @@ def compute_breakdown(cls, real_data, synthetic_data): synthetic_data (pandas.DataFrame): The synthetic data. """ - synthetic_columns = set(synthetic_data.columns) - real_columns = set(real_data.columns) - intersection_columns = real_columns & synthetic_columns - union_columns = real_columns | synthetic_columns - score = len(intersection_columns) / len(union_columns) + real_columns_dtypes = set(zip(real_data.columns, map(str, real_data.dtypes))) + synthetic_columns_dtypes = set(zip(synthetic_data.columns, map(str, synthetic_data.dtypes))) + + intersection = real_columns_dtypes & synthetic_columns_dtypes + union = real_columns_dtypes | synthetic_columns_dtypes + score = len(intersection) / len(union) return {'score': score} diff --git a/tests/integration/reports/multi_table/test_diagnostic_report.py b/tests/integration/reports/multi_table/test_diagnostic_report.py index 57ff4cc6..f006915c 100644 --- a/tests/integration/reports/multi_table/test_diagnostic_report.py +++ b/tests/integration/reports/multi_table/test_diagnostic_report.py @@ -47,6 +47,7 @@ def test_end_to_end_with_metrics_failing(self): """Test the ``DiagnosticReport`` report when some metrics crash. This test makes fail the 'Boundary' property to check that the report still works. + The TableStructure should no longer be 1.0 since there is some dtype mismatch. """ real_data, synthetic_data, metadata = load_demo(modality='multi_table') real_data['users']['age'].iloc[0] = 'error_1' @@ -62,7 +63,7 @@ def test_end_to_end_with_metrics_failing(self): # Assert expected_properties = pd.DataFrame({ 'Property': ['Data Validity', 'Data Structure', 'Relationship Validity'], - 'Score': [1.0, 1.0, 1.0], + 'Score': [1.0, 0.6761904761904761, 1.0], }) expected_details = pd.DataFrame({ 'Table': [ @@ -119,7 +120,7 @@ def test_end_to_end_with_metrics_failing(self): None, ], }) - assert results == 1.0 + assert results == 0.892063492063492 pd.testing.assert_frame_equal( report.get_properties(), expected_properties, check_exact=False, atol=2e-2 ) diff --git a/tests/unit/single_table/test_table_structure.py b/tests/unit/single_table/test_table_structure.py index 4d2e4c61..3d6db5e2 100644 --- a/tests/unit/single_table/test_table_structure.py +++ b/tests/unit/single_table/test_table_structure.py @@ -50,6 +50,28 @@ def test_compute_breakdown(self, real_data): expected_result = {'score': 1.0} assert result == expected_result + def test_compute_breakdown_with_different_dtypes(self, real_data): + """Test the ``compute_breakdown`` method with different data types. + + - Real data has 5 columns. + - Synthetic data is identical except 'col_1' has a different data type. + - Total unique (column_name, dtype) combinations: 6. + - Matching combinations: 4. + - Expected score: 4 / 6 = 2 / 3. + """ + # Setup + synthetic_data = real_data.copy() + synthetic_data['col_1'] = synthetic_data['col_1'].astype('float') + + metric = TableStructure() + + # Run + result = metric.compute_breakdown(real_data, synthetic_data) + + # Assert + expected_result = {'score': 2 / 3} + assert result == expected_result + def test_compute_breakdown_with_missing_columns(self, real_data): """Test the ``compute_breakdown`` method with missing columns.""" # Setup