Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve warning handling for non-positive values when apply_log = True for InterRowMSAS #671

Merged
merged 4 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 59 additions & 40 deletions sdmetrics/column_pairs/statistical/inter_row_msas.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,61 @@
max_value = 1.0

@staticmethod
def compute(real_data, synthetic_data, n_rows_diff=1, apply_log=False):
def _validate_inputs(real_data, synthetic_data, n_rows_diff, apply_log):
for data in [real_data, synthetic_data]:
if (

Check warning on line 34 in sdmetrics/column_pairs/statistical/inter_row_msas.py

View check run for this annotation

Codecov / codecov/patch

sdmetrics/column_pairs/statistical/inter_row_msas.py#L33-L34

Added lines #L33 - L34 were not covered by tests
not isinstance(data, tuple)
or len(data) != 2
or (not (isinstance(data[0], pd.Series) and isinstance(data[1], pd.Series)))
):
raise ValueError('The data must be a tuple of two pandas series.')

Check warning on line 39 in sdmetrics/column_pairs/statistical/inter_row_msas.py

View check run for this annotation

Codecov / codecov/patch

sdmetrics/column_pairs/statistical/inter_row_msas.py#L39

Added line #L39 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we added tests for these 3 exceptions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests have already been implemented. This change simply reorganizes the function for readability.


if not isinstance(n_rows_diff, int) or n_rows_diff < 1:
raise ValueError("'n_rows_diff' must be an integer greater than zero.")

Check warning on line 42 in sdmetrics/column_pairs/statistical/inter_row_msas.py

View check run for this annotation

Codecov / codecov/patch

sdmetrics/column_pairs/statistical/inter_row_msas.py#L41-L42

Added lines #L41 - L42 were not covered by tests

if not isinstance(apply_log, bool):
raise ValueError("'apply_log' must be a boolean.")

Check warning on line 45 in sdmetrics/column_pairs/statistical/inter_row_msas.py

View check run for this annotation

Codecov / codecov/patch

sdmetrics/column_pairs/statistical/inter_row_msas.py#L44-L45

Added lines #L44 - L45 were not covered by tests

@staticmethod
def _apply_log(real_values, synthetic_values, apply_log):
if apply_log:
num_invalid = sum(x <= 0 for x in pd.concat((real_values, synthetic_values)))
if num_invalid:
warnings.warn(

Check warning on line 52 in sdmetrics/column_pairs/statistical/inter_row_msas.py

View check run for this annotation

Codecov / codecov/patch

sdmetrics/column_pairs/statistical/inter_row_msas.py#L49-L52

Added lines #L49 - L52 were not covered by tests
f'There are {num_invalid} non-positive values in your data, which cannot be '
"used with log. Consider changing 'apply_log' to False for a better result."
)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='.*encountered in log')
real_values = np.log(real_values)
synthetic_values = np.log(synthetic_values)

Check warning on line 59 in sdmetrics/column_pairs/statistical/inter_row_msas.py

View check run for this annotation

Codecov / codecov/patch

sdmetrics/column_pairs/statistical/inter_row_msas.py#L56-L59

Added lines #L56 - L59 were not covered by tests

return real_values, synthetic_values

Check warning on line 61 in sdmetrics/column_pairs/statistical/inter_row_msas.py

View check run for this annotation

Codecov / codecov/patch

sdmetrics/column_pairs/statistical/inter_row_msas.py#L61

Added line #L61 was not covered by tests

@staticmethod
def _calculate_differences(keys, values, n_rows_diff, data_name):
grouped = values.groupby(keys)
group_sizes = grouped.size()

Check warning on line 66 in sdmetrics/column_pairs/statistical/inter_row_msas.py

View check run for this annotation

Codecov / codecov/patch

sdmetrics/column_pairs/statistical/inter_row_msas.py#L65-L66

Added lines #L65 - L66 were not covered by tests

num_invalid_groups = len(group_sizes[group_sizes <= n_rows_diff])
if num_invalid_groups > 0:
warnings.warn(

Check warning on line 70 in sdmetrics/column_pairs/statistical/inter_row_msas.py

View check run for this annotation

Codecov / codecov/patch

sdmetrics/column_pairs/statistical/inter_row_msas.py#L68-L70

Added lines #L68 - L70 were not covered by tests
f"n_rows_diff '{n_rows_diff}' is greater than the "
f'size of {num_invalid_groups} sequence keys in {data_name}.'
)

def diff_func(group):
if len(group) <= n_rows_diff:
return np.nan
group = group.to_numpy()
return np.mean(group[n_rows_diff:] - group[:-n_rows_diff])

