Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow get_column_plot to graph synthetic and real data individually #596

Merged
merged 23 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
5030641
Allow get_column_pair_plot to visualize one dataset instead of both r…
lajohn4747 Jun 20, 2024
58e0989
Adding ability to individually graph synthetic or real data
lajohn4747 Jun 20, 2024
9c7dbc7
Fix title naming
lajohn4747 Jun 20, 2024
10026b1
Merge branch 'issue_581_get_column_pair_plot' into issue_581_get_colu…
lajohn4747 Jun 20, 2024
412b10b
Fixed title
lajohn4747 Jun 20, 2024
122855c
Update message
lajohn4747 Jun 20, 2024
f83279a
Merge branch 'issue_581_get_column_pair_plot' into issue_581_get_colu…
lajohn4747 Jun 20, 2024
2add684
Keep colors
lajohn4747 Jun 20, 2024
9fec05b
Update error message
lajohn4747 Jun 20, 2024
964008c
Merge branch 'main' into issue_581_get_column_pair_plot
lajohn4747 Jun 21, 2024
ceca02e
Merge branch 'issue_581_get_column_pair_plot' into issue_581_get_colu…
lajohn4747 Jun 21, 2024
b80af9f
Add tests for previous visualization methods
lajohn4747 Jun 24, 2024
a917652
Use assert_called_once_with with ANY
lajohn4747 Jun 24, 2024
87178a1
Fix failing minimum test
lajohn4747 Jun 24, 2024
34b66ae
Address comments and added check for single column test for internal …
lajohn4747 Jun 24, 2024
d694819
Fix
lajohn4747 Jun 24, 2024
8e97290
Merge branch 'issue_581_get_column_pair_plot' into issue_581_get_colu…
lajohn4747 Jun 24, 2024
673c1d9
Remove unneeded tests
lajohn4747 Jun 24, 2024
c01163f
Merge branch 'issue_581_get_column_pair_plot' into issue_581_get_colu…
lajohn4747 Jun 24, 2024
dc2eed0
Merge branch 'main' into issue_581_get_column_pair_plot
lajohn4747 Jun 25, 2024
42bfd87
Cap scipy from new release to avoid test failures
lajohn4747 Jun 25, 2024
8e94caa
Cap all scipy versions
lajohn4747 Jun 25, 2024
990eb74
Merge branch 'issue_581_get_column_pair_plot' into issue_581_get_colu…
lajohn4747 Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdmetrics/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _generate_column_plot(

if plot_type not in ['bar', 'distplot']:
raise ValueError(
"Unrecognized plot_type '{plot_type}'. Pleas use one of 'bar' or 'distplot'"
f"Unrecognized plot_type '{plot_type}'. Please use one of 'bar' or 'distplot'"
)

column_name = ''
Expand Down
189 changes: 189 additions & 0 deletions tests/unit/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
import pandas as pd
import pytest

from sdmetrics.reports.utils import PlotConfig
from sdmetrics.visualization import (
_generate_box_plot,
_generate_cardinality_plot,
_generate_column_bar_plot,
_generate_column_distplot,
_generate_column_plot,
_generate_heatmap_plot,
_generate_line_plot,
_generate_scatter_plot,
Expand Down Expand Up @@ -276,6 +280,191 @@ def test_get_column_plot_no_data():
get_column_plot(None, None, 'values')


@patch('sdmetrics.visualization.px.histogram')
def test__generate_column_bar_plot(mock_histogram):
"""Test ``_generate_column_bar_plot`` functionality"""
# Setup
real_data = pd.Series([1, 2, 2, 3, 5])
synthetic_data = pd.Series([2, 2, 3, 4, 5])

# Run
_generate_column_bar_plot(real_data, synthetic_data)

# Assert
expected_data = pd.DataFrame(
pd.concat([real_data, synthetic_data], axis=0, ignore_index=True).astype('float64')
)
expected_parameters = {
'x': 'values',
'color': 'Data',
'barmode': 'group',
'color_discrete_sequence': ['#000036', '#01E0C9'],
'pattern_shape': 'Data',
'pattern_shape_sequence': ['', '/'],
'histnorm': 'probability density',
}
pd.testing.assert_frame_equal(expected_data, mock_histogram.call_args[0][0])
assert expected_parameters == mock_histogram.call_args[1]
mock_histogram.assert_called_once()


@patch('sdmetrics.visualization.ff.create_distplot')
def test__generate_column_distplot(mock_distplot):
"""Test ``_generate_column_distplot`` functionality"""
# Setup
real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]})
synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]})

# Run
_generate_column_distplot(real_data, synthetic_data)

# Assert
expected_data = []
expected_data.append(real_data['values'])
expected_data.append(synthetic_data['values'])
expected_data == mock_distplot.call_args[0][0]

['Real', 'Synthetic'] == mock_distplot.call_args[0][1]

expected_colors = [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]
expected_parameters = {
'show_hist': False,
'show_rug': False,
'colors': expected_colors,
}
assert expected_parameters == mock_distplot.call_args[1]
mock_distplot.assert_called_once()
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved


