Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed Jun 25, 2024
1 parent 5f4d071 commit ab546b6
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
6 changes: 2 additions & 4 deletions sdmetrics/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,7 @@ def _get_max_between_datasets(real_data, synthetic_data):
return max(synthetic_data)
elif synthetic_data is None:
return max(real_data)
else:
return max(max(real_data), max(synthetic_data))
return max(max(real_data), max(synthetic_data))


def _get_min_between_datasets(real_data, synthetic_data):
Expand All @@ -383,8 +382,7 @@ def _get_min_between_datasets(real_data, synthetic_data):
return min(synthetic_data)
elif synthetic_data is None:
return min(real_data)
else:
return min(min(real_data), min(synthetic_data))
return min(min(real_data), min(synthetic_data))


def _generate_cardinality_plot(
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
_generate_line_plot,
_generate_scatter_plot,
_get_cardinality,
_get_max_between_datasets,
_get_min_between_datasets,
get_cardinality_plot,
get_column_line_plot,
get_column_pair_plot,
Expand Down Expand Up @@ -48,6 +50,54 @@ def test_get_cardinality():
pd.testing.assert_series_equal(result, expected_result)


def test__get_max_between_datasets():
"""Test the ``_get_max_between_datasets`` method"""
# Setup
mock_real_data = pd.Series([1, 1, 2, 2, 2])
mock_synthetic_data = pd.Series([3, 3, 4])

# Run
real_only_val = _get_max_between_datasets(mock_real_data, None)
synth_only_val = _get_max_between_datasets(None, mock_synthetic_data)
all_val = _get_max_between_datasets(mock_real_data, mock_synthetic_data)

# Assert
expected_real_only_val = 2
expected_synth_only_val = 4
expected_all_val = 4
assert expected_real_only_val == real_only_val
assert expected_synth_only_val == synth_only_val
assert expected_all_val == all_val

error_msg = re.escape('Cannot get max between two None values.')
with pytest.raises(ValueError, match=error_msg):
_get_max_between_datasets(None, None)


def test__get_min_between_datasets():
"""Test the ``_get_min_between_datasets`` method"""
# Setup
mock_real_data = pd.Series([1, 1, 2, 2, 2])
mock_synthetic_data = pd.Series([3, 3, 4])

# Run
real_only_val = _get_min_between_datasets(mock_real_data, None)
synth_only_val = _get_min_between_datasets(None, mock_synthetic_data)
all_val = _get_min_between_datasets(mock_real_data, mock_synthetic_data)

# Assert
expected_real_only_val = 1
expected_synth_only_val = 3
expected_all_val = 1
assert expected_real_only_val == real_only_val
assert expected_synth_only_val == synth_only_val
assert expected_all_val == all_val

error_msg = re.escape('Cannot get min between two None values.')
with pytest.raises(ValueError, match=error_msg):
_get_min_between_datasets(None, None)


@patch('sdmetrics.visualization.px')
def test_generate_cardinality_bar_plot(mock_px):
"""Test the ``_generate_cardinality_plot`` method."""
Expand Down

0 comments on commit ab546b6

Please sign in to comment.