Skip to content

Commit

Permalink
Cast values to datetime
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Jul 10, 2024
1 parent af28309 commit 2bc5506
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
10 changes: 10 additions & 0 deletions sdmetrics/timeseries/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from operator import attrgetter

import pandas as pd

from sdmetrics.base import BaseMetric
from sdmetrics.utils import get_columns_from_metadata

Expand Down Expand Up @@ -61,6 +63,14 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None, sequence_key
if field not in real_data.columns:
raise ValueError(f'Field {field} not found in data')

for column, sdtype in metadata['columns'].items():
if sdtype['sdtype'] == 'datetime':
try:
real_data[column] = pd.to_datetime(real_data[column])
synthetic_data[column] = pd.to_datetime(synthetic_data[column])
except ValueError:
raise ValueError(f"Column '{column}' is not a valid datetime")

else:
dtype_kinds = real_data.dtypes.apply(attrgetter('kind'))
metadata = {'columns': dtype_kinds.apply(cls._DTYPES_TO_TYPES.get).to_dict()}
Expand Down
1 change: 1 addition & 0 deletions sdmetrics/timeseries/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def compute(cls, real_data, synthetic_data, metadata=None, sequence_key=None):
Union[float, tuple[float]]:
Metric output.
"""
real_data, synthetic_data = real_data.copy(), synthetic_data.copy()
_, sequence_key = cls._validate_inputs(real_data, synthetic_data, metadata, sequence_key)

ht = HyperTransformer()
Expand Down
35 changes: 35 additions & 0 deletions tests/integration/timeseries/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,38 @@ def test_compute_lstmdetection_multiple_categorical_columns():
# Assert
assert not pd.isna(output)
assert LSTMDetection.min_value <= output <= LSTMDetection.max_value


def test_compute_lstmdetection_mismatching_datetime_columns():
"""Test LSTMDetection metric with mismatching datetime columns.
Test it when the real data has a date column and the synthetic data has a string column.
"""
# Setup
df1 = pd.DataFrame( {
's_key': [1, 2, 3, 4, 5],
'visits': pd.to_datetime(['1/1/2019', '1/2/2019', '1/3/2019', '1/4/2019', '1/5/2019'])
})
df1['visits'] = df1['visits'].dt.date
df2 = pd.DataFrame({
's_key': [1, 2, 3, 4, 5],
'visits': ['1/2/2019', '1/2/2019', '1/3/2019', '1/4/2019', '1/5/2019']
})
metadata = {
'columns': {
's_key': {'sdtype': 'numerical'},
'visits': {'sdtype': 'datetime'}
},
'sequence_key': 's_key'
}

# Run
score = LSTMDetection.compute(
real_data=df1,
synthetic_data=df2,
sequence_key=['s_key'],
metadata=metadata
)

# Assert
assert score == 0.6666666666666667

0 comments on commit 2bc5506

Please sign in to comment.