Skip to content

Commit

Permalink
Add TableFormat metric (#479)
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Nov 27, 2023
1 parent 78fb427 commit 6131758
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sdmetrics/single_table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sdmetrics.single_table.privacy.numerical_sklearn import (
NumericalLR, NumericalMLP, NumericalSVR)
from sdmetrics.single_table.privacy.radius_nearest_neighbor import NumericalRadiusNearestNeighbor
from sdmetrics.single_table.table_format import TableFormat

__all__ = [
'bayesian_network',
Expand Down Expand Up @@ -90,4 +91,5 @@
'TVComplement',
'RangeCoverage',
'NewRowSynthesis',
'TableFormat',
]
80 changes: 80 additions & 0 deletions sdmetrics/single_table/table_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Table Format metric."""
from sdmetrics.goal import Goal
from sdmetrics.single_table.base import SingleTableMetric


class TableFormat(SingleTableMetric):
"""TableFormat Single Table metric.
This metric computes whether the names and data types of each column are
the same in the real and synthetic data.
Attributes:
name (str):
Name to use when reports about this metric are printed.
goal (sdmetrics.goal.Goal):
The goal of this metric.
min_value (Union[float, tuple[float]]):
Minimum value or values that this metric can take.
max_value (Union[float, tuple[float]]):
Maximum value or values that this metric can take.
"""

name = 'TableFormat'
goal = Goal.MAXIMIZE
min_value = 0
max_value = 1

@classmethod
def compute_breakdown(cls, real_data, synthetic_data, ignore_dtype_columns=None):
"""Compute the score breakdown of the table format metric.
Args:
real_data (pandas.DataFrame):
The real data.
synthetic_data (pandas.DataFrame):
The synthetic data.
ignore_dtype_columns (list[str]):
List of column names to ignore when comparing data types.
Defaults to ``None``.
"""
ignore_dtype_columns = ignore_dtype_columns or []
missing_columns_in_synthetic = set(real_data.columns) - set(synthetic_data.columns)
invalid_names = []
invalid_sdtypes = []
for column in synthetic_data.columns:
if column not in real_data.columns:
invalid_names.append(column)
continue

if column in ignore_dtype_columns:
continue

if synthetic_data[column].dtype != real_data[column].dtype:
invalid_sdtypes.append(column)

proportion_correct_columns = 1 - len(missing_columns_in_synthetic) / len(real_data.columns)
proportion_valid_names = 1 - len(invalid_names) / len(synthetic_data.columns)
proportion_valid_sdtypes = 1 - len(invalid_sdtypes) / len(synthetic_data.columns)

score = proportion_correct_columns * proportion_valid_names * proportion_valid_sdtypes
return {'score': score}

@classmethod
def compute(cls, real_data, synthetic_data, ignore_dtype_columns=None):
"""Compute the table format metric score.
Args:
real_data (pandas.DataFrame):
The real data.
synthetic_data (pandas.DataFrame):
The synthetic data.
ignore_dtype_columns (list[str]):
List of column names to ignore when comparing data types.
Defaults to ``None``.
Returns:
float:
The metric score.
"""
return cls.compute_breakdown(real_data, synthetic_data, ignore_dtype_columns)['score']
178 changes: 178 additions & 0 deletions tests/unit/single_table/test_table_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from unittest.mock import patch

import pandas as pd
import pytest

from sdmetrics.single_table import TableFormat


@pytest.fixture()
def real_data():
return pd.DataFrame({
'col_1': [1, 2, 3, 4, 5],
'col_2': ['A', 'B', 'C', 'B', 'A'],
'col_3': [True, False, True, False, True],
'col_4': pd.to_datetime([
'2020-01-01', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
]),
'col_5': [1.0, 2.0, 3.0, 4.0, 5.0]
})


class TestTableFormat:

def test_compute_breakdown(self, real_data):
"""Test the ``compute_breakdown`` method."""
# Setup
synthetic_data = pd.DataFrame({
'col_1': [3, 2, 1, 4, 5],
'col_2': ['A', 'B', 'C', 'D', 'E'],
'col_3': [True, False, True, False, True],
'col_4': pd.to_datetime([
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
]),
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0]
})

metric = TableFormat()

# Run
result = metric.compute_breakdown(real_data, synthetic_data)

# Assert
expected_result = {'score': 1.0}
assert result == expected_result

def test_compute_breakdown_with_missing_columns(self, real_data):
"""Test the ``compute_breakdown`` method with missing columns."""
# Setup
synthetic_data = pd.DataFrame({
'col_1': [3, 2, 1, 4, 5],
'col_2': ['A', 'B', 'C', 'D', 'E'],
'col_3': [True, False, True, False, True],
'col_4': pd.to_datetime([
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
]),
})

metric = TableFormat()

# Run
result = metric.compute_breakdown(real_data, synthetic_data)

# Assert
expected_result = {'score': 0.8}
assert result == expected_result

def test_compute_breakdown_with_invalid_names(self, real_data):
"""Test the ``compute_breakdown`` method with invalid names."""
# Setup
synthetic_data = pd.DataFrame({
'col_1': [3, 2, 1, 4, 5],
'col_2': ['A', 'B', 'C', 'D', 'E'],
'col_3': [True, False, True, False, True],
'col_4': pd.to_datetime([
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
]),
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0],
'col_6': [4.0, 2.0, 3.0, 4.0, 5.0],
})

metric = TableFormat()

# Run
result = metric.compute_breakdown(real_data, synthetic_data)

# Assert
expected_result = {'score': 0.8333333333333334}
assert result == expected_result

def test_compute_breakdown_with_invalid_dtypes(self, real_data):
"""Test the ``compute_breakdown`` method with invalid dtypes."""
# Setup
synthetic_data = pd.DataFrame({
'col_1': [3.0, 2.0, 1.0, 4.0, 5.0],
'col_2': ['A', 'B', 'C', 'D', 'E'],
'col_3': [True, False, True, False, True],
'col_4': [
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
],
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0],
})

metric = TableFormat()

# Run
result = metric.compute_breakdown(real_data, synthetic_data)

# Assert
expected_result = {'score': 0.6}
assert result == expected_result

def test_compute_breakdown_ignore_dtype_columns(self, real_data):
"""Test the ``compute_breakdown`` method when ignore_dtype_columns is set."""
# Setup
synthetic_data = pd.DataFrame({
'col_1': [3.0, 2.0, 1.0, 4.0, 5.0],
'col_2': ['A', 'B', 'C', 'D', 'E'],
'col_3': [True, False, True, False, True],
'col_4': [
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
],
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0],
})

metric = TableFormat()

# Run
result = metric.compute_breakdown(
real_data, synthetic_data, ignore_dtype_columns=['col_4']
)

# Assert
expected_result = {'score': 0.8}
assert result == expected_result

def test_compute_breakdown_multiple_error(self, real_data):
"""Test the ``compute_breakdown`` method with the different failure modes."""
synthetic_data = pd.DataFrame({
'col_1': [1, 2, 1, 4, 5],
'col_3': [True, False, True, False, True],
'col_4': [
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
],
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0],
'col_6': [4.0, 2.0, 3.0, 4.0, 5.0],
})

metric = TableFormat()

# Run
result = metric.compute_breakdown(real_data, synthetic_data)

# Assert
expected_result = {'score': 0.5120000000000001}
assert result == expected_result

@patch('sdmetrics.single_table.table_format.TableFormat.compute_breakdown')
def test_compute(self, compute_breakdown_mock, real_data):
"""Test the ``compute`` method."""
# Setup
synthetic_data = pd.DataFrame({
'col_1': [3, 2, 1, 4, 5],
'col_2': ['A', 'B', 'C', 'D', 'E'],
'col_3': [True, False, True, False, True],
'col_4': pd.to_datetime([
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
]),
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0]
})
metric = TableFormat()
compute_breakdown_mock.return_value = {'score': 0.6}

# Run
result = metric.compute(real_data, synthetic_data)

# Assert
compute_breakdown_mock.assert_called_once_with(real_data, synthetic_data, None)
assert result == 0.6

0 comments on commit 6131758

Please sign in to comment.