Skip to content

Commit

Permalink
Merge branch 'issue_581_get_column_plot' into issue_581_get_cardinali…
Browse files Browse the repository at this point in the history
…ty_plot
  • Loading branch information
lajohn4747 committed Jun 24, 2024
2 parents 8f59a9d + c01163f commit a93c05f
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 54 deletions.
4 changes: 2 additions & 2 deletions sdmetrics/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _generate_heatmap_plot(all_data, columns):
unique_values = all_data['Data'].unique()

if len(columns) != 2:
raise ValueError('Generating a heatmap plot requires only two columns for the axis.')
raise ValueError('Generating a heatmap plot requires exactly two columns for the axis.')

fig = px.density_heatmap(
all_data, x=columns[0], y=columns[1], facet_col='Data', histnorm='probability'
Expand Down Expand Up @@ -165,7 +165,7 @@ def _generate_scatter_plot(all_data, columns):
"""

if len(columns) != 2:
raise ValueError('Generating a scatter plot requires only two columns for the axis.')
raise ValueError('Generating a scatter plot requires exactly two columns for the axis.')

unique_values = all_data['Data'].unique()
fig = px.scatter(
Expand Down
52 changes: 0 additions & 52 deletions tests/unit/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,32 +692,6 @@ def test__generate_scatter_plot(px_mock):
assert fig == mock_figure


@patch('sdmetrics.visualization.px')
def test__generate_scatter_plot_one_column_failure(px_mock):
"""Test the ``_generate_scatter_plot`` method."""
# Setup
real_column = pd.DataFrame({
'col1': [1, 2, 3, 4],
'col2': [1.1, 1.2, 1.3, 1.4],
'Data': ['Real'] * 4,
})
synthetic_column = pd.DataFrame({
'col1': [1, 2, 4, 5],
'col2': [1.1, 1.2, 1.3, 1.4],
'Data': ['Synthetic'] * 4,
})

all_data = pd.concat([real_column, synthetic_column], axis=0, ignore_index=True)
columns = ['col1']
mock_figure = Mock()
px_mock.scatter.return_value = mock_figure

# Run and assert
error_msg = re.escape('Generating a scatter plot requires only two columns for the axis.')
with pytest.raises(ValueError, match=error_msg):
_generate_scatter_plot(all_data, columns)


@patch('sdmetrics.visualization.px')
def test__generate_heatmap_plot(px_mock):
"""Test the ``_generate_heatmap_plot`` method."""
Expand Down Expand Up @@ -769,32 +743,6 @@ def test__generate_heatmap_plot(px_mock):
assert fig == mock_figure


@patch('sdmetrics.visualization.px')
def test__generate_heatmap_plot_one_column(px_mock):
"""Test the ``_generate_heatmap_plot`` method."""
# Setup
real_column = pd.DataFrame({
'col1': [1, 2, 3, 4],
'col2': ['a', 'b', 'c', 'd'],
'Data': ['Real'] * 4,
})
synthetic_column = pd.DataFrame({
'col1': [1, 2, 4, 5],
'col2': ['a', 'b', 'c', 'd'],
'Data': ['Synthetic'] * 4,
})
columns = ['col1']
all_data = pd.concat([real_column, synthetic_column], axis=0, ignore_index=True)

mock_figure = Mock()
px_mock.density_heatmap.return_value = mock_figure

# Run and assert
error_msg = re.escape('Generating a heatmap plot requires only two columns for the axis.')
with pytest.raises(ValueError, match=error_msg):
_generate_heatmap_plot(all_data, columns)


@patch('sdmetrics.visualization.px')
def test__generate_line_plot(px_mock):
"""Test the ``_generate_line_plot`` method."""
Expand Down

0 comments on commit a93c05f

Please sign in to comment.