diff --git a/sdmetrics/column_pairs/statistical/inter_row_msas.py b/sdmetrics/column_pairs/statistical/inter_row_msas.py index 0fbcccb7..0817902c 100644 --- a/sdmetrics/column_pairs/statistical/inter_row_msas.py +++ b/sdmetrics/column_pairs/statistical/inter_row_msas.py @@ -1,6 +1,7 @@ """InterRowMSAS module.""" import warnings +from datetime import datetime import numpy as np import pandas as pd @@ -76,6 +77,13 @@ def compute(real_data, synthetic_data, n_rows_diff=1, apply_log=False): synthetic_keys, synthetic_values = synthetic_data if apply_log: + if (len(real_values) > 0 and isinstance(real_values[0], datetime)) or ( + len(synthetic_values) > 0 and isinstance(synthetic_values[0], datetime) + ): + raise TypeError( + 'Cannot compute log for datetime columns. ' + "Please set 'apply_log' to False to use this metric." + ) num_invalid = sum(x <= 0 for x in pd.concat((real_values, synthetic_values))) if num_invalid: warnings.warn( diff --git a/tests/unit/column_pairs/statistical/test_inter_row_msas.py b/tests/unit/column_pairs/statistical/test_inter_row_msas.py index a88e375f..647b6569 100644 --- a/tests/unit/column_pairs/statistical/test_inter_row_msas.py +++ b/tests/unit/column_pairs/statistical/test_inter_row_msas.py @@ -1,3 +1,5 @@ +from datetime import datetime + import pandas as pd import pytest @@ -96,6 +98,26 @@ def test_compute_with_log_warning(self): assert str(warning_info[0].message) == expected_message assert score == 0 + def test_compute_with_log_datetime(self): + """Test it crashes for logs of datetime values.""" + # Setup + real_keys = pd.Series(['id1', 'id1']) + real_values = pd.Series([datetime(2020, 10, 1), datetime(2020, 10, 1)]) + synthetic_keys = pd.Series(['id2', 'id2']) + synthetic_values = pd.Series([datetime(2020, 10, 1), datetime(2020, 10, 1)]) + + # Run and Assert + err_msg = ( + 'Cannot compute log for datetime columns. ' + "Please set 'apply_log' to False to use this metric." + ) + with pytest.raises(TypeError, match=err_msg): + InterRowMSAS.compute( + real_data=(real_keys, real_values), + synthetic_data=(synthetic_keys, synthetic_values), + apply_log=True, + ) + def test_compute_different_n_rows_diff(self): """Test it with different n_rows_diff.""" # Setup