Skip to content

Commit

Permalink
def + test
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Oct 23, 2023
1 parent 99cb1e4 commit bf7709e
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sdmetrics/column_pairs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from sdmetrics.column_pairs.statistical.correlation_similarity import CorrelationSimilarity
from sdmetrics.column_pairs.statistical.kl_divergence import (
ContinuousKLDivergence, DiscreteKLDivergence)
from sdmetrics.column_pairs.statistical.referential_integrity import ReferentialIntegrity

__all__ = [
'ColumnPairsMetric',
'ContingencySimilarity',
'ContinuousKLDivergence',
'CorrelationSimilarity',
'DiscreteKLDivergence',
'ReferentialIntegrity',
]
2 changes: 2 additions & 0 deletions sdmetrics/column_pairs/statistical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from sdmetrics.column_pairs.statistical.correlation_similarity import CorrelationSimilarity
from sdmetrics.column_pairs.statistical.kl_divergence import (
ContinuousKLDivergence, DiscreteKLDivergence)
from sdmetrics.column_pairs.statistical.referential_integrity import ReferentialIntegrity

__all__ = [
'ContingencySimilarity',
'ContinuousKLDivergence',
'CorrelationSimilarity',
'DiscreteKLDivergence',
'ReferentialIntegrity',
]
84 changes: 84 additions & 0 deletions sdmetrics/column_pairs/statistical/referential_integrity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Referential Integrity Metric."""
import logging

from sdmetrics.column_pairs.base import ColumnPairsMetric
from sdmetrics.goal import Goal

LOGGER = logging.getLogger(__name__)


class ReferentialIntegrity(ColumnPairsMetric):
"""Referential Integrity metric.
Compute the fraction of foreign key values that reference a value in the primary key column
in the synthetic data.
Attributes:
name (str):
Name to use when reports about this metric are printed.
goal (sdmetrics.goal.Goal):
The goal of this metric.
min_value (Union[float, tuple[float]]):
Minimum value or values that this metric can take.
max_value (Union[float, tuple[float]]):
Maximum value or values that this metric can take.
"""

name = 'ReferentialIntegrity'
goal = Goal.MAXIMIZE
min_value = 0.0
max_value = 1.0

@classmethod
def compute_breakdown(cls, real_data, synthetic_data):
"""Compute the score breakdown of the referential integrity metric.
Args:
real_data (tuple of 2 pandas.Series):
(primary_key, foreign_key) columns from the real data.
synthetic_data (tuple of 2 pandas.Series):
(primary_key, foreign_key) columns from the synthetic data.
Returns:
dict:
The score breakdown of the key uniqueness metric.
"""
missing_parents = not real_data[1].isin(real_data[0]).all()
if missing_parents:
LOGGER.info(
"The real data has foreign keys that don't reference any primary key."
)

score = synthetic_data[1].isin(synthetic_data[0]).mean()

return {'score': score}

@classmethod
def compute(cls, real_data, synthetic_data):
"""Compute the referential integrity of two columns.
Args:
real_data (tuple of 2 pandas.Series):
(primary_key, foreign_key) columns from the real data.
synthetic_data (tuple of 2 pandas.Series):
(primary_key, foreign_key) columns from the synthetic data.
Returns:
float:
The key uniqueness of the two columns.
"""
return cls.compute_breakdown(real_data, synthetic_data)['score']

@classmethod
def normalize(cls, raw_score):
"""Return the `raw_score` as is, since it is already normalized.
Args:
raw_score (float):
The value of the metric from `compute`.
Returns:
float:
The normalized value of the metric
"""
return super().normalize(raw_score)
87 changes: 87 additions & 0 deletions tests/unit/column_pairs/statistical/test_referential_integrity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from unittest.mock import patch

import pandas as pd

from sdmetrics.column_pairs.statistical import ReferentialIntegrity


class TestReferentialIntegrity:

def test_compute_breakdown(self):
"""Test the ``compute_breakdown`` method."""
# Setup
real_data = pd.DataFrame({
'primary_key': [1, 2, 3, 4, 5],
'foreign_key': [1, 2, 3, 2, 1]
})
synthetic_data = pd.DataFrame({
'primary_key': [1, 2, 3, 4, 5],
'foreign_key': [1, 6, 3, 4, 5]
})

metric = ReferentialIntegrity()
tuple_real = (real_data['primary_key'], real_data['foreign_key'])
tuple_synthetic = (synthetic_data['primary_key'], synthetic_data['foreign_key'])

# Run
result = metric.compute_breakdown(tuple_real, tuple_synthetic)

# Assert
assert result == {'score': 0.8}

@patch('sdmetrics.column_pairs.statistical.referential_integrity.LOGGER')
def test_compute_breakdown_with_missing_relations_real_data(self, logger_mock):
"""Test the ``compute_breakdown`` when there is missing relationships in the real data."""
# Setup
real_data = pd.DataFrame({
'primary_key': [1, 2, 3, 4, 5],
'foreign_key': [1, 2, 6, 2, 1]
})
synthetic_data = pd.DataFrame({
'primary_key': [1, 2, 3, 4, 5],
'foreign_key': [1, 6, 3, 4, 5]
})

metric = ReferentialIntegrity()
tuple_real = (real_data['primary_key'], real_data['foreign_key'])
tuple_synthetic = (synthetic_data['primary_key'], synthetic_data['foreign_key'])

# Run
result = metric.compute_breakdown(tuple_real, tuple_synthetic)

# Assert
expected_message = "The real data has foreign keys that don't reference any primary key."
assert result == {'score': 0.8}
logger_mock.info.assert_called_once_with(expected_message)

@patch('sdmetrics.column_pairs.statistical.referential_integrity.'
'ReferentialIntegrity.compute_breakdown')
def test_compute(self, compute_breakdown_mock):
"""Test the ``compute`` method."""
# Setup
real_data = pd.Series(['A', 'B', 'C', 'B', 'A'])
synthetic_data = pd.Series(['A', 'B', 'C', 'D', 'E'])
metric = ReferentialIntegrity()
compute_breakdown_mock.return_value = {'score': 0.6}

# Run
result = metric.compute(real_data, synthetic_data)

# Assert
compute_breakdown_mock.assert_called_once_with(real_data, synthetic_data)
assert result == 0.6

@patch('sdmetrics.column_pairs.statistical.referential_integrity.'
'ColumnPairsMetric.normalize')
def test_normalize(self, normalize_mock):
"""Test the ``normalize`` method."""
# Setup
metric = ReferentialIntegrity()
raw_score = 0.9

# Run
result = metric.normalize(raw_score)

# Assert
normalize_mock.assert_called_once_with(raw_score)
assert result == normalize_mock.return_value

0 comments on commit bf7709e

Please sign in to comment.