From 44de4c5b07226cb8b2d18ebfc2e1a041c6339da4 Mon Sep 17 00:00:00 2001 From: Felipe Date: Thu, 31 Oct 2024 03:52:14 -0700 Subject: [PATCH] Use groupby and add data validation --- sdmetrics/timeseries/statistic_msas.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/sdmetrics/timeseries/statistic_msas.py b/sdmetrics/timeseries/statistic_msas.py index de404db0..8afab764 100644 --- a/sdmetrics/timeseries/statistic_msas.py +++ b/sdmetrics/timeseries/statistic_msas.py @@ -74,16 +74,21 @@ def compute(real_data, synthetic_data, statistic='mean'): f' Choose from [{", ".join(statistic_functions.keys())}].' ) + for data in [real_data, synthetic_data]: + if ( + not isinstance(data, tuple) + or len(data) != 2 + or (not (isinstance(data[0], pd.Series) and isinstance(data[1], pd.Series))) + ): + raise ValueError('The data must be a tuple of two pandas series.') + real_keys, real_values = real_data synthetic_keys, synthetic_values = synthetic_data stat_func = statistic_functions[statistic] def calculate_statistics(keys, values): - statistics = [] - for key in keys.unique(): - group_values = values[keys == key].to_numpy() - statistics.append(stat_func(group_values)) - return pd.Series(statistics) + df = pd.DataFrame({'keys': keys, 'values': values}) + return df.groupby('keys')['values'].agg(stat_func) real_stats = calculate_statistics(real_keys, real_values) synthetic_stats = calculate_statistics(synthetic_keys, synthetic_values)