Skip to content

Commit

Permalink
Allow get_column_pair_plot to visualize one dataset instead of both r…
Browse files Browse the repository at this point in the history
…eal and synthetic (#595)
  • Loading branch information
lajohn4747 authored Jun 25, 2024
1 parent 7d2f508 commit 9faf45d
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 30 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ dependencies = [
"scikit-learn>=1.1.0;python_version>='3.10' and python_version<'3.11'",
"scikit-learn>=1.1.3;python_version>='3.11' and python_version<'3.12'",
"scikit-learn>=1.3.1;python_version>='3.12'",
"scipy>=1.7.3;python_version<'3.10'",
"scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'",
"scipy>=1.12.0;python_version>='3.12'",
"scipy>=1.7.3,<1.14.0;python_version<'3.10'",
"scipy>=1.9.2,<1.14.0;python_version>='3.10' and python_version<'3.12'",
"scipy>=1.12.0,<1.14.0;python_version>='3.12'",
'copulas>=0.11.0',
'tqdm>=4.29',
'plotly>=5.19.0',
Expand Down
81 changes: 54 additions & 27 deletions sdmetrics/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,20 @@ def _generate_heatmap_plot(all_data, columns):
Returns:
plotly.graph_objects._figure.Figure
"""
unique_values = all_data['Data'].unique()

if len(columns) != 2:
raise ValueError('Generating a heatmap plot requires exactly two columns for the axis.')

fig = px.density_heatmap(
all_data, x=columns[0], y=columns[1], facet_col='Data', histnorm='probability'
)

title = ' vs. '.join(unique_values)
title += f" Data for columns '{columns[0]}' and '{columns[1]}"

fig.update_layout(
title_text=f"Real vs Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
title_text=title,
coloraxis={'colorscale': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]},
font={'size': PlotConfig.FONT_SIZE},
)
Expand Down Expand Up @@ -147,6 +155,11 @@ def _generate_scatter_plot(all_data, columns):
Returns:
plotly.graph_objects._figure.Figure
"""

if len(columns) != 2:
raise ValueError('Generating a scatter plot requires exactly two columns for the axis.')

unique_values = all_data['Data'].unique()
fig = px.scatter(
all_data,
x=columns[0],
Expand All @@ -159,8 +172,11 @@ def _generate_scatter_plot(all_data, columns):
symbol='Data',
)

title = ' vs. '.join(unique_values)
title += f" Data for columns '{columns[0]}' and '{columns[1]}'"

fig.update_layout(
title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
title=title,
plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
font={'size': PlotConfig.FONT_SIZE},
)
Expand Down Expand Up @@ -448,10 +464,10 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
"""Return a plot of the real and synthetic data for a given column pair.
Args:
real_data (pandas.DataFrame):
The real table data.
synthetic_column (pandas.Dataframe):
The synthetic table data.
real_data (pandas.DataFrame or None):
The real table data. If None this data will not be graphed.
synthetic_column (pandas.Dataframe or None):
The synthetic table data. If None this data will not be graphed.
column_names (list[string]):
The names of the two columns to plot.
plot_type (str or None):
Expand All @@ -466,29 +482,37 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
if len(column_names) != 2:
raise ValueError('Must provide exactly two column names.')

if not set(column_names).issubset(real_data.columns):
raise ValueError(
f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.'
)
if real_data is None and synthetic_data is None:
raise ValueError('No data provided to plot. Please provide either real or synthetic data.')

if not set(column_names).issubset(synthetic_data.columns):
raise ValueError(
f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} '
'in synthetic data.'
)
if real_data is not None:
if not set(column_names).issubset(real_data.columns):
raise ValueError(
f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.'
)
real_data = real_data[column_names]

if synthetic_data is not None:
if not set(column_names).issubset(synthetic_data.columns):
raise ValueError(
f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} '
'in synthetic data.'
)
synthetic_data = synthetic_data[column_names]

if plot_type not in ['box', 'heatmap', 'scatter', None]:
raise ValueError(
f"Invalid plot_type '{plot_type}'. Please use one of "
"['box', 'heatmap', 'scatter', None]."
)

real_data = real_data[column_names]
synthetic_data = synthetic_data[column_names]
if plot_type is None:
plot_type = []
for column_name in column_names:
column = real_data[column_name]
if real_data is not None:
column = real_data[column_name]
else:
column = synthetic_data[column_name]
dtype = column.dropna().infer_objects().dtype.kind
if dtype in ('i', 'f') or is_datetime(column):
plot_type.append('scatter')
Expand All @@ -501,19 +525,22 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
plot_type = plot_type.pop()

# Merge the real and synthetic data and add a flag ``Data`` to indicate each one.
columns = list(real_data.columns)
real_data = real_data.copy()
real_data['Data'] = 'Real'
synthetic_data = synthetic_data.copy()
synthetic_data['Data'] = 'Synthetic'
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
all_data = pd.DataFrame()
if real_data is not None:
real_data = real_data.copy()
real_data['Data'] = 'Real'
all_data = pd.concat([all_data, real_data], axis=0, ignore_index=True)
if synthetic_data is not None:
synthetic_data = synthetic_data.copy()
synthetic_data['Data'] = 'Synthetic'
all_data = pd.concat([all_data, synthetic_data], axis=0, ignore_index=True)

if plot_type == 'scatter':
return _generate_scatter_plot(all_data, columns)
return _generate_scatter_plot(all_data, column_names)
elif plot_type == 'heatmap':
return _generate_heatmap_plot(all_data, columns)
return _generate_heatmap_plot(all_data, column_names)

return _generate_box_plot(all_data, columns)
return _generate_box_plot(all_data, column_names)


def _generate_line_plot(real_data, synthetic_data, x_axis, y_axis, marker, annotations=None):
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,44 @@ def test_get_column_pair_plot_plot_type_none_continuous_data(mock__generate_scat
assert fig == mock__generate_scatter_plot.return_value


@patch('sdmetrics.visualization._generate_scatter_plot')
def test_get_column_pair_plot_plot_single_data(mock__generate_scatter_plot):
"""Test ``get_column_pair_plot`` with only real or synthetic data"""
# Setup
columns = ['amount', 'price']
real_data = pd.DataFrame({'amount': [1, 2, 3], 'price': [4, 5, 6]})
synthetic_data = pd.DataFrame({'amount': [1.0, 2.0, 3.0], 'price': [4.0, 5.0, 6.0]})
mock__generate_scatter_plot.side_effect = ['mock_return_1', 'mock_return_2']

# Run
real_fig = get_column_pair_plot(real_data, None, columns)
synth_fig = get_column_pair_plot(None, synthetic_data, columns)

# Assert
real_data['Data'] = 'Real'
synthetic_data['Data'] = 'Synthetic'
expected_real_call_data = real_data
expected_synth_call_data = synthetic_data
expected_calls = [
call(DataFrameMatcher(expected_real_call_data), columns),
call(DataFrameMatcher(expected_synth_call_data), columns),
]
mock__generate_scatter_plot.assert_has_calls(expected_calls, any_order=False)
assert real_fig == 'mock_return_1'
assert synth_fig == 'mock_return_2'


@patch('sdmetrics.visualization._generate_scatter_plot')
def test_get_column_pair_plot_plot_no_data(mock__generate_scatter_plot):
"""Test ``get_column_pair_plot`` with neither real or synthetic data"""
# Setup
columns = ['amount', 'price']
error_msg = re.escape('No data provided to plot. Please provide either real or synthetic data.')
# Run and Assert
with pytest.raises(ValueError, match=error_msg):
get_column_pair_plot(None, None, columns)


@patch('sdmetrics.visualization._generate_scatter_plot')
def test_get_column_pair_plot_plot_type_none_continuous_data_and_date(mock__generate_scatter_plot):
"""Test ``get_column_pair_plot`` with continuous data and ``plot_type`` ``None``."""
Expand Down

0 comments on commit 9faf45d

Please sign in to comment.