diff --git a/pyproject.toml b/pyproject.toml index f95571d9..f0e87e4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,9 +30,9 @@ dependencies = [ "scikit-learn>=1.1.0;python_version>='3.10' and python_version<'3.11'", "scikit-learn>=1.1.3;python_version>='3.11' and python_version<'3.12'", "scikit-learn>=1.3.1;python_version>='3.12'", - "scipy>=1.7.3;python_version<'3.10'", - "scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'", - "scipy>=1.12.0;python_version>='3.12'", + "scipy>=1.7.3,<1.14.0;python_version<'3.10'", + "scipy>=1.9.2,<1.14.0;python_version>='3.10' and python_version<'3.12'", + "scipy>=1.12.0,<1.14.0;python_version>='3.12'", 'copulas>=0.11.0', 'tqdm>=4.29', 'plotly>=5.19.0', diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index 20d3d329..64630271 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -46,10 +46,10 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}): """Generate a bar plot of the real and synthetic data. Args: - real_column (pandas.Series): - The real data for the desired column. - synthetic_column (pandas.Series): - The synthetic data for the desired column. + real_column (pandas.Series or None): + The real data for the desired column. If None this data will not be graphed. + synthetic_column (pandas.Series or None): + The synthetic data for the desired column. If None this data will not be graphed. plot_kwargs (dict, optional): Dictionary of keyword arguments to pass to px.histogram. Keyword arguments provided this way will overwrite defaults. @@ -57,12 +57,20 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}): Returns: plotly.graph_objects._figure.Figure """ - all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) + all_data = pd.DataFrame() + color_sequence = [] + if real_data is not None: + all_data = pd.concat([all_data, real_data], axis=0, ignore_index=True) + color_sequence.append(PlotConfig.DATACEBO_DARK) + if synthetic_data is not None: + all_data = pd.concat([all_data, synthetic_data], axis=0, ignore_index=True) + color_sequence.append(PlotConfig.DATACEBO_GREEN) + histogram_kwargs = { 'x': 'values', 'color': 'Data', 'barmode': 'group', - 'color_discrete_sequence': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN], + 'color_discrete_sequence': color_sequence, 'pattern_shape': 'Data', 'pattern_shape_sequence': ['', '/'], 'histnorm': 'probability density', @@ -86,12 +94,20 @@ def _generate_heatmap_plot(all_data, columns): Returns: plotly.graph_objects._figure.Figure """ + unique_values = all_data['Data'].unique() + + if len(columns) != 2: + raise ValueError('Generating a heatmap plot requires exactly two columns for the axis.') + fig = px.density_heatmap( all_data, x=columns[0], y=columns[1], facet_col='Data', histnorm='probability' ) + title = ' vs. '.join(unique_values) + title += f" Data for columns '{columns[0]}' and '{columns[1]}" + fig.update_layout( - title_text=f"Real vs Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", + title_text=title, coloraxis={'colorscale': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]}, font={'size': PlotConfig.FONT_SIZE}, ) @@ -147,6 +163,11 @@ def _generate_scatter_plot(all_data, columns): Returns: plotly.graph_objects._figure.Figure """ + + if len(columns) != 2: + raise ValueError('Generating a scatter plot requires exactly two columns for the axis.') + + unique_values = all_data['Data'].unique() fig = px.scatter( all_data, x=columns[0], @@ -159,8 +180,11 @@ def _generate_scatter_plot(all_data, columns): symbol='Data', ) + title = ' vs. '.join(unique_values) + title += f" Data for columns '{columns[0]}' and '{columns[1]}'" + fig.update_layout( - title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", + title=title, plot_bgcolor=PlotConfig.BACKGROUND_COLOR, font={'size': PlotConfig.FONT_SIZE}, ) @@ -172,10 +196,10 @@ def _generate_column_distplot(real_data, synthetic_data, plot_kwargs={}): """Plot the real and synthetic data as a distplot. Args: - real_data (pandas.DataFrame): - The real data for the desired column. - synthetic_data (pandas.DataFrame): - The synthetic data for the desired column. + real_data (pandas.DataFrame or None): + The real data for the desired column. If None this data will not be graphed. + synthetic_data (pandas.DataFrame or None): + The synthetic data for the desired column. If None this data will not be graphed. plot_kwargs (dict, optional): Dictionary of keyword arguments to pass to px.histogram. Keyword arguments provided this way will overwrite defaults. @@ -183,15 +207,27 @@ def _generate_column_distplot(real_data, synthetic_data, plot_kwargs={}): Returns: plotly.graph_objects._figure.Figure """ + hist_data = [] + col_names = [] + colors = [] + if real_data is not None: + hist_data.append(real_data['values']) + col_names.append('Real') + colors.append(PlotConfig.DATACEBO_DARK) + if synthetic_data is not None: + hist_data.append(synthetic_data['values']) + col_names.append('Synthetic') + colors.append(PlotConfig.DATACEBO_GREEN) + default_distplot_kwargs = { 'show_hist': False, 'show_rug': False, - 'colors': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN], + 'colors': colors, } fig = ff.create_distplot( - [real_data['values'], synthetic_data['values']], - ['Real', 'Synthetic'], + hist_data, + col_names, **{**default_distplot_kwargs, **plot_kwargs}, ) @@ -204,10 +240,10 @@ def _generate_column_plot( """Generate a plot of the real and synthetic data. Args: - real_column (pandas.Series): - The real data for the desired column. - synthetic_column (pandas.Series): - The synthetic data for the desired column. + real_column (pandas.Series or None): + The real data for the desired column. If None this data will not be graphed. + synthetic_column (pandas.Series or None) + The synthetic data for the desired column. If None this data will not be graphed. plot_type (str): The type of plot to use. Must be one of 'bar' or 'distplot'. hist_kwargs (dict, optional): @@ -221,26 +257,55 @@ def _generate_column_plot( Returns: plotly.graph_objects._figure.Figure """ + + if real_column is None and synthetic_column is None: + raise ValueError('No data provided to plot. Please provide either real or synthetic data.') + if plot_type not in ['bar', 'distplot']: raise ValueError( - "Unrecognized plot_type '{plot_type}'. Pleas use one of 'bar' or 'distplot'" + f"Unrecognized plot_type '{plot_type}'. Please use one of 'bar' or 'distplot'" ) - column_name = real_column.name if hasattr(real_column, 'name') else '' - - missing_data_real = get_missing_percentage(real_column) - missing_data_synthetic = get_missing_percentage(synthetic_column) - - real_data = pd.DataFrame({'values': real_column.copy().dropna()}) - real_data['Data'] = 'Real' - synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()}) - synthetic_data['Data'] = 'Synthetic' + column_name = '' + missing_data_real = 0 + missing_data_synthetic = 0 + col_dtype = None + col_names = [] + title = '' + if real_column is not None and hasattr(real_column, 'name'): + column_name = real_column.name + elif synthetic_column is not None and hasattr(synthetic_column, 'name'): + column_name = synthetic_column.name + + real_data = None + if real_column is not None: + missing_data_real = get_missing_percentage(real_column) + real_data = pd.DataFrame({'values': real_column.copy().dropna()}) + real_data['Data'] = 'Real' + col_dtype = real_column.dtype + col_names.append('Real') + title += 'Real vs. ' + + synthetic_data = None + if synthetic_column is not None: + missing_data_synthetic = get_missing_percentage(synthetic_column) + synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()}) + synthetic_data['Data'] = 'Synthetic' + col_names.append('Synthetic') + title += 'Synthetic vs. ' + if col_dtype is None: + col_dtype = synthetic_column.dtype + + title = title[:-4] + title += f"Data for column '{column_name}'" is_datetime_sdtype = False - if is_datetime64_dtype(real_column.dtype): + if is_datetime64_dtype(col_dtype): is_datetime_sdtype = True - real_data['values'] = real_data['values'].astype('int64') - synthetic_data['values'] = synthetic_data['values'].astype('int64') + if real_data is not None: + real_data['values'] = real_data['values'].astype('int64') + if synthetic_data is not None: + synthetic_data['values'] = synthetic_data['values'].astype('int64') trace_args = {} @@ -251,7 +316,7 @@ def _generate_column_plot( fig = _generate_column_distplot(real_data, synthetic_data, plot_kwargs) trace_args = {'fill': 'tozeroy'} - for i, name in enumerate(['Real', 'Synthetic']): + for i, name in enumerate(col_names): fig.update_traces( x=pd.to_datetime(fig.data[i].x) if is_datetime_sdtype else fig.data[i].x, hovertemplate=f'{name}
Frequency: %{{y}}', @@ -260,6 +325,14 @@ def _generate_column_plot( ) show_missing_values = missing_data_real > 0 or missing_data_synthetic > 0 + text = '*Missing Values:' + if real_column is not None and show_missing_values: + text += f' Real Data ({missing_data_real}%), ' + if synthetic_column is not None and show_missing_values: + text += f'Synthetic Data ({missing_data_synthetic}%), ' + + text = text[:-2] + annotations = ( [] if not show_missing_values @@ -270,16 +343,13 @@ def _generate_column_plot( 'x': 1.0, 'y': 1.05, 'showarrow': False, - 'text': ( - f'*Missing Values: Real Data ({missing_data_real}%), ' - f'Synthetic Data ({missing_data_synthetic}%)' - ), + 'text': text, }, ] ) if not plot_title: - plot_title = f"Real vs. Synthetic Data for column '{column_name}'" + plot_title = title if not x_label: x_label = 'Category' @@ -401,10 +471,10 @@ def get_column_plot(real_data, synthetic_data, column_name, plot_type=None): """Return a plot of the real and synthetic data for a given column. Args: - real_data (pandas.DataFrame): - The real table data. - synthetic_data (pandas.DataFrame): - The synthetic table data. + real_data (pandas.DataFrame or None): + The real table data. If None this data will not be graphed. + synthetic_data (pandas.DataFrame or None): + The synthetic table data. If None this data will not be graphed. column_name (str): The name of the column. plot_type (str or None): @@ -416,28 +486,39 @@ def get_column_plot(real_data, synthetic_data, column_name, plot_type=None): Returns: plotly.graph_objects._figure.Figure """ + + if real_data is None and synthetic_data is None: + raise ValueError('No data provided to plot. Please provide either real or synthetic data.') + if plot_type not in ['bar', 'distplot', None]: raise ValueError( f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot', None]." ) - if column_name not in real_data.columns: - raise ValueError(f"Column '{column_name}' not found in real table data.") - if column_name not in synthetic_data.columns: - raise ValueError(f"Column '{column_name}' not found in synthetic table data.") + column = None + real_column = None + synthetic_column = None + if real_data is not None: + if column_name not in real_data.columns: + raise ValueError(f"Column '{column_name}' not found in real table data.") + column = real_data[column_name] + real_column = real_data[column_name] + + if synthetic_data is not None: + if column_name not in synthetic_data.columns: + raise ValueError(f"Column '{column_name}' not found in synthetic table data.") + if column is None: + column = synthetic_data[column_name] + synthetic_column = synthetic_data[column_name] - real_column = real_data[column_name] if plot_type is None: - column_is_datetime = is_datetime(real_data[column_name]) - dtype = real_column.dropna().infer_objects().dtype.kind + column_is_datetime = is_datetime(column) + dtype = column.dropna().infer_objects().dtype.kind if column_is_datetime or dtype in ('i', 'f'): plot_type = 'distplot' else: plot_type = 'bar' - real_column = real_data[column_name] - synthetic_column = synthetic_data[column_name] - fig = _generate_column_plot(real_column, synthetic_column, plot_type) return fig @@ -448,10 +529,10 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None """Return a plot of the real and synthetic data for a given column pair. Args: - real_data (pandas.DataFrame): - The real table data. - synthetic_column (pandas.Dataframe): - The synthetic table data. + real_data (pandas.DataFrame or None): + The real table data. If None this data will not be graphed. + synthetic_column (pandas.Dataframe or None): + The synthetic table data. If None this data will not be graphed. column_names (list[string]): The names of the two columns to plot. plot_type (str or None): @@ -466,16 +547,23 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None if len(column_names) != 2: raise ValueError('Must provide exactly two column names.') - if not set(column_names).issubset(real_data.columns): - raise ValueError( - f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.' - ) + if real_data is None and synthetic_data is None: + raise ValueError('No data provided to plot. Please provide either real or synthetic data.') - if not set(column_names).issubset(synthetic_data.columns): - raise ValueError( - f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} ' - 'in synthetic data.' - ) + if real_data is not None: + if not set(column_names).issubset(real_data.columns): + raise ValueError( + f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.' + ) + real_data = real_data[column_names] + + if synthetic_data is not None: + if not set(column_names).issubset(synthetic_data.columns): + raise ValueError( + f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} ' + 'in synthetic data.' + ) + synthetic_data = synthetic_data[column_names] if plot_type not in ['box', 'heatmap', 'scatter', None]: raise ValueError( @@ -483,12 +571,13 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None "['box', 'heatmap', 'scatter', None]." ) - real_data = real_data[column_names] - synthetic_data = synthetic_data[column_names] if plot_type is None: plot_type = [] for column_name in column_names: - column = real_data[column_name] + if real_data is not None: + column = real_data[column_name] + else: + column = synthetic_data[column_name] dtype = column.dropna().infer_objects().dtype.kind if dtype in ('i', 'f') or is_datetime(column): plot_type.append('scatter') @@ -501,19 +590,22 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None plot_type = plot_type.pop() # Merge the real and synthetic data and add a flag ``Data`` to indicate each one. - columns = list(real_data.columns) - real_data = real_data.copy() - real_data['Data'] = 'Real' - synthetic_data = synthetic_data.copy() - synthetic_data['Data'] = 'Synthetic' - all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) + all_data = pd.DataFrame() + if real_data is not None: + real_data = real_data.copy() + real_data['Data'] = 'Real' + all_data = pd.concat([all_data, real_data], axis=0, ignore_index=True) + if synthetic_data is not None: + synthetic_data = synthetic_data.copy() + synthetic_data['Data'] = 'Synthetic' + all_data = pd.concat([all_data, synthetic_data], axis=0, ignore_index=True) if plot_type == 'scatter': - return _generate_scatter_plot(all_data, columns) + return _generate_scatter_plot(all_data, column_names) elif plot_type == 'heatmap': - return _generate_heatmap_plot(all_data, columns) + return _generate_heatmap_plot(all_data, column_names) - return _generate_box_plot(all_data, columns) + return _generate_box_plot(all_data, column_names) def _generate_line_plot(real_data, synthetic_data, x_axis, y_axis, marker, annotations=None): diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 788e0714..7b3aaaa7 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -1,12 +1,16 @@ import re -from unittest.mock import Mock, call, patch +from unittest.mock import ANY, Mock, call, patch import pandas as pd import pytest +from sdmetrics.reports.utils import PlotConfig from sdmetrics.visualization import ( _generate_box_plot, _generate_cardinality_plot, + _generate_column_bar_plot, + _generate_column_distplot, + _generate_column_plot, _generate_heatmap_plot, _generate_line_plot, _generate_scatter_plot, @@ -268,6 +272,215 @@ def test_get_column_plot_bad_plot_type(): get_column_plot(real_data, synthetic_data, 'valeus', plot_type='bad_type') +def test_get_column_plot_no_data(): + """Test the ``get_column_plot`` method with no data passed in.""" + # Run and assert + error_msg = re.escape('No data provided to plot. Please provide either real or synthetic data.') + with pytest.raises(ValueError, match=error_msg): + get_column_plot(None, None, 'values') + + +@patch('sdmetrics.visualization.px.histogram') +def test__generate_column_bar_plot(mock_histogram): + """Test ``_generate_column_bar_plot`` functionality""" + # Setup + real_data = pd.DataFrame([1, 2, 2, 3, 5]) + synthetic_data = pd.DataFrame([2, 2, 3, 4, 5]) + + # Run + _generate_column_bar_plot(real_data, synthetic_data) + + # Assert + expected_data = pd.DataFrame(pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)) + expected_parameters = { + 'x': 'values', + 'color': 'Data', + 'barmode': 'group', + 'color_discrete_sequence': ['#000036', '#01E0C9'], + 'pattern_shape': 'Data', + 'pattern_shape_sequence': ['', '/'], + 'histnorm': 'probability density', + } + pd.testing.assert_frame_equal(expected_data, mock_histogram.call_args[0][0]) + mock_histogram.assert_called_once_with(ANY, **expected_parameters) + + +@patch('sdmetrics.visualization.ff.create_distplot') +def test__generate_column_distplot(mock_distplot): + """Test ``_generate_column_distplot`` functionality""" + # Setup + real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]}) + synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]}) + + # Run + _generate_column_distplot(real_data, synthetic_data) + + # Assert + expected_data = [] + expected_data.append(real_data['values']) + expected_data.append(synthetic_data['values']) + expected_data == mock_distplot.call_args[0][0] + expected_col = ['Real', 'Synthetic'] + expected_colors = [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN] + expected_parameters = { + 'show_hist': False, + 'show_rug': False, + 'colors': expected_colors, + } + assert expected_parameters == mock_distplot.call_args[1] + mock_distplot.assert_called_once_with(expected_data, expected_col, **expected_parameters) + + +@patch('sdmetrics.visualization._generate_column_distplot') +def test___generate_column_plot_type_distplot(mock_dist_plot): + """Test ``_generate_column_plot`` with a dist_plot""" + # Setup + real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]}) + synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]}) + mock_fig = Mock() + mock_object = Mock() + mock_object.x = [1, 2, 2, 3, 5] + mock_fig.data = [mock_object, mock_object] + mock_dist_plot.return_value = mock_fig + + # Run + _generate_column_plot(real_data['values'], synthetic_data['values'], 'distplot') + + # Assert + expected_real_data = pd.DataFrame({ + 'values': [1, 2, 2, 3, 5], + 'Data': ['Real', 'Real', 'Real', 'Real', 'Real'], + }) + expected_synth_data = pd.DataFrame({ + 'values': [2, 2, 3, 4, 5], + 'Data': ['Synthetic', 'Synthetic', 'Synthetic', 'Synthetic', 'Synthetic'], + }) + pd.testing.assert_frame_equal(mock_dist_plot.call_args[0][0], expected_real_data) + pd.testing.assert_frame_equal(mock_dist_plot.call_args[0][1], expected_synth_data) + mock_dist_plot.assert_called_once_with(ANY, ANY, {}) + + mock_fig.update_layout.assert_called_once_with( + title="Real vs. Synthetic Data for column 'values'", + xaxis_title='Value', + yaxis_title='Frequency', + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + annotations=[], + font={'size': PlotConfig.FONT_SIZE}, + ) + + +@patch('sdmetrics.visualization._generate_column_bar_plot') +def test___generate_column_plot_type_bar(mock_bar_plot): + """Test ``_generate_column_plot`` with a bar plot""" + # Setup + real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]}) + synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]}) + mock_fig = Mock() + mock_object = Mock() + mock_object.x = [1, 2, 2, 3, 5] + mock_fig.data = [mock_object, mock_object] + mock_bar_plot.return_value = mock_fig + + # Run + _generate_column_plot(real_data['values'], synthetic_data['values'], 'bar') + + # Assert + expected_real_data = pd.DataFrame({ + 'values': [1, 2, 2, 3, 5], + 'Data': ['Real', 'Real', 'Real', 'Real', 'Real'], + }) + expected_synth_data = pd.DataFrame({ + 'values': [2, 2, 3, 4, 5], + 'Data': ['Synthetic', 'Synthetic', 'Synthetic', 'Synthetic', 'Synthetic'], + }) + pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][0], expected_real_data) + pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][1], expected_synth_data) + mock_bar_plot.assert_called_once_with(ANY, ANY, {}) + mock_fig.update_layout.assert_called_once_with( + title="Real vs. Synthetic Data for column 'values'", + xaxis_title='Category', + yaxis_title='Frequency', + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + annotations=[], + font={'size': PlotConfig.FONT_SIZE}, + ) + + +@patch('sdmetrics.visualization._generate_column_bar_plot') +def test___generate_column_plot_with_datetimes(mock_bar_plot): + """Test ``_generate_column_plot`` using datetimes""" + # Setup + real_data = pd.DataFrame({'values': pd.to_datetime(['2021-01-20', '2022-01-21'])}) + synthetic_data = pd.DataFrame({'values': pd.to_datetime(['2021-01-20', '2022-01-21'])}) + mock_fig = Mock() + mock_object = Mock() + mock_object.x = [1, 2, 2, 3, 5] + mock_fig.data = [mock_object, mock_object] + mock_bar_plot.return_value = mock_fig + + # Run + _generate_column_plot(real_data['values'], synthetic_data['values'], 'bar') + + # Assert + print(mock_bar_plot.call_args[0][1]) + expected_real_data = pd.DataFrame({ + 'values': [1611100800000000000, 1642723200000000000], + 'Data': ['Real', 'Real'], + }) + expected_synth_data = pd.DataFrame({ + 'values': [1611100800000000000, 1642723200000000000], + 'Data': ['Synthetic', 'Synthetic'], + }) + pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][0], expected_real_data) + pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][1], expected_synth_data) + mock_bar_plot.assert_called_once_with(ANY, ANY, {}) + + +def test___generate_column_plot_no_data(): + """Test ``_generate_column_plot`` when no data is passed in.""" + # Run and Assert + error_msg = re.escape('No data provided to plot. Please provide either real or synthetic data.') + with pytest.raises(ValueError, match=error_msg): + _generate_column_plot(None, None, 'bar') + + +def test___generate_column_plot_with_bad_plot(): + """Test ``_generate_column_plot`` when an incorrect plot is set.""" + # Setup + real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]}) + synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]}) + # Run and Assert + error_msg = re.escape( + "Unrecognized plot_type 'bad_plot'. Please use one of 'bar' or 'distplot'" + ) + with pytest.raises(ValueError, match=error_msg): + _generate_column_plot(real_data, synthetic_data, 'bad_plot') + + +@patch('sdmetrics.visualization._generate_column_plot') +def test_get_column_plot_plot_one_data_set(mock__generate_column_plot): + """Test ``get_column_plot`` for real data and synthetic data individually.""" + # Setup + real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]}) + synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]}) + mock__generate_column_plot.side_effect = ['mock_return_1', 'mock_return_2'] + + # Run + fig_real = get_column_plot(real_data, None, 'values') + fig_synth = get_column_plot(None, synthetic_data, 'values') + + # Assert + expected_real_call_data = real_data['values'] + expected_synth_call_data = synthetic_data['values'] + expected_calls = [ + call(SeriesMatcher(expected_real_call_data), None, 'distplot'), + call(None, SeriesMatcher(expected_synth_call_data), 'distplot'), + ] + mock__generate_column_plot.assert_has_calls(expected_calls, any_order=False) + assert fig_real == 'mock_return_1' + assert fig_synth == 'mock_return_2' + + @patch('sdmetrics.visualization._generate_column_plot') def test_get_column_plot_plot_type_none_data_int(mock__generate_column_plot): """Test ``get_column_plot`` when ``plot_type`` is ``None`` and data is ``int``.""" @@ -666,6 +879,44 @@ def test_get_column_pair_plot_plot_type_none_continuous_data(mock__generate_scat assert fig == mock__generate_scatter_plot.return_value +@patch('sdmetrics.visualization._generate_scatter_plot') +def test_get_column_pair_plot_plot_single_data(mock__generate_scatter_plot): + """Test ``get_column_pair_plot`` with only real or synthetic data""" + # Setup + columns = ['amount', 'price'] + real_data = pd.DataFrame({'amount': [1, 2, 3], 'price': [4, 5, 6]}) + synthetic_data = pd.DataFrame({'amount': [1.0, 2.0, 3.0], 'price': [4.0, 5.0, 6.0]}) + mock__generate_scatter_plot.side_effect = ['mock_return_1', 'mock_return_2'] + + # Run + real_fig = get_column_pair_plot(real_data, None, columns) + synth_fig = get_column_pair_plot(None, synthetic_data, columns) + + # Assert + real_data['Data'] = 'Real' + synthetic_data['Data'] = 'Synthetic' + expected_real_call_data = real_data + expected_synth_call_data = synthetic_data + expected_calls = [ + call(DataFrameMatcher(expected_real_call_data), columns), + call(DataFrameMatcher(expected_synth_call_data), columns), + ] + mock__generate_scatter_plot.assert_has_calls(expected_calls, any_order=False) + assert real_fig == 'mock_return_1' + assert synth_fig == 'mock_return_2' + + +@patch('sdmetrics.visualization._generate_scatter_plot') +def test_get_column_pair_plot_plot_no_data(mock__generate_scatter_plot): + """Test ``get_column_pair_plot`` with neither real or synthetic data""" + # Setup + columns = ['amount', 'price'] + error_msg = re.escape('No data provided to plot. Please provide either real or synthetic data.') + # Run and Assert + with pytest.raises(ValueError, match=error_msg): + get_column_pair_plot(None, None, columns) + + @patch('sdmetrics.visualization._generate_scatter_plot') def test_get_column_pair_plot_plot_type_none_continuous_data_and_date(mock__generate_scatter_plot): """Test ``get_column_pair_plot`` with continuous data and ``plot_type`` ``None``."""