Skip to content

Commit

Permalink
remove seeds from Spark DataPerturbation
Browse files Browse the repository at this point in the history
  • Loading branch information
mortendahl authored and justin1121 committed Jun 23, 2020
1 parent fca7673 commit bbf395d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
14 changes: 5 additions & 9 deletions cape_privacy/spark/transformations/perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,29 +72,24 @@ class DatePerturbation(base.Transformation):
frequencies of timestamps. The amount of noise for each frequency
is drawn from the internal [min_freq, max_freq).
Note that seeds are currently not supported.
Attributes:
frequency (str, str list): one or more frequencies to perturbate
min (int, int list): the frequency value will be greater or equal to min
max (int, int list): the frequency value will be less than max
seed (int), optional: a seed to initialize the random generator
"""

identifier = "date-perturbation"
type_signature = "col->col"

def __init__(
self,
frequency: StrTuple,
min: IntTuple,
max: IntTuple,
seed: Optional[int] = None,
self, frequency: StrTuple, min: IntTuple, max: IntTuple,
):
typecheck.check_arg(seed, (int, type(None)))
super().__init__(dtypes.Date)
self._frequency = _check_freq_arg(frequency)
self._min = _check_minmax_arg(min)
self._max = _check_minmax_arg(max)
self._rng = np.random.default_rng(seed)
self._perturb_date = None

def __call__(self, x: sql.Column):
Expand All @@ -105,9 +100,10 @@ def __call__(self, x: sql.Column):
def _make_perturb_udf(self):
@functions.pandas_udf(dtypes.Date, functions.PandasUDFType.SCALAR)
def perturb_date(x: pd.Series) -> pd.Series:
rng = np.random.default_rng()
for f, mn, mx in zip(self._frequency, self._min, self._max):
# TODO can we switch to a lower dtype than np.int64?
noise = self._rng.integers(mn, mx, size=x.shape)
noise = rng.integers(mn, mx, size=x.shape)
delta_fn = _FREQUENCY_TO_DELTA_FN.get(f, None)
if delta_fn is None:
raise ValueError(
Expand Down
8 changes: 4 additions & 4 deletions cape_privacy/spark/transformations/perturbation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
from cape_privacy.spark.transformations import perturbation as ptb


def _make_and_apply_numeric_ptb(sess, df, dtype, min, max, seed=42):
def _make_and_apply_numeric_ptb(sess, df, dtype, min, max):
df = sess.createDataFrame(df, schema=["data"])
perturb = ptb.NumericPerturbation(dtype, min=min, max=max, seed=seed)
perturb = ptb.NumericPerturbation(dtype, min=min, max=max)
result_df = df.select(perturb(functions.col("data")))
return result_df.toPandas()


def _make_and_apply_date_ptb(sess, df, frequency, min, max, seed=42):
def _make_and_apply_date_ptb(sess, df, frequency, min, max):
df = sess.createDataFrame(df, schema=["data"])
perturb = ptb.DatePerturbation(frequency, min, max, seed)
perturb = ptb.DatePerturbation(frequency, min, max)
result_df = df.select(perturb(functions.col("data")))
return result_df.withColumnRenamed("perturb_date(data)", "data").toPandas()

Expand Down

0 comments on commit bbf395d

Please sign in to comment.