Skip to content

Commit

Permalink
Use groupby and add data validation
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Oct 31, 2024
1 parent 3254539 commit 44de4c5
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions sdmetrics/timeseries/statistic_msas.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,21 @@ def compute(real_data, synthetic_data, statistic='mean'):
f' Choose from [{", ".join(statistic_functions.keys())}].'
)

for data in [real_data, synthetic_data]:
if (
not isinstance(data, tuple)
or len(data) != 2
or (not (isinstance(data[0], pd.Series) and isinstance(data[1], pd.Series)))
):
raise ValueError('The data must be a tuple of two pandas series.')

real_keys, real_values = real_data
synthetic_keys, synthetic_values = synthetic_data
stat_func = statistic_functions[statistic]

def calculate_statistics(keys, values):
statistics = []
for key in keys.unique():
group_values = values[keys == key].to_numpy()
statistics.append(stat_func(group_values))
return pd.Series(statistics)
df = pd.DataFrame({'keys': keys, 'values': values})
return df.groupby('keys')['values'].agg(stat_func)

real_stats = calculate_statistics(real_keys, real_values)
synthetic_stats = calculate_statistics(synthetic_keys, synthetic_values)
Expand Down

0 comments on commit 44de4c5

Please sign in to comment.