Skip to content

Commit

Permalink
Address comments and added check for single column test for internal …
Browse files Browse the repository at this point in the history
…functions
  • Loading branch information
lajohn4747 committed Jun 24, 2024
1 parent 964008c commit 34b66ae
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 10 deletions.
21 changes: 11 additions & 10 deletions sdmetrics/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ def _generate_heatmap_plot(all_data, columns):
"""
unique_values = all_data['Data'].unique()

if len(columns) != 2:
raise ValueError('Generating a heatmap plot requires only two columns for the axis.')

fig = px.density_heatmap(
all_data, x=columns[0], y=columns[1], facet_col='Data', histnorm='probability'
)

title = ''
for name in unique_values:
title += f'{name} vs. '
title = title[:-4]
title += f"Data for columns '{columns[0]}' and '{columns[1]}"
title = ' vs. '.join(unique_values)
title += f" Data for columns '{columns[0]}' and '{columns[1]}"

fig.update_layout(
title_text=title,
Expand Down Expand Up @@ -155,6 +155,10 @@ def _generate_scatter_plot(all_data, columns):
Returns:
plotly.graph_objects._figure.Figure
"""

if len(columns) != 2:
raise ValueError('Generating a scatter plot requires only two columns for the axis.')

unique_values = all_data['Data'].unique()
fig = px.scatter(
all_data,
Expand All @@ -168,11 +172,8 @@ def _generate_scatter_plot(all_data, columns):
symbol='Data',
)

title = ''
for name in unique_values:
title += f'{name} vs. '
title = title[:-4]
title += f"Data for columns '{columns[0]}' and '{columns[1]}'"
title = ' vs. '.join(unique_values)
title += f" Data for columns '{columns[0]}' and '{columns[1]}'"

fig.update_layout(
title=title,
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,32 @@ def test__generate_scatter_plot(px_mock):
assert fig == mock_figure


@patch('sdmetrics.visualization.px')
def test__generate_scatter_plot_one_column_failure(px_mock):
"""Test the ``_generate_scatter_plot`` method."""
# Setup
real_column = pd.DataFrame({
'col1': [1, 2, 3, 4],
'col2': [1.1, 1.2, 1.3, 1.4],
'Data': ['Real'] * 4,
})
synthetic_column = pd.DataFrame({
'col1': [1, 2, 4, 5],
'col2': [1.1, 1.2, 1.3, 1.4],
'Data': ['Synthetic'] * 4,
})

all_data = pd.concat([real_column, synthetic_column], axis=0, ignore_index=True)
columns = ['col1']
mock_figure = Mock()
px_mock.scatter.return_value = mock_figure

# Run and assert
error_msg = re.escape('Generating a scatter plot requires only two columns for the axis.')
with pytest.raises(ValueError, match=error_msg):
_generate_scatter_plot(all_data, columns)


@patch('sdmetrics.visualization.px')
def test__generate_heatmap_plot(px_mock):
"""Test the ``_generate_heatmap_plot`` method."""
Expand Down Expand Up @@ -472,6 +498,32 @@ def test__generate_heatmap_plot(px_mock):
assert fig == mock_figure


@patch('sdmetrics.visualization.px')
def test__generate_heatmap_plot_one_column(px_mock):
"""Test the ``_generate_heatmap_plot`` method."""
# Setup
real_column = pd.DataFrame({
'col1': [1, 2, 3, 4],
'col2': ['a', 'b', 'c', 'd'],
'Data': ['Real'] * 4,
})
synthetic_column = pd.DataFrame({
'col1': [1, 2, 4, 5],
'col2': ['a', 'b', 'c', 'd'],
'Data': ['Synthetic'] * 4,
})
columns = ['col1']
all_data = pd.concat([real_column, synthetic_column], axis=0, ignore_index=True)

mock_figure = Mock()
px_mock.density_heatmap.return_value = mock_figure

# Run and assert
error_msg = re.escape('Generating a heatmap plot requires only two columns for the axis.')
with pytest.raises(ValueError, match=error_msg):
_generate_heatmap_plot(all_data, columns)


@patch('sdmetrics.visualization.px')
def test__generate_line_plot(px_mock):
"""Test the ``_generate_line_plot`` method."""
Expand Down

0 comments on commit 34b66ae

Please sign in to comment.