diff --git a/pyproject.toml b/pyproject.toml index f95571d9..f0e87e4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,9 +30,9 @@ dependencies = [ "scikit-learn>=1.1.0;python_version>='3.10' and python_version<'3.11'", "scikit-learn>=1.1.3;python_version>='3.11' and python_version<'3.12'", "scikit-learn>=1.3.1;python_version>='3.12'", - "scipy>=1.7.3;python_version<'3.10'", - "scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'", - "scipy>=1.12.0;python_version>='3.12'", + "scipy>=1.7.3,<1.14.0;python_version<'3.10'", + "scipy>=1.9.2,<1.14.0;python_version>='3.10' and python_version<'3.12'", + "scipy>=1.12.0,<1.14.0;python_version>='3.12'", 'copulas>=0.11.0', 'tqdm>=4.29', 'plotly>=5.19.0', diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index 20d3d329..4af7fe35 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -86,12 +86,20 @@ def _generate_heatmap_plot(all_data, columns): Returns: plotly.graph_objects._figure.Figure """ + unique_values = all_data['Data'].unique() + + if len(columns) != 2: + raise ValueError('Generating a heatmap plot requires exactly two columns for the axis.') + fig = px.density_heatmap( all_data, x=columns[0], y=columns[1], facet_col='Data', histnorm='probability' ) + title = ' vs. '.join(unique_values) + title += f" Data for columns '{columns[0]}' and '{columns[1]}" + fig.update_layout( - title_text=f"Real vs Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", + title_text=title, coloraxis={'colorscale': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]}, font={'size': PlotConfig.FONT_SIZE}, ) @@ -147,6 +155,11 @@ def _generate_scatter_plot(all_data, columns): Returns: plotly.graph_objects._figure.Figure """ + + if len(columns) != 2: + raise ValueError('Generating a scatter plot requires exactly two columns for the axis.') + + unique_values = all_data['Data'].unique() fig = px.scatter( all_data, x=columns[0], @@ -159,8 +172,11 @@ def _generate_scatter_plot(all_data, columns): symbol='Data', ) + title = ' vs. '.join(unique_values) + title += f" Data for columns '{columns[0]}' and '{columns[1]}'" + fig.update_layout( - title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", + title=title, plot_bgcolor=PlotConfig.BACKGROUND_COLOR, font={'size': PlotConfig.FONT_SIZE}, ) @@ -448,10 +464,10 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None """Return a plot of the real and synthetic data for a given column pair. Args: - real_data (pandas.DataFrame): - The real table data. - synthetic_column (pandas.Dataframe): - The synthetic table data. + real_data (pandas.DataFrame or None): + The real table data. If None this data will not be graphed. + synthetic_column (pandas.Dataframe or None): + The synthetic table data. If None this data will not be graphed. column_names (list[string]): The names of the two columns to plot. plot_type (str or None): @@ -466,16 +482,23 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None if len(column_names) != 2: raise ValueError('Must provide exactly two column names.') - if not set(column_names).issubset(real_data.columns): - raise ValueError( - f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.' - ) + if real_data is None and synthetic_data is None: + raise ValueError('No data provided to plot. Please provide either real or synthetic data.') - if not set(column_names).issubset(synthetic_data.columns): - raise ValueError( - f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} ' - 'in synthetic data.' - ) + if real_data is not None: + if not set(column_names).issubset(real_data.columns): + raise ValueError( + f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.' + ) + real_data = real_data[column_names] + + if synthetic_data is not None: + if not set(column_names).issubset(synthetic_data.columns): + raise ValueError( + f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} ' + 'in synthetic data.' + ) + synthetic_data = synthetic_data[column_names] if plot_type not in ['box', 'heatmap', 'scatter', None]: raise ValueError( @@ -483,12 +506,13 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None "['box', 'heatmap', 'scatter', None]." ) - real_data = real_data[column_names] - synthetic_data = synthetic_data[column_names] if plot_type is None: plot_type = [] for column_name in column_names: - column = real_data[column_name] + if real_data is not None: + column = real_data[column_name] + else: + column = synthetic_data[column_name] dtype = column.dropna().infer_objects().dtype.kind if dtype in ('i', 'f') or is_datetime(column): plot_type.append('scatter') @@ -501,19 +525,22 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None plot_type = plot_type.pop() # Merge the real and synthetic data and add a flag ``Data`` to indicate each one. - columns = list(real_data.columns) - real_data = real_data.copy() - real_data['Data'] = 'Real' - synthetic_data = synthetic_data.copy() - synthetic_data['Data'] = 'Synthetic' - all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) + all_data = pd.DataFrame() + if real_data is not None: + real_data = real_data.copy() + real_data['Data'] = 'Real' + all_data = pd.concat([all_data, real_data], axis=0, ignore_index=True) + if synthetic_data is not None: + synthetic_data = synthetic_data.copy() + synthetic_data['Data'] = 'Synthetic' + all_data = pd.concat([all_data, synthetic_data], axis=0, ignore_index=True) if plot_type == 'scatter': - return _generate_scatter_plot(all_data, columns) + return _generate_scatter_plot(all_data, column_names) elif plot_type == 'heatmap': - return _generate_heatmap_plot(all_data, columns) + return _generate_heatmap_plot(all_data, column_names) - return _generate_box_plot(all_data, columns) + return _generate_box_plot(all_data, column_names) def _generate_line_plot(real_data, synthetic_data, x_axis, y_axis, marker, annotations=None): diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 788e0714..33f59816 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -666,6 +666,44 @@ def test_get_column_pair_plot_plot_type_none_continuous_data(mock__generate_scat assert fig == mock__generate_scatter_plot.return_value +@patch('sdmetrics.visualization._generate_scatter_plot') +def test_get_column_pair_plot_plot_single_data(mock__generate_scatter_plot): + """Test ``get_column_pair_plot`` with only real or synthetic data""" + # Setup + columns = ['amount', 'price'] + real_data = pd.DataFrame({'amount': [1, 2, 3], 'price': [4, 5, 6]}) + synthetic_data = pd.DataFrame({'amount': [1.0, 2.0, 3.0], 'price': [4.0, 5.0, 6.0]}) + mock__generate_scatter_plot.side_effect = ['mock_return_1', 'mock_return_2'] + + # Run + real_fig = get_column_pair_plot(real_data, None, columns) + synth_fig = get_column_pair_plot(None, synthetic_data, columns) + + # Assert + real_data['Data'] = 'Real' + synthetic_data['Data'] = 'Synthetic' + expected_real_call_data = real_data + expected_synth_call_data = synthetic_data + expected_calls = [ + call(DataFrameMatcher(expected_real_call_data), columns), + call(DataFrameMatcher(expected_synth_call_data), columns), + ] + mock__generate_scatter_plot.assert_has_calls(expected_calls, any_order=False) + assert real_fig == 'mock_return_1' + assert synth_fig == 'mock_return_2' + + +@patch('sdmetrics.visualization._generate_scatter_plot') +def test_get_column_pair_plot_plot_no_data(mock__generate_scatter_plot): + """Test ``get_column_pair_plot`` with neither real or synthetic data""" + # Setup + columns = ['amount', 'price'] + error_msg = re.escape('No data provided to plot. Please provide either real or synthetic data.') + # Run and Assert + with pytest.raises(ValueError, match=error_msg): + get_column_pair_plot(None, None, columns) + + @patch('sdmetrics.visualization._generate_scatter_plot') def test_get_column_pair_plot_plot_type_none_continuous_data_and_date(mock__generate_scatter_plot): """Test ``get_column_pair_plot`` with continuous data and ``plot_type`` ``None``."""