From 325453925b7a0e9821b17c06f479d5bbd6c53ba8 Mon Sep 17 00:00:00 2001 From: Felipe Date: Wed, 30 Oct 2024 09:37:41 -0700 Subject: [PATCH] Add StatisticMSAS --- sdmetrics/timeseries/__init__.py | 2 + sdmetrics/timeseries/statistic_msas.py | 91 ++++++++++++++++++ tests/unit/timeseries/test_statistic_msas.py | 98 ++++++++++++++++++++ 3 files changed, 191 insertions(+) create mode 100644 sdmetrics/timeseries/statistic_msas.py create mode 100644 tests/unit/timeseries/test_statistic_msas.py diff --git a/sdmetrics/timeseries/__init__.py b/sdmetrics/timeseries/__init__.py index 6a09b529..c78dc54d 100644 --- a/sdmetrics/timeseries/__init__.py +++ b/sdmetrics/timeseries/__init__.py @@ -5,6 +5,7 @@ from sdmetrics.timeseries.detection import LSTMDetection, TimeSeriesDetectionMetric from sdmetrics.timeseries.efficacy import TimeSeriesEfficacyMetric from sdmetrics.timeseries.efficacy.classification import LSTMClassifierEfficacy +from sdmetrics.timeseries.statistic_msas import StatisticMSAS __all__ = [ 'base', @@ -16,4 +17,5 @@ 'LSTMDetection', 'TimeSeriesEfficacyMetric', 'LSTMClassifierEfficacy', + 'StatisticMSAS', ] diff --git a/sdmetrics/timeseries/statistic_msas.py b/sdmetrics/timeseries/statistic_msas.py new file mode 100644 index 00000000..de404db0 --- /dev/null +++ b/sdmetrics/timeseries/statistic_msas.py @@ -0,0 +1,91 @@ +"""StatisticMSAS module.""" + +import numpy as np +import pandas as pd + +from sdmetrics.goal import Goal +from sdmetrics.single_column.statistical.kscomplement import KSComplement + + +class StatisticMSAS: + """Statistic Multi-Sequence Aggregate Similarity (MSAS) metric. + + Attributes: + name (str): + Name to use when reports about this metric are printed. + goal (sdmetrics.goal.Goal): + The goal of this metric. + min_value (Union[float, tuple[float]]): + Minimum value or values that this metric can take. + max_value (Union[float, tuple[float]]): + Maximum value or values that this metric can take. + """ + + name = 'Statistic Multi-Sequence Aggregate Similarity' + goal = Goal.MAXIMIZE + min_value = 0.0 + max_value = 1.0 + + @staticmethod + def compute(real_data, synthetic_data, statistic='mean'): + """Compute this metric. + + This metric compares the distribution of a given statistic across sequences + in the real data vs. the synthetic data. + + It works as follows: + - Calculate the specified statistic for each sequence in the real data + - Form a distribution D_r from these statistics + - Do the same for the synthetic data to form a new distribution D_s + - Apply the KSComplement metric to compare the similarities of (D_r, D_s) + - Return this score + + Args: + real_data (tuple[pd.Series, pd.Series]): + A tuple of 2 pandas.Series objects. The first represents the sequence key + of the real data and the second represents a continuous column of data. + synthetic_data (tuple[pd.Series, pd.Series]): + A tuple of 2 pandas.Series objects. The first represents the sequence key + of the synthetic data and the second represents a continuous column of data. + statistic (str): + A string representing the statistic function to use when computing MSAS. + + Available options are: + - 'mean': The arithmetic mean of the sequence + - 'median': The median value of the sequence + - 'std': The standard deviation of the sequence + - 'min': The minimum value in the sequence + - 'max': The maximum value in the sequence + + Returns: + float: + The similarity score between the real and synthetic data distributions. + """ + statistic_functions = { + 'mean': np.mean, + 'median': np.median, + 'std': np.std, + 'min': np.min, + 'max': np.max, + } + if statistic not in statistic_functions: + raise ValueError( + f'Invalid statistic: {statistic}.' + f' Choose from [{", ".join(statistic_functions.keys())}].' + ) + + real_keys, real_values = real_data + synthetic_keys, synthetic_values = synthetic_data + stat_func = statistic_functions[statistic] + + def calculate_statistics(keys, values): + statistics = [] + for key in keys.unique(): + group_values = values[keys == key].to_numpy() + statistics.append(stat_func(group_values)) + return pd.Series(statistics) + + real_stats = calculate_statistics(real_keys, real_values) + synthetic_stats = calculate_statistics(synthetic_keys, synthetic_values) + + return KSComplement.compute(real_stats, synthetic_stats) diff --git a/tests/unit/timeseries/test_statistic_msas.py b/tests/unit/timeseries/test_statistic_msas.py new file mode 100644 index 00000000..c35b5b1e --- /dev/null +++ b/tests/unit/timeseries/test_statistic_msas.py @@ -0,0 +1,98 @@ +import re + +import pandas as pd +import pytest + +from sdmetrics.timeseries import StatisticMSAS + + +class TestStatisticMSAS: + def test_compute_identical_sequences(self): + """Test it returns 1 when real and synthetic data are identical.""" + # Setup + real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2']) + real_values = pd.Series([1, 2, 3, 4, 5, 6]) + synthetic_keys = pd.Series(['id3', 'id3', 'id3', 'id4', 'id4', 'id4']) + synthetic_values = pd.Series([1, 2, 3, 4, 5, 6]) + + # Run and Assert + for statistic in ['mean', 'median', 'std', 'min', 'max']: + score = StatisticMSAS.compute( + real_data=(real_keys, real_values), + synthetic_data=(synthetic_keys, synthetic_values), + statistic=statistic, + ) + assert score == 1 + + def test_compute_different_sequences(self): + """Test it for distinct distributions.""" + # Setup + # Setup + real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2']) + real_values = pd.Series([1, 2, 3, 4, 5, 6]) + synthetic_keys = pd.Series(['id3', 'id3', 'id3', 'id4', 'id4', 'id4']) + synthetic_values = pd.Series([10, 20, 30, 40, 50, 60]) + + # Run and Assert + for statistic in ['mean', 'median', 'std', 'min', 'max']: + score = StatisticMSAS.compute( + real_data=(real_keys, real_values), + synthetic_data=(synthetic_keys, synthetic_values), + statistic=statistic, + ) + assert score == 0 + + def test_compute_with_single_sequence(self): + """Test it with a single sequence.""" + # Setup + real_keys = pd.Series(['id1', 'id1', 'id1']) + real_values = pd.Series([1, 2, 3]) + synthetic_keys = pd.Series(['id2', 'id2', 'id2']) + synthetic_values = pd.Series([1, 2, 3]) + + # Run + score = StatisticMSAS.compute( + real_data=(real_keys, real_values), + synthetic_data=(synthetic_keys, synthetic_values), + statistic='mean', + ) + + # Assert + assert score == 1 + + def test_compute_with_different_sequence_lengths(self): + """Test it with different sequence lengths.""" + # Setup + real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2']) + real_values = pd.Series([1, 2, 3, 4, 5]) + synthetic_keys = pd.Series(['id2', 'id2', 'id3', 'id4', 'id5']) + synthetic_values = pd.Series([1, 2, 3, 4, 5]) + + # Run + score = StatisticMSAS.compute( + real_data=(real_keys, real_values), + synthetic_data=(synthetic_keys, synthetic_values), + statistic='mean', + ) + + # Assert + assert score == 0.75 + + def test_compute_with_invalid_statistic(self): + """Test it raises ValueError for invalid statistic.""" + # Setup + real_keys = pd.Series(['id1', 'id1', 'id1']) + real_values = pd.Series([1, 2, 3]) + synthetic_keys = pd.Series(['id2', 'id2', 'id2']) + synthetic_values = pd.Series([1, 2, 3]) + + # Run and Assert + err_msg = re.escape( + 'Invalid statistic: invalid. Choose from [mean, median, std, min, max].' + ) + with pytest.raises(ValueError, match=err_msg): + StatisticMSAS.compute( + real_data=(real_keys, real_values), + synthetic_data=(synthetic_keys, synthetic_values), + statistic='invalid', + )