Skip to content

Commit

Permalink
Add warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h committed Dec 6, 2024
1 parent 6234fcb commit f4a47e2
Show file tree
Hide file tree
Showing 4 changed files with 316 additions and 7 deletions.
224 changes: 224 additions & 0 deletions sdmetrics/single_table/privacy/cap.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
"""CAP modules and their attackers."""

import warnings

from sdmetrics.single_table.privacy.base import CategoricalPrivacyMetric, PrivacyAttackerModel
from sdmetrics.single_table.privacy.util import closest_neighbors, count_frequency, majority

DEPRECATION_MSG = (
'Computing CAP metrics directly is deprecated. For improved privacy metrics, '
"please use the 'DisclosureProtection' and 'DisclosureProtectionEstimate' "
'metrics instead.'
)


class CAPAttacker(PrivacyAttackerModel):
"""The CAP (Correct Attribution Probability) privacy attacker.
Expand Down Expand Up @@ -78,6 +86,78 @@ class CategoricalCAP(CategoricalPrivacyMetric):
MODEL = CAPAttacker
ACCURACY_BASE = False

@classmethod
def _compute(
cls,
real_data,
synthetic_data,
metadata=None,
key_fields=None,
sensitive_fields=None,
model_kwargs=None,
):
return super().compute(
real_data=real_data,
synthetic_data=synthetic_data,
metadata=metadata,
key_fields=key_fields,
sensitive_fields=sensitive_fields,
model_kwargs=model_kwargs,
)

@classmethod
def compute(
cls,
real_data,
synthetic_data,
metadata=None,
key_fields=None,
sensitive_fields=None,
model_kwargs=None,
):
"""Compute this metric.
This fits an adversial attacker model on the synthetic data and
then evaluates it making predictions on the real data.
A ``key_fields`` column(s) name must be given, either directly or as a first level
entry in the ``metadata`` dict, which will be used as the key column(s) for the
attack.
A ``sensitive_fields`` column(s) name must be given, either directly or as a first level
entry in the ``metadata`` dict, which will be used as the sensitive_fields column(s)
for the attack.
Args:
real_data (Union[numpy.ndarray, pandas.DataFrame]):
The values from the real dataset.
synthetic_data (Union[numpy.ndarray, pandas.DataFrame]):
The values from the synthetic dataset.
metadata (dict):
Table metadata dict. If not passed, it is build based on the
real_data fields and dtypes.
key_fields (list(str)):
Name of the column(s) to use as the key attributes.
sensitive_fields (list(str)):
Name of the column(s) to use as the sensitive attributes.
model_kwargs (dict):
Key word arguments of the attacker model. cls.MODEL_KWARGS will be used
if none is provided.
Returns:
union[float, tuple[float]]:
Scores obtained by the attackers when evaluated on the real data.
"""
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
return cls._compute(
real_data=real_data,
synthetic_data=synthetic_data,
metadata=metadata,
key_fields=key_fields,
sensitive_fields=sensitive_fields,
model_kwargs=model_kwargs,
)


class ZeroCAPAttacker(CAPAttacker):
"""The 0CAP privacy attacker, which operates in the same way as CAP does.
Expand Down Expand Up @@ -113,6 +193,78 @@ class CategoricalZeroCAP(CategoricalPrivacyMetric):
MODEL = ZeroCAPAttacker
ACCURACY_BASE = False

@classmethod
def _compute(
cls,
real_data,
synthetic_data,
metadata=None,
key_fields=None,
sensitive_fields=None,
model_kwargs=None,
):
return super().compute(
real_data=real_data,
synthetic_data=synthetic_data,
metadata=metadata,
key_fields=key_fields,
sensitive_fields=sensitive_fields,
model_kwargs=model_kwargs,
)

@classmethod
def compute(
cls,
real_data,
synthetic_data,
metadata=None,
key_fields=None,
sensitive_fields=None,
model_kwargs=None,
):
"""Compute this metric.
This fits an adversial attacker model on the synthetic data and
then evaluates it making predictions on the real data.
A ``key_fields`` column(s) name must be given, either directly or as a first level
entry in the ``metadata`` dict, which will be used as the key column(s) for the
attack.
A ``sensitive_fields`` column(s) name must be given, either directly or as a first level
entry in the ``metadata`` dict, which will be used as the sensitive_fields column(s)
for the attack.
Args:
real_data (Union[numpy.ndarray, pandas.DataFrame]):
The values from the real dataset.
synthetic_data (Union[numpy.ndarray, pandas.DataFrame]):
The values from the synthetic dataset.
metadata (dict):
Table metadata dict. If not passed, it is build based on the
real_data fields and dtypes.
key_fields (list(str)):
Name of the column(s) to use as the key attributes.
sensitive_fields (list(str)):
Name of the column(s) to use as the sensitive attributes.
model_kwargs (dict):
Key word arguments of the attacker model. cls.MODEL_KWARGS will be used
if none is provided.
Returns:
union[float, tuple[float]]:
Scores obtained by the attackers when evaluated on the real data.
"""
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
return cls._compute(
real_data=real_data,
synthetic_data=synthetic_data,
metadata=metadata,
key_fields=key_fields,
sensitive_fields=sensitive_fields,
model_kwargs=model_kwargs,
)


