Skip to content

Commit

Permalink
Add DisclosureProtectionEstimate metric (#686)
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h authored Dec 10, 2024
1 parent dafd198 commit 0bf95d9
Show file tree
Hide file tree
Showing 7 changed files with 901 additions and 32 deletions.
6 changes: 5 additions & 1 deletion sdmetrics/single_table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@
CategoricalRF,
CategoricalSVM,
)
from sdmetrics.single_table.privacy.disclosure_protection import DisclosureProtection
from sdmetrics.single_table.privacy.disclosure_protection import (
DisclosureProtection,
DisclosureProtectionEstimate,
)
from sdmetrics.single_table.privacy.ensemble import CategoricalEnsemble
from sdmetrics.single_table.privacy.numerical_sklearn import NumericalLR, NumericalMLP, NumericalSVR
from sdmetrics.single_table.privacy.radius_nearest_neighbor import NumericalRadiusNearestNeighbor
Expand Down Expand Up @@ -111,6 +114,7 @@
'CategoricalZeroCAP',
'CategoricalGeneralizedCAP',
'DisclosureProtection',
'DisclosureProtectionEstimate',
'NumericalMLP',
'NumericalLR',
'NumericalSVR',
Expand Down
6 changes: 5 additions & 1 deletion sdmetrics/single_table/privacy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
CategoricalRF,
CategoricalSVM,
)
from sdmetrics.single_table.privacy.disclosure_protection import DisclosureProtection
from sdmetrics.single_table.privacy.disclosure_protection import (
DisclosureProtection,
DisclosureProtectionEstimate,
)
from sdmetrics.single_table.privacy.ensemble import CategoricalEnsemble
from sdmetrics.single_table.privacy.numerical_sklearn import NumericalLR, NumericalMLP, NumericalSVR
from sdmetrics.single_table.privacy.radius_nearest_neighbor import NumericalRadiusNearestNeighbor
Expand All @@ -28,6 +31,7 @@
'CategoricalSVM',
'CategoricalZeroCAP',
'DisclosureProtection',
'DisclosureProtectionEstimate',
'NumericalLR',
'NumericalMLP',
'NumericalPrivacyMetric',
Expand Down
224 changes: 224 additions & 0 deletions sdmetrics/single_table/privacy/cap.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
"""CAP modules and their attackers."""

import warnings

from sdmetrics.single_table.privacy.base import CategoricalPrivacyMetric, PrivacyAttackerModel
from sdmetrics.single_table.privacy.util import closest_neighbors, count_frequency, majority

DEPRECATION_MSG = (
'Computing CAP metrics directly is deprecated. For improved privacy metrics, '
"please use the 'DisclosureProtection' and 'DisclosureProtectionEstimate' "
'metrics instead.'
)


class CAPAttacker(PrivacyAttackerModel):
"""The CAP (Correct Attribution Probability) privacy attacker.
Expand Down Expand Up @@ -78,6 +86,78 @@ class CategoricalCAP(CategoricalPrivacyMetric):
MODEL = CAPAttacker
ACCURACY_BASE = False

@classmethod
def _compute(
cls,
real_data,
synthetic_data,
metadata=None,
key_fields=None,
sensitive_fields=None,
model_kwargs=None,
):
return super().compute(
real_data=real_data,
synthetic_data=synthetic_data,
metadata=metadata,
key_fields=key_fields,
sensitive_fields=sensitive_fields,
model_kwargs=model_kwargs,
)

@classmethod
def compute(
cls,
real_data,
synthetic_data,
metadata=None,
key_fields=None,
sensitive_fields=None,
model_kwargs=None,
):
"""Compute this metric.
This fits an adversial attacker model on the synthetic data and
then evaluates it making predictions on the real data.
A ``key_fields`` column(s) name must be given, either directly or as a first level
entry in the ``metadata`` dict, which will be used as the key column(s) for the
attack.
A ``sensitive_fields`` column(s) name must be given, either directly or as a first level
entry in the ``metadata`` dict, which will be used as the sensitive_fields column(s)
for the attack.
Args:
real_data (Union[numpy.ndarray, pandas.DataFrame]):
The values from the real dataset.
synthetic_data (Union[numpy.ndarray, pandas.DataFrame]):
The values from the synthetic dataset.
metadata (dict):
Table metadata dict. If not passed, it is build based on the
real_data fields and dtypes.
key_fields (list(str)):
Name of the column(s) to use as the key attributes.
sensitive_fields (list(str)):
Name of the column(s) to use as the sensitive attributes.
model_kwargs (dict):
Key word arguments of the attacker model. cls.MODEL_KWARGS will be used
if none is provided.
Returns:
union[float, tuple[float]]:
Scores obtained by the attackers when evaluated on the real data.
"""
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
return cls._compute(
real_data=real_data,
synthetic_data=synthetic_data,
metadata=metadata,
key_fields=key_fields,
sensitive_fields=sensitive_fields,
model_kwargs=model_kwargs,
)


class ZeroCAPAttacker(CAPAttacker):
"""The 0CAP privacy attacker, which operates in the same way as CAP does.
Expand Down Expand Up @@ -113,6 +193,78 @@ class CategoricalZeroCAP(CategoricalPrivacyMetric):
MODEL = ZeroCAPAttacker
ACCURACY_BASE = False

