Skip to content

Commit

Permalink
Merge branch 'main' into issue-670-warnings-log
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho authored Nov 19, 2024
2 parents ee0a447 + f43d2a7 commit 835283a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 17 deletions.
19 changes: 4 additions & 15 deletions sdmetrics/column_pairs/statistical/statistic_msas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""StatisticMSAS module."""

import numpy as np
import pandas as pd

from sdmetrics.goal import Goal
Expand Down Expand Up @@ -61,18 +60,9 @@ def compute(real_data, synthetic_data, statistic='mean'):
float:
The similarity score between the real and synthetic data distributions.
"""
statistic_functions = {
'mean': np.mean,
'median': np.median,
'std': np.std,
'min': np.min,
'max': np.max,
}
if statistic not in statistic_functions:
raise ValueError(
f'Invalid statistic: {statistic}.'
f' Choose from [{", ".join(statistic_functions.keys())}].'
)
valid_statistics = ['mean', 'median', 'std', 'min', 'max']
if statistic not in valid_statistics:
raise ValueError(f'Invalid statistic: {statistic}. Choose from {valid_statistics}.')

for data in [real_data, synthetic_data]:
if (
Expand All @@ -84,11 +74,10 @@ def compute(real_data, synthetic_data, statistic='mean'):

real_keys, real_values = real_data
synthetic_keys, synthetic_values = synthetic_data
stat_func = statistic_functions[statistic]

def calculate_statistics(keys, values):
df = pd.DataFrame({'keys': keys, 'values': values})
return df.groupby('keys')['values'].agg(stat_func)
return df.groupby('keys')['values'].agg(statistic)

real_stats = calculate_statistics(real_keys, real_values)
synthetic_stats = calculate_statistics(synthetic_keys, synthetic_values)
Expand Down
7 changes: 5 additions & 2 deletions tests/unit/column_pairs/statistical/test_statistic_msas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class TestStatisticMSAS:
def test_compute_identical_sequences(self):
def test_compute_identical_sequences(self, recwarn):
"""Test it returns 1 when real and synthetic data are identical."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
Expand All @@ -24,6 +24,9 @@ def test_compute_identical_sequences(self):
)
assert score == 1

# Ensure GH#665 is fixed
assert len(recwarn) == 0

def test_compute_different_sequences(self):
"""Test it for distinct distributions."""
# Setup
Expand Down Expand Up @@ -87,7 +90,7 @@ def test_compute_with_invalid_statistic(self):

# Run and Assert
err_msg = re.escape(
'Invalid statistic: invalid. Choose from [mean, median, std, min, max].'
"Invalid statistic: invalid. Choose from ['mean', 'median', 'std', 'min', 'max']."
)
with pytest.raises(ValueError, match=err_msg):
StatisticMSAS.compute(
Expand Down

0 comments on commit 835283a

Please sign in to comment.