diff --git a/cape_privacy/spark/transformations/perturbation.py b/cape_privacy/spark/transformations/perturbation.py index 605ee0f..816ccd9 100644 --- a/cape_privacy/spark/transformations/perturbation.py +++ b/cape_privacy/spark/transformations/perturbation.py @@ -72,29 +72,24 @@ class DatePerturbation(base.Transformation): frequencies of timestamps. The amount of noise for each frequency is drawn from the internal [min_freq, max_freq). + Note that seeds are currently not supported. + Attributes: frequency (str, str list): one or more frequencies to perturbate min (int, int list): the frequency value will be greater or equal to min max (int, int list): the frequency value will be less than max - seed (int), optional: a seed to initialize the random generator """ identifier = "date-perturbation" type_signature = "col->col" def __init__( - self, - frequency: StrTuple, - min: IntTuple, - max: IntTuple, - seed: Optional[int] = None, + self, frequency: StrTuple, min: IntTuple, max: IntTuple, ): - typecheck.check_arg(seed, (int, type(None))) super().__init__(dtypes.Date) self._frequency = _check_freq_arg(frequency) self._min = _check_minmax_arg(min) self._max = _check_minmax_arg(max) - self._rng = np.random.default_rng(seed) self._perturb_date = None def __call__(self, x: sql.Column): @@ -105,9 +100,10 @@ def __call__(self, x: sql.Column): def _make_perturb_udf(self): @functions.pandas_udf(dtypes.Date, functions.PandasUDFType.SCALAR) def perturb_date(x: pd.Series) -> pd.Series: + rng = np.random.default_rng() for f, mn, mx in zip(self._frequency, self._min, self._max): # TODO can we switch to a lower dtype than np.int64? - noise = self._rng.integers(mn, mx, size=x.shape) + noise = rng.integers(mn, mx, size=x.shape) delta_fn = _FREQUENCY_TO_DELTA_FN.get(f, None) if delta_fn is None: raise ValueError( diff --git a/cape_privacy/spark/transformations/perturbation_test.py b/cape_privacy/spark/transformations/perturbation_test.py index 03cb80b..994d1d1 100644 --- a/cape_privacy/spark/transformations/perturbation_test.py +++ b/cape_privacy/spark/transformations/perturbation_test.py @@ -7,16 +7,16 @@ from cape_privacy.spark.transformations import perturbation as ptb -def _make_and_apply_numeric_ptb(sess, df, dtype, min, max, seed=42): +def _make_and_apply_numeric_ptb(sess, df, dtype, min, max): df = sess.createDataFrame(df, schema=["data"]) - perturb = ptb.NumericPerturbation(dtype, min=min, max=max, seed=seed) + perturb = ptb.NumericPerturbation(dtype, min=min, max=max) result_df = df.select(perturb(functions.col("data"))) return result_df.toPandas() -def _make_and_apply_date_ptb(sess, df, frequency, min, max, seed=42): +def _make_and_apply_date_ptb(sess, df, frequency, min, max): df = sess.createDataFrame(df, schema=["data"]) - perturb = ptb.DatePerturbation(frequency, min, max, seed) + perturb = ptb.DatePerturbation(frequency, min, max) result_df = df.select(perturb(functions.col("data"))) return result_df.withColumnRenamed("perturb_date(data)", "data").toPandas()