@classmethod
def _compute(
cls,
real_data,
synthetic_data,
metadata=None,
key_fields=None,
sensitive_fields=None,
model_kwargs=None,
):
return super().compute(
real_data=real_data,
synthetic_data=synthetic_data,
metadata=metadata,
key_fields=key_fields,
sensitive_fields=sensitive_fields,
model_kwargs=model_kwargs,
)

@classmethod
def compute(
cls,
real_data,
synthetic_data,
metadata=None,
key_fields=None,
sensitive_fields=None,
model_kwargs=None,
):
"""Compute this metric.
This fits an adversial attacker model on the synthetic data and
then evaluates it making predictions on the real data.
A ``key_fields`` column(s) name must be given, either directly or as a first level
entry in the ``metadata`` dict, which will be used as the key column(s) for the
attack.
A ``sensitive_fields`` column(s) name must be given, either directly or as a first level
entry in the ``metadata`` dict, which will be used as the sensitive_fields column(s)
for the attack.
Args:
real_data (Union[numpy.ndarray, pandas.DataFrame]):
The values from the real dataset.
synthetic_data (Union[numpy.ndarray, pandas.DataFrame]):
The values from the synthetic dataset.
metadata (dict):
Table metadata dict. If not passed, it is build based on the
real_data fields and dtypes.
key_fields (list(str)):
Name of the column(s) to use as the key attributes.
sensitive_fields (list(str)):
Name of the column(s) to use as the sensitive attributes.
model_kwargs (dict):
Key word arguments of the attacker model. cls.MODEL_KWARGS will be used
if none is provided.
Returns:
union[float, tuple[float]]:
Scores obtained by the attackers when evaluated on the real data.
"""
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
return cls._compute(
real_data=real_data,
synthetic_data=synthetic_data,
metadata=metadata,
key_fields=key_fields,
sensitive_fields=sensitive_fields,
model_kwargs=model_kwargs,
)


class GeneralizedCAPAttacker(CAPAttacker):
"""The GeneralizedCAP privacy attacker.
Expand Down Expand Up @@ -169,3 +321,75 @@ class CategoricalGeneralizedCAP(CategoricalPrivacyMetric):
name = 'Categorical GeneralizedCAP'
MODEL = GeneralizedCAPAttacker
ACCURACY_BASE = False

@classmethod
def _compute(
cls,
real_data,
synthetic_data,
metadata=None,
key_fields=None,
sensitive_fields=None,
model_kwargs=None,
):
return super().compute(
real_data=real_data,
synthetic_data=synthetic_data,
metadata=metadata,
key_fields=key_fields,
sensitive_fields=sensitive_fields,
model_kwargs=model_kwargs,
)

@classmethod
def compute(
cls,
real_data,
synthetic_data,
metadata=None,
key_fields=None,
sensitive_fields=None,
model_kwargs=None,
):
"""Compute this metric.
This fits an adversial attacker model on the synthetic data and
then evaluates it making predictions on the real data.
A ``key_fields`` column(s) name must be given, either directly or as a first level
entry in the ``metadata`` dict, which will be used as the key column(s) for the
attack.
A ``sensitive_fields`` column(s) name must be given, either directly or as a first level
entry in the ``metadata`` dict, which will be used as the sensitive_fields column(s)
for the attack.
Args:
real_data (Union[numpy.ndarray, pandas.DataFrame]):
The values from the real dataset.
synthetic_data (Union[numpy.ndarray, pandas.DataFrame]):
The values from the synthetic dataset.
metadata (dict):
Table metadata dict. If not passed, it is build based on the
real_data fields and dtypes.
key_fields (list(str)):
Name of the column(s) to use as the key attributes.
sensitive_fields (list(str)):
Name of the column(s) to use as the sensitive attributes.
model_kwargs (dict):
Key word arguments of the attacker model. cls.MODEL_KWARGS will be used
if none is provided.
Returns:
union[float, tuple[float]]:
Scores obtained by the attackers when evaluated on the real data.
"""
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
return cls._compute(
real_data=real_data,
synthetic_data=synthetic_data,
metadata=metadata,
key_fields=key_fields,
sensitive_fields=sensitive_fields,
model_kwargs=model_kwargs,
)
Loading

0 comments on commit 0bf95d9

Please sign in to comment.