Skip to content

Commit

Permalink
Fix ordering of the metric
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Oct 28, 2024
1 parent b9fab0e commit 0e707c2
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 6 deletions.
2 changes: 2 additions & 0 deletions sdmetrics/timeseries/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sdmetrics.timeseries.detection import LSTMDetection, TimeSeriesDetectionMetric
from sdmetrics.timeseries.efficacy import TimeSeriesEfficacyMetric
from sdmetrics.timeseries.efficacy.classification import LSTMClassifierEfficacy
from sdmetrics.timeseries.sequence_length_similarity import SequenceLengthSimilarity

__all__ = [
'base',
Expand All @@ -16,4 +17,5 @@
'LSTMDetection',
'TimeSeriesEfficacyMetric',
'LSTMClassifierEfficacy',
'SequenceLengthSimilarity',
]
13 changes: 10 additions & 3 deletions sdmetrics/timeseries/sequence_length_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class SequenceLengthSimilarity:
Maximum value or values that this metric can take.
"""

name = 'BayesianNetwork Likelihood'
name = 'Sequence Length Similarity'
goal = Goal.MAXIMIZE
min_value = 0.0
max_value = 1.0
Expand Down Expand Up @@ -50,7 +50,14 @@ def compute(real_data: pd.Series, synthetic_data: pd.Series) -> float:
float:
Mean of the log probabilities returned by the Bayesian Network.
"""
real_lengths = real_data.value_counts().to_numpy()
synthetic_lengths = synthetic_data.value_counts().to_numpy()
real_lengths = real_data.value_counts(sort=False)
synthetic_lengths = synthetic_data.value_counts(sort=False)

all_indexes = real_lengths.index.union(synthetic_lengths.index)
real_lengths = real_lengths.reindex(all_indexes, fill_value=0)
synthetic_lengths = synthetic_lengths.reindex(all_indexes, fill_value=0)

real_lengths = real_lengths.sort_index()
synthetic_lengths = synthetic_lengths.sort_index()

return KSComplement.compute(real_lengths, synthetic_lengths)
18 changes: 15 additions & 3 deletions tests/unit/timeseries/test_sequence_length_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,23 @@ def test_compute_one(self):
def test_compute_low_score(self):
"""Test it for distinct distributions."""
# Setup
real_data = pd.Series(['id1', 'id1', 'id2'])
synthetic_data = pd.Series(['id1', 'id2', 'id3'])
real_data = pd.Series([f'id{i}' for i in range(100)])
synthetic_data = pd.Series(['id1'] * 100)

# Run
score = SequenceLengthSimilarity.compute(real_data, synthetic_data)

# Assert
assert score == 0.5
assert score == 0.010000000000000009

def test_compute_one_difference_sequences(self):
"""Test it returns one for distinct distributions when they are sorted."""
# Setup
real_data = pd.Series(['id1', 'id1', 'id1'])
synthetic_data = pd.Series(['id2', 'id2', 'id2'])

# Run
score = SequenceLengthSimilarity.compute(real_data, synthetic_data)

# Assert
assert score == 1

0 comments on commit 0e707c2

Please sign in to comment.