diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index 8799aead..734f5aac 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -372,8 +372,7 @@ def _get_max_between_datasets(real_data, synthetic_data): return max(synthetic_data) elif synthetic_data is None: return max(real_data) - else: - return max(max(real_data), max(synthetic_data)) + return max(max(real_data), max(synthetic_data)) def _get_min_between_datasets(real_data, synthetic_data): @@ -383,8 +382,7 @@ def _get_min_between_datasets(real_data, synthetic_data): return min(synthetic_data) elif synthetic_data is None: return min(real_data) - else: - return min(min(real_data), min(synthetic_data)) + return min(min(real_data), min(synthetic_data)) def _generate_cardinality_plot( diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index f9507bd5..632a0082 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -15,6 +15,8 @@ _generate_line_plot, _generate_scatter_plot, _get_cardinality, + _get_max_between_datasets, + _get_min_between_datasets, get_cardinality_plot, get_column_line_plot, get_column_pair_plot, @@ -48,6 +50,54 @@ def test_get_cardinality(): pd.testing.assert_series_equal(result, expected_result) +def test__get_max_between_datasets(): + """Test the ``_get_max_between_datasets`` method""" + # Setup + mock_real_data = pd.Series([1, 1, 2, 2, 2]) + mock_synthetic_data = pd.Series([3, 3, 4]) + + # Run + real_only_val = _get_max_between_datasets(mock_real_data, None) + synth_only_val = _get_max_between_datasets(None, mock_synthetic_data) + all_val = _get_max_between_datasets(mock_real_data, mock_synthetic_data) + + # Assert + expected_real_only_val = 2 + expected_synth_only_val = 4 + expected_all_val = 4 + assert expected_real_only_val == real_only_val + assert expected_synth_only_val == synth_only_val + assert expected_all_val == all_val + + error_msg = re.escape('Cannot get max between two None values.') + with pytest.raises(ValueError, match=error_msg): + _get_max_between_datasets(None, None) + + +def test__get_min_between_datasets(): + """Test the ``_get_min_between_datasets`` method""" + # Setup + mock_real_data = pd.Series([1, 1, 2, 2, 2]) + mock_synthetic_data = pd.Series([3, 3, 4]) + + # Run + real_only_val = _get_min_between_datasets(mock_real_data, None) + synth_only_val = _get_min_between_datasets(None, mock_synthetic_data) + all_val = _get_min_between_datasets(mock_real_data, mock_synthetic_data) + + # Assert + expected_real_only_val = 1 + expected_synth_only_val = 3 + expected_all_val = 1 + assert expected_real_only_val == real_only_val + assert expected_synth_only_val == synth_only_val + assert expected_all_val == all_val + + error_msg = re.escape('Cannot get min between two None values.') + with pytest.raises(ValueError, match=error_msg): + _get_min_between_datasets(None, None) + + @patch('sdmetrics.visualization.px') def test_generate_cardinality_bar_plot(mock_px): """Test the ``_generate_cardinality_plot`` method."""