Skip to content

Commit

Permalink
Add StatisticMSAS
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Oct 30, 2024
1 parent f58f443 commit 783e68d
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 1 deletion.
3 changes: 2 additions & 1 deletion sdmetrics/timeseries/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sdmetrics.timeseries.detection import LSTMDetection, TimeSeriesDetectionMetric
from sdmetrics.timeseries.efficacy import TimeSeriesEfficacyMetric
from sdmetrics.timeseries.efficacy.classification import LSTMClassifierEfficacy

from sdmetrics.timeseries.statistic_msas import StatisticMSAS
__all__ = [
'base',
'detection',
Expand All @@ -16,4 +16,5 @@
'LSTMDetection',
'TimeSeriesEfficacyMetric',
'LSTMClassifierEfficacy',
'StatisticMSAS',
]
91 changes: 91 additions & 0 deletions sdmetrics/timeseries/statistic_msas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""StatisticMSAS module."""

import numpy as np
import pandas as pd

from sdmetrics.goal import Goal
from sdmetrics.single_column.statistical.kscomplement import KSComplement


class StatisticMSAS:
"""Statistic Multi-Sequence Aggregate Similarity (MSAS) metric.
Attributes:
name (str):
Name to use when reports about this metric are printed.
goal (sdmetrics.goal.Goal):
The goal of this metric.
min_value (Union[float, tuple[float]]):
Minimum value or values that this metric can take.
max_value (Union[float, tuple[float]]):
Maximum value or values that this metric can take.
"""

name = 'Statistic Multi-Sequence Aggregate Similarity'
goal = Goal.MAXIMIZE
min_value = 0.0
max_value = 1.0

@staticmethod
def compute(real_data, synthetic_data, statistic='mean'):
"""Compute this metric.
This metric compares the distribution of a given statistic across sequences
in the real data vs. the synthetic data.
It works as follows:
- Calculate the specified statistic for each sequence in the real data
- Form a distribution D_r from these statistics
- Do the same for the synthetic data to form a new distribution D_s
- Apply the KSComplement metric to compare the similarities of (D_r, D_s)
- Return this score
Args:
real_data (tuple[pd.Series, pd.Series]):
A tuple of 2 pandas.Series objects. The first represents the sequence key
of the real data and the second represents a continuous column of data.
synthetic_data (tuple[pd.Series, pd.Series]):
A tuple of 2 pandas.Series objects. The first represents the sequence key
of the synthetic data and the second represents a continuous column of data.
statistic (str):
A string representing the statistic function to use when computing MSAS.
Available options are:
- 'mean': The arithmetic mean of the sequence
- 'median': The median value of the sequence
- 'std': The standard deviation of the sequence
- 'min': The minimum value in the sequence
- 'max': The maximum value in the sequence
Returns:
float:
The similarity score between the real and synthetic data distributions.
"""
statistic_functions = {
'mean': np.mean,
'median': np.median,
'std': np.std,
'min': np.min,
'max': np.max,
}
if statistic not in statistic_functions:
raise ValueError(
f'Invalid statistic: {statistic}.'
f" Choose from [{', '.join(statistic_functions.keys())}]."
)

real_keys, real_values = real_data
synthetic_keys, synthetic_values = synthetic_data
stat_func = statistic_functions[statistic]

def calculate_statistics(keys, values):
statistics = []
for key in keys.unique():
group_values = values[keys == key].to_numpy()
statistics.append(stat_func(group_values))
return pd.Series(statistics)

real_stats = calculate_statistics(real_keys, real_values)
synthetic_stats = calculate_statistics(synthetic_keys, synthetic_values)

return KSComplement.compute(real_stats, synthetic_stats)
96 changes: 96 additions & 0 deletions tests/unit/timeseries/test_statistic_msas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import re
import pandas as pd
import pytest

from sdmetrics.timeseries import StatisticMSAS


class TestStatisticMSAS:
def test_compute_identical_sequences(self):
"""Test it returns 1 when real and synthetic data are identical."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
real_values = pd.Series([1, 2, 3, 4, 5, 6])
synthetic_keys = pd.Series(['id3', 'id3', 'id3', 'id4', 'id4', 'id4'])
synthetic_values = pd.Series([1, 2, 3, 4, 5, 6])

# Run and Assert
for statistic in ['mean', 'median', 'std', 'min', 'max']:
score = StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
statistic=statistic,
)
assert score == 1

def test_compute_different_sequences(self):
"""Test it for distinct distributions."""
# Setup
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
real_values = pd.Series([1, 2, 3, 4, 5, 6])
synthetic_keys = pd.Series(['id3', 'id3', 'id3', 'id4', 'id4', 'id4'])
synthetic_values = pd.Series([10, 20, 30, 40, 50, 60])

# Run and Assert
for statistic in ['mean', 'median', 'std', 'min', 'max']:
score = StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
statistic=statistic,
)
assert score == 0

def test_compute_with_single_sequence(self):
"""Test it with a single sequence."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1'])
real_values = pd.Series([1, 2, 3])
synthetic_keys = pd.Series(['id2', 'id2', 'id2'])
synthetic_values = pd.Series([1, 2, 3])

# Run
score = StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
statistic='mean',
)

# Assert
assert score == 1

def test_compute_with_different_sequence_lengths(self):
"""Test it with different sequence lengths."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2'])
real_values = pd.Series([1, 2, 3, 4, 5])
synthetic_keys = pd.Series(['id2', 'id2', 'id3', 'id4', 'id5'])
synthetic_values = pd.Series([1, 2, 3, 4, 5])

# Run
score = StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
statistic='mean',
)

# Assert
assert score == .75

def test_compute_with_invalid_statistic(self):
"""Test it raises ValueError for invalid statistic."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1'])
real_values = pd.Series([1, 2, 3])
synthetic_keys = pd.Series(['id2', 'id2', 'id2'])
synthetic_values = pd.Series([1, 2, 3])

# Run and Assert
err_msg = re.escape(
"Invalid statistic: invalid. Choose from [mean, median, std, min, max].")
with pytest.raises(ValueError, match=err_msg):
StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
statistic='invalid',
)

0 comments on commit 783e68d

Please sign in to comment.