Skip to content

Commit

Permalink
Add support for numpy 2.0.0 (#599)
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo authored Jun 26, 2024
1 parent 85f221f commit 8e79aa2
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ license = { text = 'MIT license' }
requires-python = ">=3.8,<3.13"
readme = 'README.md'
dependencies = [
"numpy>=1.21.0,<2.0.0;python_version<'3.10'",
"numpy>=1.23.3,<2.0.0;python_version>='3.10' and python_version<'3.12'",
"numpy>=1.26.0,<2.0.0;python_version>='3.12'",
"numpy>=1.21.0;python_version<'3.10'",
"numpy>=1.23.3;python_version>='3.10' and python_version<'3.12'",
"numpy>=1.26.0;python_version>='3.12'",
"pandas>=1.4.0;python_version<'3.11'",
"pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'",
"pandas>=2.1.1;python_version>='3.12'",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import Mock

import numpy as np
from tqdm import tqdm

from sdmetrics.demos import load_demo
Expand All @@ -17,7 +18,7 @@ def test_end_to_end(self):
result = column_pair_trends.get_score(real_data, synthetic_data, metadata)

# Assert
assert result == 0.45654629583521095
assert np.isclose(result, 0.45654629583521095, atol=1e-8)

def test_with_progress_bar(self):
"""Test that the progress bar is correctly updated."""
Expand All @@ -37,5 +38,5 @@ def test_with_progress_bar(self):
result = column_pair_trends.get_score(real_data, synthetic_data, metadata, progress_bar)

# Assert
assert result == 0.45654629583521095
assert np.isclose(result, 0.45654629583521095, atol=1e-8)
assert mock_update.call_count == num_iter
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pandas as pd
import pytest

from sdmetrics.multi_table.statistical import CardinalityStatisticSimilarity
from sdmetrics.warnings import ConstantInputWarning
Expand Down Expand Up @@ -57,7 +58,7 @@ def test__compute_statistic_constant_input(self):
)

# Run
with np.testing.assert_warns(ConstantInputWarning, match=expected_warn_msg):
with pytest.warns(ConstantInputWarning, match=expected_warn_msg):
result = CardinalityStatisticSimilarity._compute_statistic(
real_distribution, synthetic_distribution, 'mean'
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pandas as pd
import pytest

from sdmetrics.single_column.statistical import StatisticSimilarity
from sdmetrics.warnings import ConstantInputWarning
Expand Down Expand Up @@ -54,13 +55,13 @@ def test_compute_breakdown_constant_input(self):
'score': np.nan,
}
expected_warn_msg = (
'The real data input array is constant. '
'The StatisticSimilarity metric is either undefined or infinte.'
'The real data input array is constant. The StatisticSimilarity '
'metric is either undefined or infinite.'
)

# Run
metric = StatisticSimilarity()
with np.testing.assert_warns(ConstantInputWarning, match=expected_warn_msg):
with pytest.warns(ConstantInputWarning, match=expected_warn_msg):
result = metric.compute_breakdown(real_data, synthetic_data, statistic='mean')

# Assert
Expand Down

0 comments on commit 8e79aa2

Please sign in to comment.