From 6131758885af61516cfd33a269aac79a3a515762 Mon Sep 17 00:00:00 2001 From: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com> Date: Thu, 26 Oct 2023 13:37:22 -0600 Subject: [PATCH] Add `TableFormat` metric (#479) --- sdmetrics/single_table/__init__.py | 2 + sdmetrics/single_table/table_format.py | 80 +++++++++ tests/unit/single_table/test_table_format.py | 178 +++++++++++++++++++ 3 files changed, 260 insertions(+) create mode 100644 sdmetrics/single_table/table_format.py create mode 100644 tests/unit/single_table/test_table_format.py diff --git a/sdmetrics/single_table/__init__.py b/sdmetrics/single_table/__init__.py index 35704626..d6a49ac2 100644 --- a/sdmetrics/single_table/__init__.py +++ b/sdmetrics/single_table/__init__.py @@ -32,6 +32,7 @@ from sdmetrics.single_table.privacy.numerical_sklearn import ( NumericalLR, NumericalMLP, NumericalSVR) from sdmetrics.single_table.privacy.radius_nearest_neighbor import NumericalRadiusNearestNeighbor +from sdmetrics.single_table.table_format import TableFormat __all__ = [ 'bayesian_network', @@ -90,4 +91,5 @@ 'TVComplement', 'RangeCoverage', 'NewRowSynthesis', + 'TableFormat', ] diff --git a/sdmetrics/single_table/table_format.py b/sdmetrics/single_table/table_format.py new file mode 100644 index 00000000..428aa7ca --- /dev/null +++ b/sdmetrics/single_table/table_format.py @@ -0,0 +1,80 @@ +"""Table Format metric.""" +from sdmetrics.goal import Goal +from sdmetrics.single_table.base import SingleTableMetric + + +class TableFormat(SingleTableMetric): + """TableFormat Single Table metric. + + This metric computes whether the names and data types of each column are + the same in the real and synthetic data. + + Attributes: + name (str): + Name to use when reports about this metric are printed. + goal (sdmetrics.goal.Goal): + The goal of this metric. + min_value (Union[float, tuple[float]]): + Minimum value or values that this metric can take. + max_value (Union[float, tuple[float]]): + Maximum value or values that this metric can take. + """ + + name = 'TableFormat' + goal = Goal.MAXIMIZE + min_value = 0 + max_value = 1 + + @classmethod + def compute_breakdown(cls, real_data, synthetic_data, ignore_dtype_columns=None): + """Compute the score breakdown of the table format metric. + + Args: + real_data (pandas.DataFrame): + The real data. + synthetic_data (pandas.DataFrame): + The synthetic data. + ignore_dtype_columns (list[str]): + List of column names to ignore when comparing data types. + Defaults to ``None``. + """ + ignore_dtype_columns = ignore_dtype_columns or [] + missing_columns_in_synthetic = set(real_data.columns) - set(synthetic_data.columns) + invalid_names = [] + invalid_sdtypes = [] + for column in synthetic_data.columns: + if column not in real_data.columns: + invalid_names.append(column) + continue + + if column in ignore_dtype_columns: + continue + + if synthetic_data[column].dtype != real_data[column].dtype: + invalid_sdtypes.append(column) + + proportion_correct_columns = 1 - len(missing_columns_in_synthetic) / len(real_data.columns) + proportion_valid_names = 1 - len(invalid_names) / len(synthetic_data.columns) + proportion_valid_sdtypes = 1 - len(invalid_sdtypes) / len(synthetic_data.columns) + + score = proportion_correct_columns * proportion_valid_names * proportion_valid_sdtypes + return {'score': score} + + @classmethod + def compute(cls, real_data, synthetic_data, ignore_dtype_columns=None): + """Compute the table format metric score. + + Args: + real_data (pandas.DataFrame): + The real data. + synthetic_data (pandas.DataFrame): + The synthetic data. + ignore_dtype_columns (list[str]): + List of column names to ignore when comparing data types. + Defaults to ``None``. + + Returns: + float: + The metric score. + """ + return cls.compute_breakdown(real_data, synthetic_data, ignore_dtype_columns)['score'] diff --git a/tests/unit/single_table/test_table_format.py b/tests/unit/single_table/test_table_format.py new file mode 100644 index 00000000..2195f1a4 --- /dev/null +++ b/tests/unit/single_table/test_table_format.py @@ -0,0 +1,178 @@ +from unittest.mock import patch + +import pandas as pd +import pytest + +from sdmetrics.single_table import TableFormat + + +@pytest.fixture() +def real_data(): + return pd.DataFrame({ + 'col_1': [1, 2, 3, 4, 5], + 'col_2': ['A', 'B', 'C', 'B', 'A'], + 'col_3': [True, False, True, False, True], + 'col_4': pd.to_datetime([ + '2020-01-01', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' + ]), + 'col_5': [1.0, 2.0, 3.0, 4.0, 5.0] + }) + + +class TestTableFormat: + + def test_compute_breakdown(self, real_data): + """Test the ``compute_breakdown`` method.""" + # Setup + synthetic_data = pd.DataFrame({ + 'col_1': [3, 2, 1, 4, 5], + 'col_2': ['A', 'B', 'C', 'D', 'E'], + 'col_3': [True, False, True, False, True], + 'col_4': pd.to_datetime([ + '2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' + ]), + 'col_5': [4.0, 2.0, 3.0, 4.0, 5.0] + }) + + metric = TableFormat() + + # Run + result = metric.compute_breakdown(real_data, synthetic_data) + + # Assert + expected_result = {'score': 1.0} + assert result == expected_result + + def test_compute_breakdown_with_missing_columns(self, real_data): + """Test the ``compute_breakdown`` method with missing columns.""" + # Setup + synthetic_data = pd.DataFrame({ + 'col_1': [3, 2, 1, 4, 5], + 'col_2': ['A', 'B', 'C', 'D', 'E'], + 'col_3': [True, False, True, False, True], + 'col_4': pd.to_datetime([ + '2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' + ]), + }) + + metric = TableFormat() + + # Run + result = metric.compute_breakdown(real_data, synthetic_data) + + # Assert + expected_result = {'score': 0.8} + assert result == expected_result + + def test_compute_breakdown_with_invalid_names(self, real_data): + """Test the ``compute_breakdown`` method with invalid names.""" + # Setup + synthetic_data = pd.DataFrame({ + 'col_1': [3, 2, 1, 4, 5], + 'col_2': ['A', 'B', 'C', 'D', 'E'], + 'col_3': [True, False, True, False, True], + 'col_4': pd.to_datetime([ + '2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' + ]), + 'col_5': [4.0, 2.0, 3.0, 4.0, 5.0], + 'col_6': [4.0, 2.0, 3.0, 4.0, 5.0], + }) + + metric = TableFormat() + + # Run + result = metric.compute_breakdown(real_data, synthetic_data) + + # Assert + expected_result = {'score': 0.8333333333333334} + assert result == expected_result + + def test_compute_breakdown_with_invalid_dtypes(self, real_data): + """Test the ``compute_breakdown`` method with invalid dtypes.""" + # Setup + synthetic_data = pd.DataFrame({ + 'col_1': [3.0, 2.0, 1.0, 4.0, 5.0], + 'col_2': ['A', 'B', 'C', 'D', 'E'], + 'col_3': [True, False, True, False, True], + 'col_4': [ + '2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' + ], + 'col_5': [4.0, 2.0, 3.0, 4.0, 5.0], + }) + + metric = TableFormat() + + # Run + result = metric.compute_breakdown(real_data, synthetic_data) + + # Assert + expected_result = {'score': 0.6} + assert result == expected_result + + def test_compute_breakdown_ignore_dtype_columns(self, real_data): + """Test the ``compute_breakdown`` method when ignore_dtype_columns is set.""" + # Setup + synthetic_data = pd.DataFrame({ + 'col_1': [3.0, 2.0, 1.0, 4.0, 5.0], + 'col_2': ['A', 'B', 'C', 'D', 'E'], + 'col_3': [True, False, True, False, True], + 'col_4': [ + '2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' + ], + 'col_5': [4.0, 2.0, 3.0, 4.0, 5.0], + }) + + metric = TableFormat() + + # Run + result = metric.compute_breakdown( + real_data, synthetic_data, ignore_dtype_columns=['col_4'] + ) + + # Assert + expected_result = {'score': 0.8} + assert result == expected_result + + def test_compute_breakdown_multiple_error(self, real_data): + """Test the ``compute_breakdown`` method with the different failure modes.""" + synthetic_data = pd.DataFrame({ + 'col_1': [1, 2, 1, 4, 5], + 'col_3': [True, False, True, False, True], + 'col_4': [ + '2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' + ], + 'col_5': [4.0, 2.0, 3.0, 4.0, 5.0], + 'col_6': [4.0, 2.0, 3.0, 4.0, 5.0], + }) + + metric = TableFormat() + + # Run + result = metric.compute_breakdown(real_data, synthetic_data) + + # Assert + expected_result = {'score': 0.5120000000000001} + assert result == expected_result + + @patch('sdmetrics.single_table.table_format.TableFormat.compute_breakdown') + def test_compute(self, compute_breakdown_mock, real_data): + """Test the ``compute`` method.""" + # Setup + synthetic_data = pd.DataFrame({ + 'col_1': [3, 2, 1, 4, 5], + 'col_2': ['A', 'B', 'C', 'D', 'E'], + 'col_3': [True, False, True, False, True], + 'col_4': pd.to_datetime([ + '2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' + ]), + 'col_5': [4.0, 2.0, 3.0, 4.0, 5.0] + }) + metric = TableFormat() + compute_breakdown_mock.return_value = {'score': 0.6} + + # Run + result = metric.compute(real_data, synthetic_data) + + # Assert + compute_breakdown_mock.assert_called_once_with(real_data, synthetic_data, None) + assert result == 0.6