diff --git a/sdmetrics/column_pairs/statistical/inter_row_msas.py b/sdmetrics/column_pairs/statistical/inter_row_msas.py index eea77f06..200b58fa 100644 --- a/sdmetrics/column_pairs/statistical/inter_row_msas.py +++ b/sdmetrics/column_pairs/statistical/inter_row_msas.py @@ -29,7 +29,61 @@ class InterRowMSAS: max_value = 1.0 @staticmethod - def compute(real_data, synthetic_data, n_rows_diff=1, apply_log=False): + def _validate_inputs(real_data, synthetic_data, n_rows_diff, apply_log): + for data in [real_data, synthetic_data]: + if ( + not isinstance(data, tuple) + or len(data) != 2 + or (not (isinstance(data[0], pd.Series) and isinstance(data[1], pd.Series))) + ): + raise ValueError('The data must be a tuple of two pandas series.') + + if not isinstance(n_rows_diff, int) or n_rows_diff < 1: + raise ValueError("'n_rows_diff' must be an integer greater than zero.") + + if not isinstance(apply_log, bool): + raise ValueError("'apply_log' must be a boolean.") + + @staticmethod + def _apply_log(real_values, synthetic_values, apply_log): + if apply_log: + num_invalid = sum(x <= 0 for x in pd.concat((real_values, synthetic_values))) + if num_invalid: + warnings.warn( + f'There are {num_invalid} non-positive values in your data, which cannot be ' + "used with log. Consider changing 'apply_log' to False for a better result." + ) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message='.*encountered in log') + real_values = np.log(real_values) + synthetic_values = np.log(synthetic_values) + + return real_values, synthetic_values + + @staticmethod + def _calculate_differences(keys, values, n_rows_diff, data_name): + grouped = values.groupby(keys) + group_sizes = grouped.size() + + num_invalid_groups = len(group_sizes[group_sizes <= n_rows_diff]) + if num_invalid_groups > 0: + warnings.warn( + f"n_rows_diff '{n_rows_diff}' is greater than the " + f'size of {num_invalid_groups} sequence keys in {data_name}.' + ) + + def diff_func(group): + if len(group) <= n_rows_diff: + return np.nan + group = group.to_numpy() + return np.mean(group[n_rows_diff:] - group[:-n_rows_diff]) + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message='invalid value encountered in.*') + return grouped.apply(diff_func) + + @classmethod + def compute(cls, real_data, synthetic_data, n_rows_diff=1, apply_log=False): """Compute this metric. This metric compares the inter-row differences of sequences in the real data @@ -58,48 +112,13 @@ def compute(real_data, synthetic_data, n_rows_diff=1, apply_log=False): float: The similarity score between the real and synthetic data distributions. """ - for data in [real_data, synthetic_data]: - if ( - not isinstance(data, tuple) - or len(data) != 2 - or (not (isinstance(data[0], pd.Series) and isinstance(data[1], pd.Series))) - ): - raise ValueError('The data must be a tuple of two pandas series.') - - if not isinstance(n_rows_diff, int) or n_rows_diff < 1: - raise ValueError("'n_rows_diff' must be an integer greater than zero.") - - if not isinstance(apply_log, bool): - raise ValueError("'apply_log' must be a boolean.") - + cls._validate_inputs(real_data, synthetic_data, n_rows_diff, apply_log) real_keys, real_values = real_data synthetic_keys, synthetic_values = synthetic_data + real_values, synthetic_values = cls._apply_log(real_values, synthetic_values, apply_log) - if apply_log: - real_values = np.log(real_values) - synthetic_values = np.log(synthetic_values) - - def calculate_differences(keys, values, n_rows_diff, data_name): - group_sizes = values.groupby(keys).size() - num_invalid_groups = group_sizes[group_sizes <= n_rows_diff].count() - if num_invalid_groups > 0: - warnings.warn( - f"n_rows_diff '{n_rows_diff}' is greater than the " - f'size of {num_invalid_groups} sequence keys in {data_name}.' - ) - - differences = values.groupby(keys).apply( - lambda group: np.mean( - group.to_numpy()[n_rows_diff:] - group.to_numpy()[:-n_rows_diff] - ) - if len(group) > n_rows_diff - else np.nan - ) - - return pd.Series(differences) - - real_diff = calculate_differences(real_keys, real_values, n_rows_diff, 'real_data') - synthetic_diff = calculate_differences( + real_diff = cls._calculate_differences(real_keys, real_values, n_rows_diff, 'real_data') + synthetic_diff = cls._calculate_differences( synthetic_keys, synthetic_values, n_rows_diff, 'synthetic_data' ) diff --git a/tests/unit/column_pairs/statistical/test_inter_row_msas.py b/tests/unit/column_pairs/statistical/test_inter_row_msas.py index 9a3552db..a88e375f 100644 --- a/tests/unit/column_pairs/statistical/test_inter_row_msas.py +++ b/tests/unit/column_pairs/statistical/test_inter_row_msas.py @@ -71,6 +71,31 @@ def test_compute_with_log(self): # Assert assert score == 1 + def test_compute_with_log_warning(self): + """Test it warns when negative values are present and apply_log is True.""" + # Setup + real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2']) + real_values = pd.Series([1, 1.4, 4, -1, 16, -10]) + synthetic_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2']) + synthetic_values = pd.Series([1, 2, -4, 8, 16, 30]) + + # Run + with pytest.warns(UserWarning) as warning_info: + score = InterRowMSAS.compute( + real_data=(real_keys, real_values), + synthetic_data=(synthetic_keys, synthetic_values), + apply_log=True, + ) + + # Assert + expected_message = ( + 'There are 3 non-positive values in your data, which cannot be used with log. ' + "Consider changing 'apply_log' to False for a better result." + ) + assert len(warning_info) == 1 + assert str(warning_info[0].message) == expected_message + assert score == 0 + def test_compute_different_n_rows_diff(self): """Test it with different n_rows_diff.""" # Setup