Check warning on line 79 in sdmetrics/column_pairs/statistical/inter_row_msas.py

View check run for this annotation

Codecov / codecov/patch

sdmetrics/column_pairs/statistical/inter_row_msas.py#L75-L79

Added lines #L75 - L79 were not covered by tests

with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='invalid value encountered in.*')
return grouped.apply(diff_func)

Check warning on line 83 in sdmetrics/column_pairs/statistical/inter_row_msas.py

View check run for this annotation

Codecov / codecov/patch

sdmetrics/column_pairs/statistical/inter_row_msas.py#L81-L83

Added lines #L81 - L83 were not covered by tests

@classmethod
def compute(cls, real_data, synthetic_data, n_rows_diff=1, apply_log=False):
"""Compute this metric.

This metric compares the inter-row differences of sequences in the real data
Expand Down Expand Up @@ -58,48 +112,13 @@
float:
The similarity score between the real and synthetic data distributions.
"""
for data in [real_data, synthetic_data]:
if (
not isinstance(data, tuple)
or len(data) != 2
or (not (isinstance(data[0], pd.Series) and isinstance(data[1], pd.Series)))
):
raise ValueError('The data must be a tuple of two pandas series.')

if not isinstance(n_rows_diff, int) or n_rows_diff < 1:
raise ValueError("'n_rows_diff' must be an integer greater than zero.")

if not isinstance(apply_log, bool):
raise ValueError("'apply_log' must be a boolean.")

cls._validate_inputs(real_data, synthetic_data, n_rows_diff, apply_log)

Check warning on line 115 in sdmetrics/column_pairs/statistical/inter_row_msas.py

View check run for this annotation

Codecov / codecov/patch

sdmetrics/column_pairs/statistical/inter_row_msas.py#L115

Added line #L115 was not covered by tests
real_keys, real_values = real_data
synthetic_keys, synthetic_values = synthetic_data
real_values, synthetic_values = cls._apply_log(real_values, synthetic_values, apply_log)

Check warning on line 118 in sdmetrics/column_pairs/statistical/inter_row_msas.py

View check run for this annotation

Codecov / codecov/patch

sdmetrics/column_pairs/statistical/inter_row_msas.py#L118

Added line #L118 was not covered by tests

if apply_log:
real_values = np.log(real_values)
synthetic_values = np.log(synthetic_values)

def calculate_differences(keys, values, n_rows_diff, data_name):
group_sizes = values.groupby(keys).size()
num_invalid_groups = group_sizes[group_sizes <= n_rows_diff].count()
if num_invalid_groups > 0:
warnings.warn(
f"n_rows_diff '{n_rows_diff}' is greater than the "
f'size of {num_invalid_groups} sequence keys in {data_name}.'
)

differences = values.groupby(keys).apply(
lambda group: np.mean(
group.to_numpy()[n_rows_diff:] - group.to_numpy()[:-n_rows_diff]
)
if len(group) > n_rows_diff
else np.nan
)

return pd.Series(differences)

real_diff = calculate_differences(real_keys, real_values, n_rows_diff, 'real_data')
synthetic_diff = calculate_differences(
real_diff = cls._calculate_differences(real_keys, real_values, n_rows_diff, 'real_data')
synthetic_diff = cls._calculate_differences(

Check warning on line 121 in sdmetrics/column_pairs/statistical/inter_row_msas.py

View check run for this annotation

Codecov / codecov/patch

sdmetrics/column_pairs/statistical/inter_row_msas.py#L120-L121

Added lines #L120 - L121 were not covered by tests
synthetic_keys, synthetic_values, n_rows_diff, 'synthetic_data'
)

Expand Down
25 changes: 25 additions & 0 deletions tests/unit/column_pairs/statistical/test_inter_row_msas.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,31 @@ def test_compute_with_log(self):
# Assert
assert score == 1

def test_compute_with_log_warning(self):
"""Test it warns when negative values are present and apply_log is True."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
real_values = pd.Series([1, 1.4, 4, -1, 16, -10])
synthetic_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
synthetic_values = pd.Series([1, 2, -4, 8, 16, 30])

# Run
with pytest.warns(UserWarning) as warning_info:
score = InterRowMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
apply_log=True,
)

# Assert
expected_message = (
'There are 3 non-positive values in your data, which cannot be used with log. '
"Consider changing 'apply_log' to False for a better result."
)
assert len(warning_info) == 1
assert str(warning_info[0].message) == expected_message
assert score == 0

def test_compute_different_n_rows_diff(self):
"""Test it with different n_rows_diff."""
# Setup
Expand Down
Loading