From ee0a447e7812c458c0451d7c3ce3cfeb0df9f2f2 Mon Sep 17 00:00:00 2001
From: Felipe <fealho@gmail.com>
Date: Tue, 19 Nov 2024 09:08:21 -0800
Subject: [PATCH] Add warning

---
 .../statistical/inter_row_msas.py             | 28 +++++++++++++------
 .../statistical/test_inter_row_msas.py        | 25 +++++++++++++++++
 2 files changed, 45 insertions(+), 8 deletions(-)

diff --git a/sdmetrics/column_pairs/statistical/inter_row_msas.py b/sdmetrics/column_pairs/statistical/inter_row_msas.py
index eea77f06..0fbcccb7 100644
--- a/sdmetrics/column_pairs/statistical/inter_row_msas.py
+++ b/sdmetrics/column_pairs/statistical/inter_row_msas.py
@@ -76,8 +76,17 @@ def compute(real_data, synthetic_data, n_rows_diff=1, apply_log=False):
         synthetic_keys, synthetic_values = synthetic_data
 
         if apply_log:
-            real_values = np.log(real_values)
-            synthetic_values = np.log(synthetic_values)
+            num_invalid = sum(x <= 0 for x in pd.concat((real_values, synthetic_values)))
+            if num_invalid:
+                warnings.warn(
+                    f'There are {num_invalid} non-positive values in your data, which cannot be '
+                    "used with log. Consider changing 'apply_log' to False for a better result."
+                )
+            with warnings.catch_warnings():
+                warnings.filterwarnings('ignore', message='divide by zero encountered in log')
+                warnings.filterwarnings('ignore', message='invalid value encountered in log')
+                real_values = np.log(real_values)
+                synthetic_values = np.log(synthetic_values)
 
         def calculate_differences(keys, values, n_rows_diff, data_name):
             group_sizes = values.groupby(keys).size()
@@ -88,13 +97,16 @@ def calculate_differences(keys, values, n_rows_diff, data_name):
                     f'size of {num_invalid_groups} sequence keys in {data_name}.'
                 )
 
-            differences = values.groupby(keys).apply(
-                lambda group: np.mean(
-                    group.to_numpy()[n_rows_diff:] - group.to_numpy()[:-n_rows_diff]
+            with warnings.catch_warnings():
+                warnings.filterwarnings('ignore', message='invalid value encountered in subtract')
+                warnings.filterwarnings('ignore', message='invalid value encountered in reduce')
+                differences = values.groupby(keys).apply(
+                    lambda group: np.mean(
+                        group.to_numpy()[n_rows_diff:] - group.to_numpy()[:-n_rows_diff]
+                    )
+                    if len(group) > n_rows_diff
+                    else np.nan
                 )
-                if len(group) > n_rows_diff
-                else np.nan
-            )
 
             return pd.Series(differences)
 
diff --git a/tests/unit/column_pairs/statistical/test_inter_row_msas.py b/tests/unit/column_pairs/statistical/test_inter_row_msas.py
index 9a3552db..a88e375f 100644
--- a/tests/unit/column_pairs/statistical/test_inter_row_msas.py
+++ b/tests/unit/column_pairs/statistical/test_inter_row_msas.py
@@ -71,6 +71,31 @@ def test_compute_with_log(self):
         # Assert
         assert score == 1
 
+    def test_compute_with_log_warning(self):
+        """Test it warns when negative values are present and apply_log is True."""
+        # Setup
+        real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
+        real_values = pd.Series([1, 1.4, 4, -1, 16, -10])
+        synthetic_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
+        synthetic_values = pd.Series([1, 2, -4, 8, 16, 30])
+
+        # Run
+        with pytest.warns(UserWarning) as warning_info:
+            score = InterRowMSAS.compute(
+                real_data=(real_keys, real_values),
+                synthetic_data=(synthetic_keys, synthetic_values),
+                apply_log=True,
+            )
+
+        # Assert
+        expected_message = (
+            'There are 3 non-positive values in your data, which cannot be used with log. '
+            "Consider changing 'apply_log' to False for a better result."
+        )
+        assert len(warning_info) == 1
+        assert str(warning_info[0].message) == expected_message
+        assert score == 0
+
     def test_compute_different_n_rows_diff(self):
         """Test it with different n_rows_diff."""
         # Setup