Skip to content

Commit

Permalink
Fix tests failing due to scipy 1.14 (#601)
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 authored Jun 26, 2024
1 parent 8e79aa2 commit 7b4fc9f
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 26 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ dependencies = [
"scikit-learn>=1.1.0;python_version>='3.10' and python_version<'3.11'",
"scikit-learn>=1.1.3;python_version>='3.11' and python_version<'3.12'",
"scikit-learn>=1.3.1;python_version>='3.12'",
"scipy>=1.7.3,<1.14.0;python_version<'3.10'",
"scipy>=1.9.2,<1.14.0;python_version>='3.10' and python_version<'3.12'",
"scipy>=1.12.0,<1.14.0;python_version>='3.12'",
"scipy>=1.7.3;python_version<'3.10'",
"scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'",
"scipy>=1.12.0;python_version>='3.12'",
'copulas>=0.11.0',
'tqdm>=4.29',
'plotly>=5.19.0',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_multi_table_quality_report():
details.append(report.get_details(property_))

# Assert score
assert score == 0.649582127409184
assert round(score, 15) == 0.649582127409184
pd.testing.assert_frame_equal(
properties,
pd.DataFrame({
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
from tests.utils import get_error_type

from sdmetrics.demos import load_demo
from sdmetrics.reports.single_table._properties.column_pair_trends import ColumnPairTrends
Expand Down Expand Up @@ -79,17 +80,25 @@ def test_get_score_warnings(self, recwarn):
# Run
column_pair_trends = ColumnPairTrends()

exp_message_1 = "ValueError: could not convert string to float: 'a'"
exp_message_1 = 'ValueError'

exp_message_2 = "TypeError: '<=' not supported between instances of 'float' and 'str'"
exp_message_2 = 'TypeError'

exp_error_serie = pd.Series([exp_message_1, None, None, exp_message_2, exp_message_2, None])
exp_error_series = pd.Series([
exp_message_1,
None,
None,
exp_message_2,
exp_message_2,
None,
])

score = column_pair_trends.get_score(real_data, synthetic_data, metadata)

# Assert
details = column_pair_trends.details
pd.testing.assert_series_equal(details['Error'], exp_error_serie, check_names=False)
details['Error'] = details['Error'].apply(get_error_type)
pd.testing.assert_series_equal(details['Error'], exp_error_series, check_names=False)
assert score == 0.7751937984496124

def test_only_categorical_columns(self):
Expand Down
44 changes: 26 additions & 18 deletions tests/integration/reports/single_table/test_quality_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sdmetrics.demos import load_demo
from sdmetrics.reports.single_table import QualityReport
from tests.utils import get_error_type


class TestQualityReport:
Expand Down Expand Up @@ -262,7 +263,7 @@ def test_report_end_to_end_with_errors(self):
'Score': [0.6621621621621622, np.nan, 0.9953488372093023, 0.9395348837209302],
'Error': [
None,
"TypeError: '<' not supported between instances of 'str' and 'float'",
'TypeError',
None,
None,
],
Expand Down Expand Up @@ -304,23 +305,25 @@ def test_report_end_to_end_with_errors(self):
'Real Correlation': [np.nan] * 6,
'Synthetic Correlation': [np.nan] * 6,
'Error': [
"ValueError: could not convert string to float: 'a'",
'ValueError',
None,
None,
"TypeError: '<=' not supported between instances of 'float' and 'str'",
"TypeError: '<=' not supported between instances of 'float' and 'str'",
'TypeError',
'TypeError',
None,
],
}
expected_details_column_shapes = pd.DataFrame(expected_details_column_shapes_dict)
expected_details_cpt = pd.DataFrame(expected_details_cpt__dict)

pd.testing.assert_frame_equal(
report.get_details('Column Shapes'), expected_details_column_shapes
)
pd.testing.assert_frame_equal(
report.get_details('Column Pair Trends'), expected_details_cpt
)
# Errors may change based on versions of scipy installed.
col_shape_report = report.get_details('Column Shapes')
col_pair_report = report.get_details('Column Pair Trends')
col_shape_report['Error'] = col_shape_report['Error'].apply(get_error_type)
col_pair_report['Error'] = col_pair_report['Error'].apply(get_error_type)

pd.testing.assert_frame_equal(col_shape_report, expected_details_column_shapes)
pd.testing.assert_frame_equal(col_pair_report, expected_details_cpt)
assert report.get_score() == 0.8204378797402054

def test_report_with_column_nan(self):
Expand Down Expand Up @@ -446,10 +449,10 @@ def test_report_with_column_nan(self):
None,
None,
None,
'ValueError: x and y must have length at least 2.',
'ValueError',
None,
None,
'ValueError: x and y must have length at least 2.',
'ValueError',
None,
None,
None,
Expand All @@ -458,12 +461,17 @@ def test_report_with_column_nan(self):
expected_details_column_shapes = pd.DataFrame(expected_details_column_shapes_dict)
expected_details_cpt = pd.DataFrame(expected_details_cpt__dict)

pd.testing.assert_frame_equal(
report.get_details('Column Shapes'), expected_details_column_shapes
)
pd.testing.assert_frame_equal(
report.get_details('Column Pair Trends'), expected_details_cpt
)
col_shape_report = report.get_details('Column Shapes')
if 'Error' not in col_shape_report:
# Errors may not occur in certain scipy versions
expected_details_column_shapes.drop(columns=['Error'], inplace=True)

# Errors may change based on versions of library installed.
col_pair_report = report.get_details('Column Pair Trends')
col_pair_report['Error'] = col_pair_report['Error'].apply(get_error_type)

pd.testing.assert_frame_equal(col_shape_report, expected_details_column_shapes)
pd.testing.assert_frame_equal(col_pair_report, expected_details_cpt)

def test_report_with_verbose(self, capsys):
"""Test the report with verbose.
Expand Down
7 changes: 7 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,10 @@ def __eq__(self, other):
"""Assert equality by expanding the iterator."""
assert all(x == y for x, y in zip(self.iterator, other))
return True


def get_error_type(error):
if error is not None:
colon_index = error.find(':')
return error[:colon_index]
return None

0 comments on commit 7b4fc9f

Please sign in to comment.