Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow get_column_pair_plot to visualize one dataset instead of both real and synthetic #595

Merged
merged 9 commits into from
Jun 25, 2024
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ dependencies = [
"scikit-learn>=1.1.0;python_version>='3.10' and python_version<'3.11'",
"scikit-learn>=1.1.3;python_version>='3.11' and python_version<'3.12'",
"scikit-learn>=1.3.1;python_version>='3.12'",
"scipy>=1.7.3;python_version<'3.10'",
"scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'",
"scipy>=1.12.0;python_version>='3.12'",
"scipy>=1.7.3,<1.14.0;python_version<'3.10'",
"scipy>=1.9.2,<1.14.0;python_version>='3.10' and python_version<'3.12'",
"scipy>=1.12.0,<1.14.0;python_version>='3.12'",
'copulas>=0.11.0',
'tqdm>=4.29',
'plotly>=5.19.0',
Expand Down
81 changes: 54 additions & 27 deletions sdmetrics/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,20 @@ def _generate_heatmap_plot(all_data, columns):
Returns:
plotly.graph_objects._figure.Figure
"""
unique_values = all_data['Data'].unique()

if len(columns) != 2:
raise ValueError('Generating a heatmap plot requires exactly two columns for the axis.')

fig = px.density_heatmap(
all_data, x=columns[0], y=columns[1], facet_col='Data', histnorm='probability'
)

title = ' vs. '.join(unique_values)
title += f" Data for columns '{columns[0]}' and '{columns[1]}"

fig.update_layout(
title_text=f"Real vs Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
title_text=title,
coloraxis={'colorscale': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]},
font={'size': PlotConfig.FONT_SIZE},
)
Expand Down Expand Up @@ -147,6 +155,11 @@ def _generate_scatter_plot(all_data, columns):
Returns:
plotly.graph_objects._figure.Figure
"""

if len(columns) != 2:
raise ValueError('Generating a scatter plot requires exactly two columns for the axis.')

unique_values = all_data['Data'].unique()
fig = px.scatter(
all_data,
x=columns[0],
Expand All @@ -159,8 +172,11 @@ def _generate_scatter_plot(all_data, columns):
symbol='Data',
)

title = ' vs. '.join(unique_values)
title += f" Data for columns '{columns[0]}' and '{columns[1]}'"

fig.update_layout(
title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
title=title,
plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
font={'size': PlotConfig.FONT_SIZE},
)
Expand Down Expand Up @@ -448,10 +464,10 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
"""Return a plot of the real and synthetic data for a given column pair.

Args:
real_data (pandas.DataFrame):
The real table data.
synthetic_column (pandas.Dataframe):
The synthetic table data.
real_data (pandas.DataFrame or None):
The real table data. If None this data will not be graphed.
synthetic_column (pandas.Dataframe or None):
The synthetic table data. If None this data will not be graphed.
column_names (list[string]):
The names of the two columns to plot.
plot_type (str or None):
Expand All @@ -466,29 +482,37 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
if len(column_names) != 2:
raise ValueError('Must provide exactly two column names.')

if not set(column_names).issubset(real_data.columns):
raise ValueError(
f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.'
)
if real_data is None and synthetic_data is None:
raise ValueError('No data provided to plot. Please provide either real or synthetic data.')

if not set(column_names).issubset(synthetic_data.columns):
raise ValueError(
f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} '
'in synthetic data.'
)
if real_data is not None:
if not set(column_names).issubset(real_data.columns):
raise ValueError(
f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.'
)
real_data = real_data[column_names]

if synthetic_data is not None:
if not set(column_names).issubset(synthetic_data.columns):
raise ValueError(
f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} '
'in synthetic data.'
)
synthetic_data = synthetic_data[column_names]

if plot_type not in ['box', 'heatmap', 'scatter', None]:
raise ValueError(
f"Invalid plot_type '{plot_type}'. Please use one of "
"['box', 'heatmap', 'scatter', None]."
)

real_data = real_data[column_names]
synthetic_data = synthetic_data[column_names]
if plot_type is None:
plot_type = []
for column_name in column_names:
column = real_data[column_name]
if real_data is not None:
column = real_data[column_name]
else:
column = synthetic_data[column_name]
dtype = column.dropna().infer_objects().dtype.kind
if dtype in ('i', 'f') or is_datetime(column):
plot_type.append('scatter')
Expand All @@ -501,19 +525,22 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
plot_type = plot_type.pop()

# Merge the real and synthetic data and add a flag ``Data`` to indicate each one.
columns = list(real_data.columns)
real_data = real_data.copy()
real_data['Data'] = 'Real'
synthetic_data = synthetic_data.copy()
synthetic_data['Data'] = 'Synthetic'
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
all_data = pd.DataFrame()
if real_data is not None:
real_data = real_data.copy()
real_data['Data'] = 'Real'
all_data = pd.concat([all_data, real_data], axis=0, ignore_index=True)
if synthetic_data is not None:
synthetic_data = synthetic_data.copy()
synthetic_data['Data'] = 'Synthetic'
all_data = pd.concat([all_data, synthetic_data], axis=0, ignore_index=True)

if plot_type == 'scatter':
return _generate_scatter_plot(all_data, columns)
return _generate_scatter_plot(all_data, column_names)
elif plot_type == 'heatmap':
return _generate_heatmap_plot(all_data, columns)
return _generate_heatmap_plot(all_data, column_names)

return _generate_box_plot(all_data, columns)
return _generate_box_plot(all_data, column_names)


def _generate_line_plot(real_data, synthetic_data, x_axis, y_axis, marker, annotations=None):
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,44 @@ def test_get_column_pair_plot_plot_type_none_continuous_data(mock__generate_scat
assert fig == mock__generate_scatter_plot.return_value


@patch('sdmetrics.visualization._generate_scatter_plot')
def test_get_column_pair_plot_plot_single_data(mock__generate_scatter_plot):
"""Test ``get_column_pair_plot`` with only real or synthetic data"""
# Setup
columns = ['amount', 'price']
real_data = pd.DataFrame({'amount': [1, 2, 3], 'price': [4, 5, 6]})
synthetic_data = pd.DataFrame({'amount': [1.0, 2.0, 3.0], 'price': [4.0, 5.0, 6.0]})
mock__generate_scatter_plot.side_effect = ['mock_return_1', 'mock_return_2']

# Run
real_fig = get_column_pair_plot(real_data, None, columns)
synth_fig = get_column_pair_plot(None, synthetic_data, columns)

# Assert
real_data['Data'] = 'Real'
synthetic_data['Data'] = 'Synthetic'
expected_real_call_data = real_data
expected_synth_call_data = synthetic_data
expected_calls = [
call(DataFrameMatcher(expected_real_call_data), columns),
call(DataFrameMatcher(expected_synth_call_data), columns),
]
mock__generate_scatter_plot.assert_has_calls(expected_calls, any_order=False)
assert real_fig == 'mock_return_1'
assert synth_fig == 'mock_return_2'


@patch('sdmetrics.visualization._generate_scatter_plot')
def test_get_column_pair_plot_plot_no_data(mock__generate_scatter_plot):
"""Test ``get_column_pair_plot`` with neither real or synthetic data"""
# Setup
columns = ['amount', 'price']
error_msg = re.escape('No data provided to plot. Please provide either real or synthetic data.')
# Run and Assert
with pytest.raises(ValueError, match=error_msg):
get_column_pair_plot(None, None, columns)


@patch('sdmetrics.visualization._generate_scatter_plot')
def test_get_column_pair_plot_plot_type_none_continuous_data_and_date(mock__generate_scatter_plot):
"""Test ``get_column_pair_plot`` with continuous data and ``plot_type`` ``None``."""
Expand Down
Loading