Skip to content

Commit

Permalink
Add error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Nov 19, 2024
1 parent 835283a commit f5fdd37
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
8 changes: 8 additions & 0 deletions sdmetrics/column_pairs/statistical/inter_row_msas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""InterRowMSAS module."""

import warnings
from datetime import datetime

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -76,6 +77,13 @@ def compute(real_data, synthetic_data, n_rows_diff=1, apply_log=False):
synthetic_keys, synthetic_values = synthetic_data

if apply_log:
if (len(real_values) > 0 and isinstance(real_values[0], datetime)) or (
len(synthetic_values) > 0 and isinstance(synthetic_values[0], datetime)
):
raise TypeError(
'Cannot compute log for datetime columns. '
"Please set 'apply_log' to False to use this metric."
)
num_invalid = sum(x <= 0 for x in pd.concat((real_values, synthetic_values)))
if num_invalid:
warnings.warn(
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/column_pairs/statistical/test_inter_row_msas.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

import pandas as pd
import pytest

Expand Down Expand Up @@ -96,6 +98,26 @@ def test_compute_with_log_warning(self):
assert str(warning_info[0].message) == expected_message
assert score == 0

def test_compute_with_log_datetime(self):
"""Test it crashes for logs of datetime values."""
# Setup
real_keys = pd.Series(['id1', 'id1'])
real_values = pd.Series([datetime(2020, 10, 1), datetime(2020, 10, 1)])
synthetic_keys = pd.Series(['id2', 'id2'])
synthetic_values = pd.Series([datetime(2020, 10, 1), datetime(2020, 10, 1)])

# Run and Assert
err_msg = (
'Cannot compute log for datetime columns. '
"Please set 'apply_log' to False to use this metric."
)
with pytest.raises(TypeError, match=err_msg):
InterRowMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
apply_log=True,
)

def test_compute_different_n_rows_diff(self):
"""Test it with different n_rows_diff."""
# Setup
Expand Down

0 comments on commit f5fdd37

Please sign in to comment.