diff --git a/sdmetrics/single_column/statistical/kscomplement.py b/sdmetrics/single_column/statistical/kscomplement.py index 3be01330..525e85c7 100644 --- a/sdmetrics/single_column/statistical/kscomplement.py +++ b/sdmetrics/single_column/statistical/kscomplement.py @@ -1,5 +1,6 @@ """Kolmogorov-Smirnov test based Metric.""" +import numpy as np import pandas as pd from scipy.stats import ks_2samp @@ -56,7 +57,13 @@ def compute(real_data, synthetic_data): real_data = pd.to_numeric(real_data) synthetic_data = pd.to_numeric(synthetic_data) - statistic, _ = ks_2samp(real_data, synthetic_data) + try: + statistic, _ = ks_2samp(real_data, synthetic_data) + except ValueError as e: + if str(e) == 'Data passed to ks_2samp must not be empty': + return np.nan + else: + raise ValueError(e) return 1 - statistic diff --git a/sdmetrics/timeseries/inter_row_msas.py b/sdmetrics/timeseries/inter_row_msas.py index c819a188..eea77f06 100644 --- a/sdmetrics/timeseries/inter_row_msas.py +++ b/sdmetrics/timeseries/inter_row_msas.py @@ -1,5 +1,7 @@ """InterRowMSAS module.""" +import warnings + import numpy as np import pandas as pd @@ -77,7 +79,15 @@ def compute(real_data, synthetic_data, n_rows_diff=1, apply_log=False): real_values = np.log(real_values) synthetic_values = np.log(synthetic_values) - def calculate_differences(keys, values, n_rows_diff): + def calculate_differences(keys, values, n_rows_diff, data_name): + group_sizes = values.groupby(keys).size() + num_invalid_groups = group_sizes[group_sizes <= n_rows_diff].count() + if num_invalid_groups > 0: + warnings.warn( + f"n_rows_diff '{n_rows_diff}' is greater than the " + f'size of {num_invalid_groups} sequence keys in {data_name}.' + ) + differences = values.groupby(keys).apply( lambda group: np.mean( group.to_numpy()[n_rows_diff:] - group.to_numpy()[:-n_rows_diff] @@ -85,9 +95,12 @@ def calculate_differences(keys, values, n_rows_diff): if len(group) > n_rows_diff else np.nan ) + return pd.Series(differences) - real_diff = calculate_differences(real_keys, real_values, n_rows_diff) - synthetic_diff = calculate_differences(synthetic_keys, synthetic_values, n_rows_diff) + real_diff = calculate_differences(real_keys, real_values, n_rows_diff, 'real_data') + synthetic_diff = calculate_differences( + synthetic_keys, synthetic_values, n_rows_diff, 'synthetic_data' + ) return KSComplement.compute(real_diff, synthetic_diff) diff --git a/tests/unit/timeseries/test_inter_row_msas.py b/tests/unit/timeseries/test_inter_row_msas.py index d9082a94..14101079 100644 --- a/tests/unit/timeseries/test_inter_row_msas.py +++ b/tests/unit/timeseries/test_inter_row_msas.py @@ -154,3 +154,23 @@ def test_compute_invalid_apply_log(self): n_rows_diff=1, apply_log='True', # Should be a boolean, not a string ) + + def test_compute_warning(self): + """Test a warning is raised when n_rows_diff is greater than sequence values size.""" + # Setup + real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2']) + real_values = pd.Series([1, 2, 3, 4, 5, 6]) + synthetic_keys = pd.Series(['id3', 'id3', 'id3', 'id4', 'id4', 'id4']) + synthetic_values = pd.Series([1, 10, 3, 7, 5, 1]) + + # Run and Assert + warn_msg = "n_rows_diff '10' is greater than the size of 2 sequence keys in real_data." + with pytest.warns(UserWarning, match=warn_msg): + score = InterRowMSAS.compute( + real_data=(real_keys, real_values), + synthetic_data=(synthetic_keys, synthetic_values), + n_rows_diff=10, + ) + + # Assert + assert pd.isna(score)