Skip to content

Commit

Permalink
Add warning
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Nov 1, 2024
1 parent 4cae06b commit 47139b1
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 4 deletions.
9 changes: 8 additions & 1 deletion sdmetrics/single_column/statistical/kscomplement.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Kolmogorov-Smirnov test based Metric."""

import numpy as np
import pandas as pd
from scipy.stats import ks_2samp

Expand Down Expand Up @@ -56,7 +57,13 @@ def compute(real_data, synthetic_data):
real_data = pd.to_numeric(real_data)
synthetic_data = pd.to_numeric(synthetic_data)

statistic, _ = ks_2samp(real_data, synthetic_data)
try:
statistic, _ = ks_2samp(real_data, synthetic_data)
except ValueError as e:
if str(e) == 'Data passed to ks_2samp must not be empty':
return np.nan
else:
raise ValueError(e)

return 1 - statistic

Expand Down
19 changes: 16 additions & 3 deletions sdmetrics/timeseries/inter_row_msas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""InterRowMSAS module."""

import warnings

import numpy as np
import pandas as pd

Expand Down Expand Up @@ -77,17 +79,28 @@ def compute(real_data, synthetic_data, n_rows_diff=1, apply_log=False):
real_values = np.log(real_values)
synthetic_values = np.log(synthetic_values)

def calculate_differences(keys, values, n_rows_diff):
def calculate_differences(keys, values, n_rows_diff, data_name):
group_sizes = values.groupby(keys).size()
num_invalid_groups = group_sizes[group_sizes <= n_rows_diff].count()
if num_invalid_groups > 0:
warnings.warn(
f"n_rows_diff '{n_rows_diff}' is greater than the "
f'size of {num_invalid_groups} sequence keys in {data_name}.'
)

differences = values.groupby(keys).apply(
lambda group: np.mean(
group.to_numpy()[n_rows_diff:] - group.to_numpy()[:-n_rows_diff]
)
if len(group) > n_rows_diff
else np.nan
)

return pd.Series(differences)

real_diff = calculate_differences(real_keys, real_values, n_rows_diff)
synthetic_diff = calculate_differences(synthetic_keys, synthetic_values, n_rows_diff)
real_diff = calculate_differences(real_keys, real_values, n_rows_diff, 'real_data')
synthetic_diff = calculate_differences(
synthetic_keys, synthetic_values, n_rows_diff, 'synthetic_data'
)

return KSComplement.compute(real_diff, synthetic_diff)
20 changes: 20 additions & 0 deletions tests/unit/timeseries/test_inter_row_msas.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,23 @@ def test_compute_invalid_apply_log(self):
n_rows_diff=1,
apply_log='True', # Should be a boolean, not a string
)

def test_compute_warning(self):
"""Test a warning is raised when n_rows_diff is greater than sequence values size."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
real_values = pd.Series([1, 2, 3, 4, 5, 6])
synthetic_keys = pd.Series(['id3', 'id3', 'id3', 'id4', 'id4', 'id4'])
synthetic_values = pd.Series([1, 10, 3, 7, 5, 1])

# Run and Assert
warn_msg = "n_rows_diff '10' is greater than the size of 2 sequence keys in real_data."
with pytest.warns(UserWarning, match=warn_msg):
score = InterRowMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
n_rows_diff=10,
)

# Assert
assert pd.isna(score)

0 comments on commit 47139b1

Please sign in to comment.