From f43d2a71812a18a8072ce3e66d59b286a52d8334 Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Tue, 19 Nov 2024 09:10:25 -0800 Subject: [PATCH] Remove undesirable `FutureWarning` in `StatisticMSAS` (#666) --- .../statistical/statistic_msas.py | 19 ++++--------------- .../statistical/test_statistic_msas.py | 7 +++++-- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/sdmetrics/column_pairs/statistical/statistic_msas.py b/sdmetrics/column_pairs/statistical/statistic_msas.py index 8afab764..8440618d 100644 --- a/sdmetrics/column_pairs/statistical/statistic_msas.py +++ b/sdmetrics/column_pairs/statistical/statistic_msas.py @@ -1,6 +1,5 @@ """StatisticMSAS module.""" -import numpy as np import pandas as pd from sdmetrics.goal import Goal @@ -61,18 +60,9 @@ def compute(real_data, synthetic_data, statistic='mean'): float: The similarity score between the real and synthetic data distributions. """ - statistic_functions = { - 'mean': np.mean, - 'median': np.median, - 'std': np.std, - 'min': np.min, - 'max': np.max, - } - if statistic not in statistic_functions: - raise ValueError( - f'Invalid statistic: {statistic}.' - f' Choose from [{", ".join(statistic_functions.keys())}].' - ) + valid_statistics = ['mean', 'median', 'std', 'min', 'max'] + if statistic not in valid_statistics: + raise ValueError(f'Invalid statistic: {statistic}. Choose from {valid_statistics}.') for data in [real_data, synthetic_data]: if ( @@ -84,11 +74,10 @@ def compute(real_data, synthetic_data, statistic='mean'): real_keys, real_values = real_data synthetic_keys, synthetic_values = synthetic_data - stat_func = statistic_functions[statistic] def calculate_statistics(keys, values): df = pd.DataFrame({'keys': keys, 'values': values}) - return df.groupby('keys')['values'].agg(stat_func) + return df.groupby('keys')['values'].agg(statistic) real_stats = calculate_statistics(real_keys, real_values) synthetic_stats = calculate_statistics(synthetic_keys, synthetic_values) diff --git a/tests/unit/column_pairs/statistical/test_statistic_msas.py b/tests/unit/column_pairs/statistical/test_statistic_msas.py index 9e8813eb..52338844 100644 --- a/tests/unit/column_pairs/statistical/test_statistic_msas.py +++ b/tests/unit/column_pairs/statistical/test_statistic_msas.py @@ -7,7 +7,7 @@ class TestStatisticMSAS: - def test_compute_identical_sequences(self): + def test_compute_identical_sequences(self, recwarn): """Test it returns 1 when real and synthetic data are identical.""" # Setup real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2']) @@ -24,6 +24,9 @@ def test_compute_identical_sequences(self): ) assert score == 1 + # Ensure GH#665 is fixed + assert len(recwarn) == 0 + def test_compute_different_sequences(self): """Test it for distinct distributions.""" # Setup @@ -87,7 +90,7 @@ def test_compute_with_invalid_statistic(self): # Run and Assert err_msg = re.escape( - 'Invalid statistic: invalid. Choose from [mean, median, std, min, max].' + "Invalid statistic: invalid. Choose from ['mean', 'median', 'std', 'min', 'max']." ) with pytest.raises(ValueError, match=err_msg): StatisticMSAS.compute(