From ee0a447e7812c458c0451d7c3ce3cfeb0df9f2f2 Mon Sep 17 00:00:00 2001 From: Felipe <fealho@gmail.com> Date: Tue, 19 Nov 2024 09:08:21 -0800 Subject: [PATCH] Add warning --- .../statistical/inter_row_msas.py | 28 +++++++++++++------ .../statistical/test_inter_row_msas.py | 25 +++++++++++++++++ 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/sdmetrics/column_pairs/statistical/inter_row_msas.py b/sdmetrics/column_pairs/statistical/inter_row_msas.py index eea77f06..0fbcccb7 100644 --- a/sdmetrics/column_pairs/statistical/inter_row_msas.py +++ b/sdmetrics/column_pairs/statistical/inter_row_msas.py @@ -76,8 +76,17 @@ def compute(real_data, synthetic_data, n_rows_diff=1, apply_log=False): synthetic_keys, synthetic_values = synthetic_data if apply_log: - real_values = np.log(real_values) - synthetic_values = np.log(synthetic_values) + num_invalid = sum(x <= 0 for x in pd.concat((real_values, synthetic_values))) + if num_invalid: + warnings.warn( + f'There are {num_invalid} non-positive values in your data, which cannot be ' + "used with log. Consider changing 'apply_log' to False for a better result." + ) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message='divide by zero encountered in log') + warnings.filterwarnings('ignore', message='invalid value encountered in log') + real_values = np.log(real_values) + synthetic_values = np.log(synthetic_values) def calculate_differences(keys, values, n_rows_diff, data_name): group_sizes = values.groupby(keys).size() @@ -88,13 +97,16 @@ def calculate_differences(keys, values, n_rows_diff, data_name): f'size of {num_invalid_groups} sequence keys in {data_name}.' ) - differences = values.groupby(keys).apply( - lambda group: np.mean( - group.to_numpy()[n_rows_diff:] - group.to_numpy()[:-n_rows_diff] + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message='invalid value encountered in subtract') + warnings.filterwarnings('ignore', message='invalid value encountered in reduce') + differences = values.groupby(keys).apply( + lambda group: np.mean( + group.to_numpy()[n_rows_diff:] - group.to_numpy()[:-n_rows_diff] + ) + if len(group) > n_rows_diff + else np.nan ) - if len(group) > n_rows_diff - else np.nan - ) return pd.Series(differences) diff --git a/tests/unit/column_pairs/statistical/test_inter_row_msas.py b/tests/unit/column_pairs/statistical/test_inter_row_msas.py index 9a3552db..a88e375f 100644 --- a/tests/unit/column_pairs/statistical/test_inter_row_msas.py +++ b/tests/unit/column_pairs/statistical/test_inter_row_msas.py @@ -71,6 +71,31 @@ def test_compute_with_log(self): # Assert assert score == 1 + def test_compute_with_log_warning(self): + """Test it warns when negative values are present and apply_log is True.""" + # Setup + real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2']) + real_values = pd.Series([1, 1.4, 4, -1, 16, -10]) + synthetic_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2']) + synthetic_values = pd.Series([1, 2, -4, 8, 16, 30]) + + # Run + with pytest.warns(UserWarning) as warning_info: + score = InterRowMSAS.compute( + real_data=(real_keys, real_values), + synthetic_data=(synthetic_keys, synthetic_values), + apply_log=True, + ) + + # Assert + expected_message = ( + 'There are 3 non-positive values in your data, which cannot be used with log. ' + "Consider changing 'apply_log' to False for a better result." + ) + assert len(warning_info) == 1 + assert str(warning_info[0].message) == expected_message + assert score == 0 + def test_compute_different_n_rows_diff(self): """Test it with different n_rows_diff.""" # Setup