class GeneralizedCAPAttacker(CAPAttacker):
"""The GeneralizedCAP privacy attacker.
Expand Down Expand Up @@ -169,3 +321,75 @@ class CategoricalGeneralizedCAP(CategoricalPrivacyMetric):
name = 'Categorical GeneralizedCAP'
MODEL = GeneralizedCAPAttacker
ACCURACY_BASE = False

@classmethod
def _compute(
cls,
real_data,
synthetic_data,
metadata=None,
key_fields=None,
sensitive_fields=None,
model_kwargs=None,
):
return super().compute(
real_data=real_data,
synthetic_data=synthetic_data,
metadata=metadata,
key_fields=key_fields,
sensitive_fields=sensitive_fields,
model_kwargs=model_kwargs,
)

@classmethod
def compute(
cls,
real_data,
synthetic_data,
metadata=None,
key_fields=None,
sensitive_fields=None,
model_kwargs=None,
):
"""Compute this metric.
This fits an adversial attacker model on the synthetic data and
then evaluates it making predictions on the real data.
A ``key_fields`` column(s) name must be given, either directly or as a first level
entry in the ``metadata`` dict, which will be used as the key column(s) for the
attack.
A ``sensitive_fields`` column(s) name must be given, either directly or as a first level
entry in the ``metadata`` dict, which will be used as the sensitive_fields column(s)
for the attack.
Args:
real_data (Union[numpy.ndarray, pandas.DataFrame]):
The values from the real dataset.
synthetic_data (Union[numpy.ndarray, pandas.DataFrame]):
The values from the synthetic dataset.
metadata (dict):
Table metadata dict. If not passed, it is build based on the
real_data fields and dtypes.
key_fields (list(str)):
Name of the column(s) to use as the key attributes.
sensitive_fields (list(str)):
Name of the column(s) to use as the sensitive attributes.
model_kwargs (dict):
Key word arguments of the attacker model. cls.MODEL_KWARGS will be used
if none is provided.
Returns:
union[float, tuple[float]]:
Scores obtained by the attackers when evaluated on the real data.
"""
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
return cls._compute(
real_data=real_data,
synthetic_data=synthetic_data,
metadata=metadata,
key_fields=key_fields,
sensitive_fields=sensitive_fields,
model_kwargs=model_kwargs,
)
15 changes: 13 additions & 2 deletions sdmetrics/single_table/privacy/disclosure_protection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Disclosure protection metrics."""

import warnings

import numpy as np
import pandas as pd
import tqdm
Expand All @@ -12,6 +14,8 @@
CategoricalZeroCAP,
)

MAX_NUM_ROWS = 50000

CAP_METHODS = {
'CAP': CategoricalCAP,
'ZERO_CAP': CategoricalZeroCAP,
Expand Down Expand Up @@ -204,7 +208,14 @@ def compute_breakdown(
continuous_column_names,
num_discrete_bins,
)

computation_method = computation_method.upper()
if len(real_data) > MAX_NUM_ROWS or len(synthetic_data) > MAX_NUM_ROWS:
warnings.warn(
f'Data exceeds {MAX_NUM_ROWS} rows, perfomance may be slow.'
'Consider using the `DisclosureProtectionEstimate` for faster computation.'
)

real_data, synthetic_data = cls._discretize_and_fillna(
real_data,
synthetic_data,
Expand All @@ -219,7 +230,7 @@ def compute_breakdown(

# Compute CAP metric
cap_metric = CAP_METHODS.get(computation_method)
cap_protection = cap_metric.compute(
cap_protection = cap_metric._compute(
real_data,
synthetic_data,
key_fields=known_column_names,
Expand Down Expand Up @@ -343,7 +354,7 @@ def _compute_estimated_cap_metric(
real_data_samp = real_data.sample(min(num_rows_subsample, len(real_data)))
synth_data_samp = synthetic_data.sample(min(num_rows_subsample, len(synthetic_data)))

estimated_cap_protection = cap_metric.compute(
estimated_cap_protection = cap_metric._compute(
real_data_samp,
synth_data_samp,
key_fields=known_column_names,
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/single_table/privacy/test_cap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import re

import pandas as pd
import pytest

from sdmetrics.single_table.privacy.cap import (
CategoricalCAP,
CategoricalGeneralizedCAP,
CategoricalZeroCAP,
)


@pytest.mark.parametrize('metric', [CategoricalCAP, CategoricalZeroCAP, CategoricalGeneralizedCAP])
def test_CAP_deprecation_message(metric):
"""Test deprecation warning is raised when running the metric directly."""
# Setup
real_data = pd.DataFrame({'col1': range(5), 'col2': ['A', 'B', 'C', 'A', 'B']})
synthetic_data = pd.DataFrame({'col1': range(5), 'col2': ['C', 'A', 'A', 'B', 'C']})

# Run and Assert
expected_warning = re.escape(
'Computing CAP metrics directly is deprecated. For improved privacy metrics, '
"please use the 'DisclosureProtection' and 'DisclosureProtectionEstimate' "
'metrics instead.'
)
with pytest.warns(DeprecationWarning, match=expected_warning):
metric.compute(real_data, synthetic_data, key_fields=['col1'], sensitive_fields=['col2'])
Loading

0 comments on commit f4a47e2

Please sign in to comment.