Skip to content

Commit

Permalink
Add warning
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Nov 19, 2024
1 parent 0752a8b commit ee0a447
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
28 changes: 20 additions & 8 deletions sdmetrics/column_pairs/statistical/inter_row_msas.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,17 @@ def compute(real_data, synthetic_data, n_rows_diff=1, apply_log=False):
synthetic_keys, synthetic_values = synthetic_data

if apply_log:
real_values = np.log(real_values)
synthetic_values = np.log(synthetic_values)
num_invalid = sum(x <= 0 for x in pd.concat((real_values, synthetic_values)))
if num_invalid:
warnings.warn(
f'There are {num_invalid} non-positive values in your data, which cannot be '
"used with log. Consider changing 'apply_log' to False for a better result."
)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='divide by zero encountered in log')
warnings.filterwarnings('ignore', message='invalid value encountered in log')
real_values = np.log(real_values)
synthetic_values = np.log(synthetic_values)

def calculate_differences(keys, values, n_rows_diff, data_name):
group_sizes = values.groupby(keys).size()
Expand All @@ -88,13 +97,16 @@ def calculate_differences(keys, values, n_rows_diff, data_name):
f'size of {num_invalid_groups} sequence keys in {data_name}.'
)

differences = values.groupby(keys).apply(
lambda group: np.mean(
group.to_numpy()[n_rows_diff:] - group.to_numpy()[:-n_rows_diff]
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='invalid value encountered in subtract')
warnings.filterwarnings('ignore', message='invalid value encountered in reduce')
differences = values.groupby(keys).apply(
lambda group: np.mean(
group.to_numpy()[n_rows_diff:] - group.to_numpy()[:-n_rows_diff]
)
if len(group) > n_rows_diff
else np.nan
)
if len(group) > n_rows_diff
else np.nan
)

return pd.Series(differences)

Expand Down
25 changes: 25 additions & 0 deletions tests/unit/column_pairs/statistical/test_inter_row_msas.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,31 @@ def test_compute_with_log(self):
# Assert
assert score == 1

def test_compute_with_log_warning(self):
"""Test it warns when negative values are present and apply_log is True."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
real_values = pd.Series([1, 1.4, 4, -1, 16, -10])
synthetic_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
synthetic_values = pd.Series([1, 2, -4, 8, 16, 30])

# Run
with pytest.warns(UserWarning) as warning_info:
score = InterRowMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
apply_log=True,
)

# Assert
expected_message = (
'There are 3 non-positive values in your data, which cannot be used with log. '
"Consider changing 'apply_log' to False for a better result."
)
assert len(warning_info) == 1
assert str(warning_info[0].message) == expected_message
assert score == 0

def test_compute_different_n_rows_diff(self):
"""Test it with different n_rows_diff."""
# Setup
Expand Down

0 comments on commit ee0a447

Please sign in to comment.