diff --git a/pyproject.toml b/pyproject.toml
index f95571d9..f0e87e4d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -30,9 +30,9 @@ dependencies = [
"scikit-learn>=1.1.0;python_version>='3.10' and python_version<'3.11'",
"scikit-learn>=1.1.3;python_version>='3.11' and python_version<'3.12'",
"scikit-learn>=1.3.1;python_version>='3.12'",
- "scipy>=1.7.3;python_version<'3.10'",
- "scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'",
- "scipy>=1.12.0;python_version>='3.12'",
+ "scipy>=1.7.3,<1.14.0;python_version<'3.10'",
+ "scipy>=1.9.2,<1.14.0;python_version>='3.10' and python_version<'3.12'",
+ "scipy>=1.12.0,<1.14.0;python_version>='3.12'",
'copulas>=0.11.0',
'tqdm>=4.29',
'plotly>=5.19.0',
diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py
index 20d3d329..64630271 100644
--- a/sdmetrics/visualization.py
+++ b/sdmetrics/visualization.py
@@ -46,10 +46,10 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
"""Generate a bar plot of the real and synthetic data.
Args:
- real_column (pandas.Series):
- The real data for the desired column.
- synthetic_column (pandas.Series):
- The synthetic data for the desired column.
+ real_column (pandas.Series or None):
+ The real data for the desired column. If None this data will not be graphed.
+ synthetic_column (pandas.Series or None):
+ The synthetic data for the desired column. If None this data will not be graphed.
plot_kwargs (dict, optional):
Dictionary of keyword arguments to pass to px.histogram. Keyword arguments
provided this way will overwrite defaults.
@@ -57,12 +57,20 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
Returns:
plotly.graph_objects._figure.Figure
"""
- all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
+ all_data = pd.DataFrame()
+ color_sequence = []
+ if real_data is not None:
+ all_data = pd.concat([all_data, real_data], axis=0, ignore_index=True)
+ color_sequence.append(PlotConfig.DATACEBO_DARK)
+ if synthetic_data is not None:
+ all_data = pd.concat([all_data, synthetic_data], axis=0, ignore_index=True)
+ color_sequence.append(PlotConfig.DATACEBO_GREEN)
+
histogram_kwargs = {
'x': 'values',
'color': 'Data',
'barmode': 'group',
- 'color_discrete_sequence': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN],
+ 'color_discrete_sequence': color_sequence,
'pattern_shape': 'Data',
'pattern_shape_sequence': ['', '/'],
'histnorm': 'probability density',
@@ -86,12 +94,20 @@ def _generate_heatmap_plot(all_data, columns):
Returns:
plotly.graph_objects._figure.Figure
"""
+ unique_values = all_data['Data'].unique()
+
+ if len(columns) != 2:
+ raise ValueError('Generating a heatmap plot requires exactly two columns for the axis.')
+
fig = px.density_heatmap(
all_data, x=columns[0], y=columns[1], facet_col='Data', histnorm='probability'
)
+ title = ' vs. '.join(unique_values)
+ title += f" Data for columns '{columns[0]}' and '{columns[1]}"
+
fig.update_layout(
- title_text=f"Real vs Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
+ title_text=title,
coloraxis={'colorscale': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]},
font={'size': PlotConfig.FONT_SIZE},
)
@@ -147,6 +163,11 @@ def _generate_scatter_plot(all_data, columns):
Returns:
plotly.graph_objects._figure.Figure
"""
+
+ if len(columns) != 2:
+ raise ValueError('Generating a scatter plot requires exactly two columns for the axis.')
+
+ unique_values = all_data['Data'].unique()
fig = px.scatter(
all_data,
x=columns[0],
@@ -159,8 +180,11 @@ def _generate_scatter_plot(all_data, columns):
symbol='Data',
)
+ title = ' vs. '.join(unique_values)
+ title += f" Data for columns '{columns[0]}' and '{columns[1]}'"
+
fig.update_layout(
- title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
+ title=title,
plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
font={'size': PlotConfig.FONT_SIZE},
)
@@ -172,10 +196,10 @@ def _generate_column_distplot(real_data, synthetic_data, plot_kwargs={}):
"""Plot the real and synthetic data as a distplot.
Args:
- real_data (pandas.DataFrame):
- The real data for the desired column.
- synthetic_data (pandas.DataFrame):
- The synthetic data for the desired column.
+ real_data (pandas.DataFrame or None):
+ The real data for the desired column. If None this data will not be graphed.
+ synthetic_data (pandas.DataFrame or None):
+ The synthetic data for the desired column. If None this data will not be graphed.
plot_kwargs (dict, optional):
Dictionary of keyword arguments to pass to px.histogram. Keyword arguments
provided this way will overwrite defaults.
@@ -183,15 +207,27 @@ def _generate_column_distplot(real_data, synthetic_data, plot_kwargs={}):
Returns:
plotly.graph_objects._figure.Figure
"""
+ hist_data = []
+ col_names = []
+ colors = []
+ if real_data is not None:
+ hist_data.append(real_data['values'])
+ col_names.append('Real')
+ colors.append(PlotConfig.DATACEBO_DARK)
+ if synthetic_data is not None:
+ hist_data.append(synthetic_data['values'])
+ col_names.append('Synthetic')
+ colors.append(PlotConfig.DATACEBO_GREEN)
+
default_distplot_kwargs = {
'show_hist': False,
'show_rug': False,
- 'colors': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN],
+ 'colors': colors,
}
fig = ff.create_distplot(
- [real_data['values'], synthetic_data['values']],
- ['Real', 'Synthetic'],
+ hist_data,
+ col_names,
**{**default_distplot_kwargs, **plot_kwargs},
)
@@ -204,10 +240,10 @@ def _generate_column_plot(
"""Generate a plot of the real and synthetic data.
Args:
- real_column (pandas.Series):
- The real data for the desired column.
- synthetic_column (pandas.Series):
- The synthetic data for the desired column.
+ real_column (pandas.Series or None):
+ The real data for the desired column. If None this data will not be graphed.
+ synthetic_column (pandas.Series or None)
+ The synthetic data for the desired column. If None this data will not be graphed.
plot_type (str):
The type of plot to use. Must be one of 'bar' or 'distplot'.
hist_kwargs (dict, optional):
@@ -221,26 +257,55 @@ def _generate_column_plot(
Returns:
plotly.graph_objects._figure.Figure
"""
+
+ if real_column is None and synthetic_column is None:
+ raise ValueError('No data provided to plot. Please provide either real or synthetic data.')
+
if plot_type not in ['bar', 'distplot']:
raise ValueError(
- "Unrecognized plot_type '{plot_type}'. Pleas use one of 'bar' or 'distplot'"
+ f"Unrecognized plot_type '{plot_type}'. Please use one of 'bar' or 'distplot'"
)
- column_name = real_column.name if hasattr(real_column, 'name') else ''
-
- missing_data_real = get_missing_percentage(real_column)
- missing_data_synthetic = get_missing_percentage(synthetic_column)
-
- real_data = pd.DataFrame({'values': real_column.copy().dropna()})
- real_data['Data'] = 'Real'
- synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()})
- synthetic_data['Data'] = 'Synthetic'
+ column_name = ''
+ missing_data_real = 0
+ missing_data_synthetic = 0
+ col_dtype = None
+ col_names = []
+ title = ''
+ if real_column is not None and hasattr(real_column, 'name'):
+ column_name = real_column.name
+ elif synthetic_column is not None and hasattr(synthetic_column, 'name'):
+ column_name = synthetic_column.name
+
+ real_data = None
+ if real_column is not None:
+ missing_data_real = get_missing_percentage(real_column)
+ real_data = pd.DataFrame({'values': real_column.copy().dropna()})
+ real_data['Data'] = 'Real'
+ col_dtype = real_column.dtype
+ col_names.append('Real')
+ title += 'Real vs. '
+
+ synthetic_data = None
+ if synthetic_column is not None:
+ missing_data_synthetic = get_missing_percentage(synthetic_column)
+ synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()})
+ synthetic_data['Data'] = 'Synthetic'
+ col_names.append('Synthetic')
+ title += 'Synthetic vs. '
+ if col_dtype is None:
+ col_dtype = synthetic_column.dtype
+
+ title = title[:-4]
+ title += f"Data for column '{column_name}'"
is_datetime_sdtype = False
- if is_datetime64_dtype(real_column.dtype):
+ if is_datetime64_dtype(col_dtype):
is_datetime_sdtype = True
- real_data['values'] = real_data['values'].astype('int64')
- synthetic_data['values'] = synthetic_data['values'].astype('int64')
+ if real_data is not None:
+ real_data['values'] = real_data['values'].astype('int64')
+ if synthetic_data is not None:
+ synthetic_data['values'] = synthetic_data['values'].astype('int64')
trace_args = {}
@@ -251,7 +316,7 @@ def _generate_column_plot(
fig = _generate_column_distplot(real_data, synthetic_data, plot_kwargs)
trace_args = {'fill': 'tozeroy'}
- for i, name in enumerate(['Real', 'Synthetic']):
+ for i, name in enumerate(col_names):
fig.update_traces(
x=pd.to_datetime(fig.data[i].x) if is_datetime_sdtype else fig.data[i].x,
hovertemplate=f'{name}
Frequency: %{{y}}',
@@ -260,6 +325,14 @@ def _generate_column_plot(
)
show_missing_values = missing_data_real > 0 or missing_data_synthetic > 0
+ text = '*Missing Values:'
+ if real_column is not None and show_missing_values:
+ text += f' Real Data ({missing_data_real}%), '
+ if synthetic_column is not None and show_missing_values:
+ text += f'Synthetic Data ({missing_data_synthetic}%), '
+
+ text = text[:-2]
+
annotations = (
[]
if not show_missing_values
@@ -270,16 +343,13 @@ def _generate_column_plot(
'x': 1.0,
'y': 1.05,
'showarrow': False,
- 'text': (
- f'*Missing Values: Real Data ({missing_data_real}%), '
- f'Synthetic Data ({missing_data_synthetic}%)'
- ),
+ 'text': text,
},
]
)
if not plot_title:
- plot_title = f"Real vs. Synthetic Data for column '{column_name}'"
+ plot_title = title
if not x_label:
x_label = 'Category'
@@ -401,10 +471,10 @@ def get_column_plot(real_data, synthetic_data, column_name, plot_type=None):
"""Return a plot of the real and synthetic data for a given column.
Args:
- real_data (pandas.DataFrame):
- The real table data.
- synthetic_data (pandas.DataFrame):
- The synthetic table data.
+ real_data (pandas.DataFrame or None):
+ The real table data. If None this data will not be graphed.
+ synthetic_data (pandas.DataFrame or None):
+ The synthetic table data. If None this data will not be graphed.
column_name (str):
The name of the column.
plot_type (str or None):
@@ -416,28 +486,39 @@ def get_column_plot(real_data, synthetic_data, column_name, plot_type=None):
Returns:
plotly.graph_objects._figure.Figure
"""
+
+ if real_data is None and synthetic_data is None:
+ raise ValueError('No data provided to plot. Please provide either real or synthetic data.')
+
if plot_type not in ['bar', 'distplot', None]:
raise ValueError(
f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot', None]."
)
- if column_name not in real_data.columns:
- raise ValueError(f"Column '{column_name}' not found in real table data.")
- if column_name not in synthetic_data.columns:
- raise ValueError(f"Column '{column_name}' not found in synthetic table data.")
+ column = None
+ real_column = None
+ synthetic_column = None
+ if real_data is not None:
+ if column_name not in real_data.columns:
+ raise ValueError(f"Column '{column_name}' not found in real table data.")
+ column = real_data[column_name]
+ real_column = real_data[column_name]
+
+ if synthetic_data is not None:
+ if column_name not in synthetic_data.columns:
+ raise ValueError(f"Column '{column_name}' not found in synthetic table data.")
+ if column is None:
+ column = synthetic_data[column_name]
+ synthetic_column = synthetic_data[column_name]
- real_column = real_data[column_name]
if plot_type is None:
- column_is_datetime = is_datetime(real_data[column_name])
- dtype = real_column.dropna().infer_objects().dtype.kind
+ column_is_datetime = is_datetime(column)
+ dtype = column.dropna().infer_objects().dtype.kind
if column_is_datetime or dtype in ('i', 'f'):
plot_type = 'distplot'
else:
plot_type = 'bar'
- real_column = real_data[column_name]
- synthetic_column = synthetic_data[column_name]
-
fig = _generate_column_plot(real_column, synthetic_column, plot_type)
return fig
@@ -448,10 +529,10 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
"""Return a plot of the real and synthetic data for a given column pair.
Args:
- real_data (pandas.DataFrame):
- The real table data.
- synthetic_column (pandas.Dataframe):
- The synthetic table data.
+ real_data (pandas.DataFrame or None):
+ The real table data. If None this data will not be graphed.
+ synthetic_column (pandas.Dataframe or None):
+ The synthetic table data. If None this data will not be graphed.
column_names (list[string]):
The names of the two columns to plot.
plot_type (str or None):
@@ -466,16 +547,23 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
if len(column_names) != 2:
raise ValueError('Must provide exactly two column names.')
- if not set(column_names).issubset(real_data.columns):
- raise ValueError(
- f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.'
- )
+ if real_data is None and synthetic_data is None:
+ raise ValueError('No data provided to plot. Please provide either real or synthetic data.')
- if not set(column_names).issubset(synthetic_data.columns):
- raise ValueError(
- f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} '
- 'in synthetic data.'
- )
+ if real_data is not None:
+ if not set(column_names).issubset(real_data.columns):
+ raise ValueError(
+ f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.'
+ )
+ real_data = real_data[column_names]
+
+ if synthetic_data is not None:
+ if not set(column_names).issubset(synthetic_data.columns):
+ raise ValueError(
+ f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} '
+ 'in synthetic data.'
+ )
+ synthetic_data = synthetic_data[column_names]
if plot_type not in ['box', 'heatmap', 'scatter', None]:
raise ValueError(
@@ -483,12 +571,13 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
"['box', 'heatmap', 'scatter', None]."
)
- real_data = real_data[column_names]
- synthetic_data = synthetic_data[column_names]
if plot_type is None:
plot_type = []
for column_name in column_names:
- column = real_data[column_name]
+ if real_data is not None:
+ column = real_data[column_name]
+ else:
+ column = synthetic_data[column_name]
dtype = column.dropna().infer_objects().dtype.kind
if dtype in ('i', 'f') or is_datetime(column):
plot_type.append('scatter')
@@ -501,19 +590,22 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
plot_type = plot_type.pop()
# Merge the real and synthetic data and add a flag ``Data`` to indicate each one.
- columns = list(real_data.columns)
- real_data = real_data.copy()
- real_data['Data'] = 'Real'
- synthetic_data = synthetic_data.copy()
- synthetic_data['Data'] = 'Synthetic'
- all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
+ all_data = pd.DataFrame()
+ if real_data is not None:
+ real_data = real_data.copy()
+ real_data['Data'] = 'Real'
+ all_data = pd.concat([all_data, real_data], axis=0, ignore_index=True)
+ if synthetic_data is not None:
+ synthetic_data = synthetic_data.copy()
+ synthetic_data['Data'] = 'Synthetic'
+ all_data = pd.concat([all_data, synthetic_data], axis=0, ignore_index=True)
if plot_type == 'scatter':
- return _generate_scatter_plot(all_data, columns)
+ return _generate_scatter_plot(all_data, column_names)
elif plot_type == 'heatmap':
- return _generate_heatmap_plot(all_data, columns)
+ return _generate_heatmap_plot(all_data, column_names)
- return _generate_box_plot(all_data, columns)
+ return _generate_box_plot(all_data, column_names)
def _generate_line_plot(real_data, synthetic_data, x_axis, y_axis, marker, annotations=None):
diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py
index 788e0714..7b3aaaa7 100644
--- a/tests/unit/test_visualization.py
+++ b/tests/unit/test_visualization.py
@@ -1,12 +1,16 @@
import re
-from unittest.mock import Mock, call, patch
+from unittest.mock import ANY, Mock, call, patch
import pandas as pd
import pytest
+from sdmetrics.reports.utils import PlotConfig
from sdmetrics.visualization import (
_generate_box_plot,
_generate_cardinality_plot,
+ _generate_column_bar_plot,
+ _generate_column_distplot,
+ _generate_column_plot,
_generate_heatmap_plot,
_generate_line_plot,
_generate_scatter_plot,
@@ -268,6 +272,215 @@ def test_get_column_plot_bad_plot_type():
get_column_plot(real_data, synthetic_data, 'valeus', plot_type='bad_type')
+def test_get_column_plot_no_data():
+ """Test the ``get_column_plot`` method with no data passed in."""
+ # Run and assert
+ error_msg = re.escape('No data provided to plot. Please provide either real or synthetic data.')
+ with pytest.raises(ValueError, match=error_msg):
+ get_column_plot(None, None, 'values')
+
+
+@patch('sdmetrics.visualization.px.histogram')
+def test__generate_column_bar_plot(mock_histogram):
+ """Test ``_generate_column_bar_plot`` functionality"""
+ # Setup
+ real_data = pd.DataFrame([1, 2, 2, 3, 5])
+ synthetic_data = pd.DataFrame([2, 2, 3, 4, 5])
+
+ # Run
+ _generate_column_bar_plot(real_data, synthetic_data)
+
+ # Assert
+ expected_data = pd.DataFrame(pd.concat([real_data, synthetic_data], axis=0, ignore_index=True))
+ expected_parameters = {
+ 'x': 'values',
+ 'color': 'Data',
+ 'barmode': 'group',
+ 'color_discrete_sequence': ['#000036', '#01E0C9'],
+ 'pattern_shape': 'Data',
+ 'pattern_shape_sequence': ['', '/'],
+ 'histnorm': 'probability density',
+ }
+ pd.testing.assert_frame_equal(expected_data, mock_histogram.call_args[0][0])
+ mock_histogram.assert_called_once_with(ANY, **expected_parameters)
+
+
+@patch('sdmetrics.visualization.ff.create_distplot')
+def test__generate_column_distplot(mock_distplot):
+ """Test ``_generate_column_distplot`` functionality"""
+ # Setup
+ real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]})
+ synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]})
+
+ # Run
+ _generate_column_distplot(real_data, synthetic_data)
+
+ # Assert
+ expected_data = []
+ expected_data.append(real_data['values'])
+ expected_data.append(synthetic_data['values'])
+ expected_data == mock_distplot.call_args[0][0]
+ expected_col = ['Real', 'Synthetic']
+ expected_colors = [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]
+ expected_parameters = {
+ 'show_hist': False,
+ 'show_rug': False,
+ 'colors': expected_colors,
+ }
+ assert expected_parameters == mock_distplot.call_args[1]
+ mock_distplot.assert_called_once_with(expected_data, expected_col, **expected_parameters)
+
+
+@patch('sdmetrics.visualization._generate_column_distplot')
+def test___generate_column_plot_type_distplot(mock_dist_plot):
+ """Test ``_generate_column_plot`` with a dist_plot"""
+ # Setup
+ real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]})
+ synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]})
+ mock_fig = Mock()
+ mock_object = Mock()
+ mock_object.x = [1, 2, 2, 3, 5]
+ mock_fig.data = [mock_object, mock_object]
+ mock_dist_plot.return_value = mock_fig
+
+ # Run
+ _generate_column_plot(real_data['values'], synthetic_data['values'], 'distplot')
+
+ # Assert
+ expected_real_data = pd.DataFrame({
+ 'values': [1, 2, 2, 3, 5],
+ 'Data': ['Real', 'Real', 'Real', 'Real', 'Real'],
+ })
+ expected_synth_data = pd.DataFrame({
+ 'values': [2, 2, 3, 4, 5],
+ 'Data': ['Synthetic', 'Synthetic', 'Synthetic', 'Synthetic', 'Synthetic'],
+ })
+ pd.testing.assert_frame_equal(mock_dist_plot.call_args[0][0], expected_real_data)
+ pd.testing.assert_frame_equal(mock_dist_plot.call_args[0][1], expected_synth_data)
+ mock_dist_plot.assert_called_once_with(ANY, ANY, {})
+
+ mock_fig.update_layout.assert_called_once_with(
+ title="Real vs. Synthetic Data for column 'values'",
+ xaxis_title='Value',
+ yaxis_title='Frequency',
+ plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
+ annotations=[],
+ font={'size': PlotConfig.FONT_SIZE},
+ )
+
+
+@patch('sdmetrics.visualization._generate_column_bar_plot')
+def test___generate_column_plot_type_bar(mock_bar_plot):
+ """Test ``_generate_column_plot`` with a bar plot"""
+ # Setup
+ real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]})
+ synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]})
+ mock_fig = Mock()
+ mock_object = Mock()
+ mock_object.x = [1, 2, 2, 3, 5]
+ mock_fig.data = [mock_object, mock_object]
+ mock_bar_plot.return_value = mock_fig
+
+ # Run
+ _generate_column_plot(real_data['values'], synthetic_data['values'], 'bar')
+
+ # Assert
+ expected_real_data = pd.DataFrame({
+ 'values': [1, 2, 2, 3, 5],
+ 'Data': ['Real', 'Real', 'Real', 'Real', 'Real'],
+ })
+ expected_synth_data = pd.DataFrame({
+ 'values': [2, 2, 3, 4, 5],
+ 'Data': ['Synthetic', 'Synthetic', 'Synthetic', 'Synthetic', 'Synthetic'],
+ })
+ pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][0], expected_real_data)
+ pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][1], expected_synth_data)
+ mock_bar_plot.assert_called_once_with(ANY, ANY, {})
+ mock_fig.update_layout.assert_called_once_with(
+ title="Real vs. Synthetic Data for column 'values'",
+ xaxis_title='Category',
+ yaxis_title='Frequency',
+ plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
+ annotations=[],
+ font={'size': PlotConfig.FONT_SIZE},
+ )
+
+
+@patch('sdmetrics.visualization._generate_column_bar_plot')
+def test___generate_column_plot_with_datetimes(mock_bar_plot):
+ """Test ``_generate_column_plot`` using datetimes"""
+ # Setup
+ real_data = pd.DataFrame({'values': pd.to_datetime(['2021-01-20', '2022-01-21'])})
+ synthetic_data = pd.DataFrame({'values': pd.to_datetime(['2021-01-20', '2022-01-21'])})
+ mock_fig = Mock()
+ mock_object = Mock()
+ mock_object.x = [1, 2, 2, 3, 5]
+ mock_fig.data = [mock_object, mock_object]
+ mock_bar_plot.return_value = mock_fig
+
+ # Run
+ _generate_column_plot(real_data['values'], synthetic_data['values'], 'bar')
+
+ # Assert
+ print(mock_bar_plot.call_args[0][1])
+ expected_real_data = pd.DataFrame({
+ 'values': [1611100800000000000, 1642723200000000000],
+ 'Data': ['Real', 'Real'],
+ })
+ expected_synth_data = pd.DataFrame({
+ 'values': [1611100800000000000, 1642723200000000000],
+ 'Data': ['Synthetic', 'Synthetic'],
+ })
+ pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][0], expected_real_data)
+ pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][1], expected_synth_data)
+ mock_bar_plot.assert_called_once_with(ANY, ANY, {})
+
+
+def test___generate_column_plot_no_data():
+ """Test ``_generate_column_plot`` when no data is passed in."""
+ # Run and Assert
+ error_msg = re.escape('No data provided to plot. Please provide either real or synthetic data.')
+ with pytest.raises(ValueError, match=error_msg):
+ _generate_column_plot(None, None, 'bar')
+
+
+def test___generate_column_plot_with_bad_plot():
+ """Test ``_generate_column_plot`` when an incorrect plot is set."""
+ # Setup
+ real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]})
+ synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]})
+ # Run and Assert
+ error_msg = re.escape(
+ "Unrecognized plot_type 'bad_plot'. Please use one of 'bar' or 'distplot'"
+ )
+ with pytest.raises(ValueError, match=error_msg):
+ _generate_column_plot(real_data, synthetic_data, 'bad_plot')
+
+
+@patch('sdmetrics.visualization._generate_column_plot')
+def test_get_column_plot_plot_one_data_set(mock__generate_column_plot):
+ """Test ``get_column_plot`` for real data and synthetic data individually."""
+ # Setup
+ real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]})
+ synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]})
+ mock__generate_column_plot.side_effect = ['mock_return_1', 'mock_return_2']
+
+ # Run
+ fig_real = get_column_plot(real_data, None, 'values')
+ fig_synth = get_column_plot(None, synthetic_data, 'values')
+
+ # Assert
+ expected_real_call_data = real_data['values']
+ expected_synth_call_data = synthetic_data['values']
+ expected_calls = [
+ call(SeriesMatcher(expected_real_call_data), None, 'distplot'),
+ call(None, SeriesMatcher(expected_synth_call_data), 'distplot'),
+ ]
+ mock__generate_column_plot.assert_has_calls(expected_calls, any_order=False)
+ assert fig_real == 'mock_return_1'
+ assert fig_synth == 'mock_return_2'
+
+
@patch('sdmetrics.visualization._generate_column_plot')
def test_get_column_plot_plot_type_none_data_int(mock__generate_column_plot):
"""Test ``get_column_plot`` when ``plot_type`` is ``None`` and data is ``int``."""
@@ -666,6 +879,44 @@ def test_get_column_pair_plot_plot_type_none_continuous_data(mock__generate_scat
assert fig == mock__generate_scatter_plot.return_value
+@patch('sdmetrics.visualization._generate_scatter_plot')
+def test_get_column_pair_plot_plot_single_data(mock__generate_scatter_plot):
+ """Test ``get_column_pair_plot`` with only real or synthetic data"""
+ # Setup
+ columns = ['amount', 'price']
+ real_data = pd.DataFrame({'amount': [1, 2, 3], 'price': [4, 5, 6]})
+ synthetic_data = pd.DataFrame({'amount': [1.0, 2.0, 3.0], 'price': [4.0, 5.0, 6.0]})
+ mock__generate_scatter_plot.side_effect = ['mock_return_1', 'mock_return_2']
+
+ # Run
+ real_fig = get_column_pair_plot(real_data, None, columns)
+ synth_fig = get_column_pair_plot(None, synthetic_data, columns)
+
+ # Assert
+ real_data['Data'] = 'Real'
+ synthetic_data['Data'] = 'Synthetic'
+ expected_real_call_data = real_data
+ expected_synth_call_data = synthetic_data
+ expected_calls = [
+ call(DataFrameMatcher(expected_real_call_data), columns),
+ call(DataFrameMatcher(expected_synth_call_data), columns),
+ ]
+ mock__generate_scatter_plot.assert_has_calls(expected_calls, any_order=False)
+ assert real_fig == 'mock_return_1'
+ assert synth_fig == 'mock_return_2'
+
+
+@patch('sdmetrics.visualization._generate_scatter_plot')
+def test_get_column_pair_plot_plot_no_data(mock__generate_scatter_plot):
+ """Test ``get_column_pair_plot`` with neither real or synthetic data"""
+ # Setup
+ columns = ['amount', 'price']
+ error_msg = re.escape('No data provided to plot. Please provide either real or synthetic data.')
+ # Run and Assert
+ with pytest.raises(ValueError, match=error_msg):
+ get_column_pair_plot(None, None, columns)
+
+
@patch('sdmetrics.visualization._generate_scatter_plot')
def test_get_column_pair_plot_plot_type_none_continuous_data_and_date(mock__generate_scatter_plot):
"""Test ``get_column_pair_plot`` with continuous data and ``plot_type`` ``None``."""