@patch('sdmetrics.visualization._generate_column_distplot')
def test___generate_column_plot_type_distplot(mock_dist_plot):
"""Test ``_generate_column_plot`` with a dist_plot"""
# Setup
real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]})
synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]})
mock_fig = Mock()
mock_object = Mock()
mock_object.x = [1, 2, 2, 3, 5]
mock_fig.data = [mock_object, mock_object]
mock_dist_plot.return_value = mock_fig

# Run
_generate_column_plot(real_data['values'], synthetic_data['values'], 'distplot')

# Assert
expected_real_data = pd.DataFrame({
'values': [1, 2, 2, 3, 5],
'Data': ['Real', 'Real', 'Real', 'Real', 'Real'],
})
expected_synth_data = pd.DataFrame({
'values': [2, 2, 3, 4, 5],
'Data': ['Synthetic', 'Synthetic', 'Synthetic', 'Synthetic', 'Synthetic'],
})
pd.testing.assert_frame_equal(mock_dist_plot.call_args[0][0], expected_real_data)
pd.testing.assert_frame_equal(mock_dist_plot.call_args[0][1], expected_synth_data)
assert mock_dist_plot.call_args[0][2] == {}
mock_dist_plot.assert_called_once()

mock_fig.update_layout.assert_called_once_with(
title="Real vs. Synthetic Data for column 'values'",
xaxis_title='Value',
yaxis_title='Frequency',
plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
annotations=[],
font={'size': PlotConfig.FONT_SIZE},
)


@patch('sdmetrics.visualization._generate_column_bar_plot')
def test___generate_column_plot_type_bar(mock_bar_plot):
"""Test ``_generate_column_plot`` with a bar plot"""
# Setup
real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]})
synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]})
mock_fig = Mock()
mock_object = Mock()
mock_object.x = [1, 2, 2, 3, 5]
mock_fig.data = [mock_object, mock_object]
mock_bar_plot.return_value = mock_fig

# Run
_generate_column_plot(real_data['values'], synthetic_data['values'], 'bar')

# Assert
expected_real_data = pd.DataFrame({
'values': [1, 2, 2, 3, 5],
'Data': ['Real', 'Real', 'Real', 'Real', 'Real'],
})
expected_synth_data = pd.DataFrame({
'values': [2, 2, 3, 4, 5],
'Data': ['Synthetic', 'Synthetic', 'Synthetic', 'Synthetic', 'Synthetic'],
})
pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][0], expected_real_data)
pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][1], expected_synth_data)
assert mock_bar_plot.call_args[0][2] == {}
mock_bar_plot.assert_called_once()
mock_fig.update_layout.assert_called_once_with(
title="Real vs. Synthetic Data for column 'values'",
xaxis_title='Category',
yaxis_title='Frequency',
plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
annotations=[],
font={'size': PlotConfig.FONT_SIZE},
)


@patch('sdmetrics.visualization._generate_column_bar_plot')
def test___generate_column_plot_with_datetimes(mock_bar_plot):
"""Test ``_generate_column_plot`` using datetimes"""
# Setup
real_data = pd.DataFrame({'values': pd.to_datetime(['2021-01-20', '2022-01-21'])})
synthetic_data = pd.DataFrame({'values': pd.to_datetime(['2021-01-20', '2022-01-21'])})
mock_fig = Mock()
mock_object = Mock()
mock_object.x = [1, 2, 2, 3, 5]
mock_fig.data = [mock_object, mock_object]
mock_bar_plot.return_value = mock_fig

# Run
_generate_column_plot(real_data['values'], synthetic_data['values'], 'bar')

# Assert
print(mock_bar_plot.call_args[0][1])
expected_real_data = pd.DataFrame({
'values': [1611100800000000000, 1642723200000000000],
'Data': ['Real', 'Real'],
})
expected_synth_data = pd.DataFrame({
'values': [1611100800000000000, 1642723200000000000],
'Data': ['Synthetic', 'Synthetic'],
})
pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][0], expected_real_data)
pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][1], expected_synth_data)
assert mock_bar_plot.call_args[0][2] == {}
mock_bar_plot.assert_called_once()


def test___generate_column_plot_no_data():
"""Test ``_generate_column_plot`` when no data is passed in."""
# Run and Assert
error_msg = re.escape('No data provided to plot. Please provide either real or synthetic data.')
with pytest.raises(ValueError, match=error_msg):
_generate_column_plot(None, None, 'bar')


def test___generate_column_plot_with_bad_plot():
"""Test ``_generate_column_plot`` when an incorrect plot is set."""
# Setup
real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]})
synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]})
# Run and Assert
error_msg = re.escape(
"Unrecognized plot_type 'bad_plot'. Please use one of 'bar' or 'distplot'"
)
with pytest.raises(ValueError, match=error_msg):
_generate_column_plot(real_data, synthetic_data, 'bad_plot')


@patch('sdmetrics.visualization._generate_column_plot')
def test_get_column_plot_plot_one_data_set(mock__generate_column_plot):
"""Test ``get_column_plot`` for real data and synthetic data individually."""
Expand Down
Loading