From 50306410eb142e1ec41b2f8c536bcccfed43231b Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Thu, 20 Jun 2024 11:15:19 -0500 Subject: [PATCH 01/15] Allow get_column_pair_plot to visualize one dataset instead of both real and synthetic --- sdmetrics/visualization.py | 61 +++++++++++++++++++------------- tests/unit/test_visualization.py | 38 ++++++++++++++++++++ 2 files changed, 74 insertions(+), 25 deletions(-) diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index 20d3d329..fe52a74d 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -448,10 +448,10 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None """Return a plot of the real and synthetic data for a given column pair. Args: - real_data (pandas.DataFrame): - The real table data. - synthetic_column (pandas.Dataframe): - The synthetic table data. + real_data (pandas.DataFrame or None): + The real table data. If None this data will not be graphed. + synthetic_column (pandas.Dataframe or None): + The synthetic table data. If None this data will not be graphed. column_names (list[string]): The names of the two columns to plot. plot_type (str or None): @@ -466,16 +466,23 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None if len(column_names) != 2: raise ValueError('Must provide exactly two column names.') - if not set(column_names).issubset(real_data.columns): - raise ValueError( - f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.' - ) + if real_data is None and synthetic_data is None: + raise ValueError('Must provide at least one dataset to plot.') - if not set(column_names).issubset(synthetic_data.columns): - raise ValueError( - f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} ' - 'in synthetic data.' - ) + if real_data is not None: + if not set(column_names).issubset(real_data.columns): + raise ValueError( + f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.' + ) + real_data = real_data[column_names] + + if synthetic_data is not None: + if not set(column_names).issubset(synthetic_data.columns): + raise ValueError( + f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} ' + 'in synthetic data.' + ) + synthetic_data = synthetic_data[column_names] if plot_type not in ['box', 'heatmap', 'scatter', None]: raise ValueError( @@ -483,12 +490,13 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None "['box', 'heatmap', 'scatter', None]." ) - real_data = real_data[column_names] - synthetic_data = synthetic_data[column_names] if plot_type is None: plot_type = [] for column_name in column_names: - column = real_data[column_name] + if real_data is not None: + column = real_data[column_name] + else: + column = synthetic_data[column_name] dtype = column.dropna().infer_objects().dtype.kind if dtype in ('i', 'f') or is_datetime(column): plot_type.append('scatter') @@ -501,19 +509,22 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None plot_type = plot_type.pop() # Merge the real and synthetic data and add a flag ``Data`` to indicate each one. - columns = list(real_data.columns) - real_data = real_data.copy() - real_data['Data'] = 'Real' - synthetic_data = synthetic_data.copy() - synthetic_data['Data'] = 'Synthetic' - all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) + all_data = pd.DataFrame() + if real_data is not None: + real_data = real_data.copy() + real_data['Data'] = 'Real' + all_data = pd.concat([all_data, real_data], axis=0, ignore_index=True) + if synthetic_data is not None: + synthetic_data = synthetic_data.copy() + synthetic_data['Data'] = 'Synthetic' + all_data = pd.concat([all_data, synthetic_data], axis=0, ignore_index=True) if plot_type == 'scatter': - return _generate_scatter_plot(all_data, columns) + return _generate_scatter_plot(all_data, column_names) elif plot_type == 'heatmap': - return _generate_heatmap_plot(all_data, columns) + return _generate_heatmap_plot(all_data, column_names) - return _generate_box_plot(all_data, columns) + return _generate_box_plot(all_data, column_names) def _generate_line_plot(real_data, synthetic_data, x_axis, y_axis, marker, annotations=None): diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 788e0714..71abb985 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -666,6 +666,44 @@ def test_get_column_pair_plot_plot_type_none_continuous_data(mock__generate_scat assert fig == mock__generate_scatter_plot.return_value +@patch('sdmetrics.visualization._generate_scatter_plot') +def test_get_column_pair_plot_plot_single_data(mock__generate_scatter_plot): + """Test ``get_column_pair_plot`` with only real or synthetic data""" + # Setup + columns = ['amount', 'price'] + real_data = pd.DataFrame({'amount': [1, 2, 3], 'price': [4, 5, 6]}) + synthetic_data = pd.DataFrame({'amount': [1.0, 2.0, 3.0], 'price': [4.0, 5.0, 6.0]}) + mock__generate_scatter_plot.side_effect = ['mock_return_1', 'mock_return_2'] + + # Run + real_fig = get_column_pair_plot(real_data, None, columns) + synth_fig = get_column_pair_plot(None, synthetic_data, columns) + + # Assert + real_data['Data'] = 'Real' + synthetic_data['Data'] = 'Synthetic' + expected_real_call_data = real_data + expected_synth_call_data = synthetic_data + expected_calls = [ + call(DataFrameMatcher(expected_real_call_data), columns), + call(DataFrameMatcher(expected_synth_call_data), columns), + ] + mock__generate_scatter_plot.assert_has_calls(expected_calls, any_order=False) + assert real_fig == 'mock_return_1' + assert synth_fig == 'mock_return_2' + + +@patch('sdmetrics.visualization._generate_scatter_plot') +def test_get_column_pair_plot_plot_no_data(mock__generate_scatter_plot): + """Test ``get_column_pair_plot`` with neither real or synthetic data""" + # Setup + columns = ['amount', 'price'] + error_msg = re.escape('Must provide at least one dataset to plot.') + # Run and Aassert + with pytest.raises(ValueError, match=error_msg): + get_column_pair_plot(None, None, columns) + + @patch('sdmetrics.visualization._generate_scatter_plot') def test_get_column_pair_plot_plot_type_none_continuous_data_and_date(mock__generate_scatter_plot): """Test ``get_column_pair_plot`` with continuous data and ``plot_type`` ``None``.""" From 58e0989f63d1ded43877d546f40d99d953ada200 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Thu, 20 Jun 2024 13:15:20 -0500 Subject: [PATCH 02/15] Adding ability to individually graph synthetic or real data --- sdmetrics/visualization.py | 143 +++++++++++++++++++++---------- tests/unit/test_visualization.py | 32 +++++++ 2 files changed, 130 insertions(+), 45 deletions(-) diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index fe52a74d..bb12f7fd 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -46,10 +46,10 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}): """Generate a bar plot of the real and synthetic data. Args: - real_column (pandas.Series): - The real data for the desired column. - synthetic_column (pandas.Series): - The synthetic data for the desired column. + real_column (pandas.Series or None): + The real data for the desired column. If None this data will not be graphed. + synthetic_column (pandas.Series or None): + The synthetic data for the desired column. If None this data will not be graphed. plot_kwargs (dict, optional): Dictionary of keyword arguments to pass to px.histogram. Keyword arguments provided this way will overwrite defaults. @@ -57,7 +57,12 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}): Returns: plotly.graph_objects._figure.Figure """ - all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) + all_data = pd.DataFrame() + if real_data is not None: + all_data = pd.concat([all_data, real_data], axis=0, ignore_index=True) + if synthetic_data is not None: + all_data = pd.concat([all_data, synthetic_data], axis=0, ignore_index=True) + histogram_kwargs = { 'x': 'values', 'color': 'Data', @@ -172,10 +177,10 @@ def _generate_column_distplot(real_data, synthetic_data, plot_kwargs={}): """Plot the real and synthetic data as a distplot. Args: - real_data (pandas.DataFrame): - The real data for the desired column. - synthetic_data (pandas.DataFrame): - The synthetic data for the desired column. + real_data (pandas.DataFrame or None): + The real data for the desired column. If None this data will not be graphed. + synthetic_data (pandas.DataFrame or None): + The synthetic data for the desired column. If None this data will not be graphed. plot_kwargs (dict, optional): Dictionary of keyword arguments to pass to px.histogram. Keyword arguments provided this way will overwrite defaults. @@ -189,9 +194,18 @@ def _generate_column_distplot(real_data, synthetic_data, plot_kwargs={}): 'colors': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN], } + hist_data = [] + col_names = [] + if real_data is not None: + hist_data.append(real_data['values']) + col_names.append('Real') + if synthetic_data is not None: + hist_data.append(synthetic_data['values']) + col_names.append('Synthetic') + fig = ff.create_distplot( - [real_data['values'], synthetic_data['values']], - ['Real', 'Synthetic'], + hist_data, + col_names, **{**default_distplot_kwargs, **plot_kwargs}, ) @@ -204,10 +218,10 @@ def _generate_column_plot( """Generate a plot of the real and synthetic data. Args: - real_column (pandas.Series): - The real data for the desired column. - synthetic_column (pandas.Series): - The synthetic data for the desired column. + real_column (pandas.Series or None): + The real data for the desired column. If None this data will not be graphed. + synthetic_column (pandas.Series or None) + The synthetic data for the desired column. If None this data will not be graphed. plot_type (str): The type of plot to use. Must be one of 'bar' or 'distplot'. hist_kwargs (dict, optional): @@ -221,26 +235,49 @@ def _generate_column_plot( Returns: plotly.graph_objects._figure.Figure """ + + if real_column is None and synthetic_column is None: + raise ValueError('Must provide at least one dataset to plot.') + if plot_type not in ['bar', 'distplot']: raise ValueError( "Unrecognized plot_type '{plot_type}'. Pleas use one of 'bar' or 'distplot'" ) - column_name = real_column.name if hasattr(real_column, 'name') else '' - - missing_data_real = get_missing_percentage(real_column) - missing_data_synthetic = get_missing_percentage(synthetic_column) + column_name = '' + missing_data_real = 0 + missing_data_synthetic = 0 + col_dtype = None + col_names = [] + if real_column is not None and hasattr(real_column, 'name'): + column_name = real_column.name + elif synthetic_column is not None and hasattr(synthetic_column, 'name'): + column_name = synthetic_column.name + + real_data = None + if real_column is not None: + missing_data_real = get_missing_percentage(real_column) + real_data = pd.DataFrame({'values': real_column.copy().dropna()}) + real_data['Data'] = 'Real' + col_dtype = real_column.dtype + col_names.append('Real') - real_data = pd.DataFrame({'values': real_column.copy().dropna()}) - real_data['Data'] = 'Real' - synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()}) - synthetic_data['Data'] = 'Synthetic' + synthetic_data = None + if synthetic_column is not None: + missing_data_synthetic = get_missing_percentage(synthetic_column) + synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()}) + synthetic_data['Data'] = 'Synthetic' + col_names.append('Synthetic') + if col_dtype is None: + col_dtype = synthetic_column.dtype is_datetime_sdtype = False - if is_datetime64_dtype(real_column.dtype): + if is_datetime64_dtype(col_dtype): is_datetime_sdtype = True - real_data['values'] = real_data['values'].astype('int64') - synthetic_data['values'] = synthetic_data['values'].astype('int64') + if real_data is not None: + real_data['values'] = real_data['values'].astype('int64') + if synthetic_data is not None: + synthetic_data['values'] = synthetic_data['values'].astype('int64') trace_args = {} @@ -251,7 +288,7 @@ def _generate_column_plot( fig = _generate_column_distplot(real_data, synthetic_data, plot_kwargs) trace_args = {'fill': 'tozeroy'} - for i, name in enumerate(['Real', 'Synthetic']): + for i, name in enumerate(col_names): fig.update_traces( x=pd.to_datetime(fig.data[i].x) if is_datetime_sdtype else fig.data[i].x, hovertemplate=f'{name}
Frequency: %{{y}}', @@ -260,6 +297,14 @@ def _generate_column_plot( ) show_missing_values = missing_data_real > 0 or missing_data_synthetic > 0 + text = '*Missing Values:' + if real_column is not None and show_missing_values: + text += f' Real Data ({missing_data_real}%), ' + if synthetic_column is not None and show_missing_values: + text += f'Synthetic Data ({missing_data_synthetic}%), ' + + text = text[:-2] + annotations = ( [] if not show_missing_values @@ -270,10 +315,7 @@ def _generate_column_plot( 'x': 1.0, 'y': 1.05, 'showarrow': False, - 'text': ( - f'*Missing Values: Real Data ({missing_data_real}%), ' - f'Synthetic Data ({missing_data_synthetic}%)' - ), + 'text': text, }, ] ) @@ -401,10 +443,10 @@ def get_column_plot(real_data, synthetic_data, column_name, plot_type=None): """Return a plot of the real and synthetic data for a given column. Args: - real_data (pandas.DataFrame): - The real table data. - synthetic_data (pandas.DataFrame): - The synthetic table data. + real_data (pandas.DataFrame or None): + The real table data. If None this data will not be graphed. + synthetic_data (pandas.DataFrame or None): + The synthetic table data. If None this data will not be graphed. column_name (str): The name of the column. plot_type (str or None): @@ -416,28 +458,39 @@ def get_column_plot(real_data, synthetic_data, column_name, plot_type=None): Returns: plotly.graph_objects._figure.Figure """ + + if real_data is None and synthetic_data is None: + raise ValueError('Must provide at least one dataset to plot.') + if plot_type not in ['bar', 'distplot', None]: raise ValueError( f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot', None]." ) - if column_name not in real_data.columns: - raise ValueError(f"Column '{column_name}' not found in real table data.") - if column_name not in synthetic_data.columns: - raise ValueError(f"Column '{column_name}' not found in synthetic table data.") + column = None + real_column = None + synthetic_column = None + if real_data is not None: + if column_name not in real_data.columns: + raise ValueError(f"Column '{column_name}' not found in real table data.") + column = real_data[column_name] + real_column = real_data[column_name] + + if synthetic_data is not None: + if column_name not in synthetic_data.columns: + raise ValueError(f"Column '{column_name}' not found in synthetic table data.") + if column is None: + column = synthetic_data[column_name] + synthetic_column = synthetic_data[column_name] - real_column = real_data[column_name] if plot_type is None: - column_is_datetime = is_datetime(real_data[column_name]) - dtype = real_column.dropna().infer_objects().dtype.kind + column_is_datetime = is_datetime(column) + dtype = column.dropna().infer_objects().dtype.kind if column_is_datetime or dtype in ('i', 'f'): plot_type = 'distplot' else: plot_type = 'bar' - real_column = real_data[column_name] - synthetic_column = synthetic_data[column_name] - fig = _generate_column_plot(real_column, synthetic_column, plot_type) return fig diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 71abb985..ad6529ef 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -268,6 +268,38 @@ def test_get_column_plot_bad_plot_type(): get_column_plot(real_data, synthetic_data, 'valeus', plot_type='bad_type') +def test_get_column_plot_no_data(): + """Test the ``get_column_plot`` method with no data passed in.""" + # Run and assert + error_msg = re.escape('Must provide at least one dataset to plot.') + with pytest.raises(ValueError, match=error_msg): + get_column_plot(None, None, 'values') + + +@patch('sdmetrics.visualization._generate_column_plot') +def test_get_column_plot_plot_one_data_set(mock__generate_column_plot): + """Test ``get_column_plot`` for real data and synthetic data individually.""" + # Setup + real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]}) + synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]}) + mock__generate_column_plot.side_effect = ['mock_return_1', 'mock_return_2'] + + # Run + fig_real = get_column_plot(real_data, None, 'values') + fig_synth = get_column_plot(None, synthetic_data, 'values') + + # Assert + expected_real_call_data = real_data['values'] + expected_synth_call_data = synthetic_data['values'] + expected_calls = [ + call(SeriesMatcher(expected_real_call_data), None, 'distplot'), + call(None, SeriesMatcher(expected_synth_call_data), 'distplot'), + ] + mock__generate_column_plot.assert_has_calls(expected_calls, any_order=False) + assert fig_real == 'mock_return_1' + assert fig_synth == 'mock_return_2' + + @patch('sdmetrics.visualization._generate_column_plot') def test_get_column_plot_plot_type_none_data_int(mock__generate_column_plot): """Test ``get_column_plot`` when ``plot_type`` is ``None`` and data is ``int``.""" From 9c7dbc77371163ee5e26b2b738e1ba9b6a1a5612 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Thu, 20 Jun 2024 13:32:39 -0500 Subject: [PATCH 03/15] Fix title naming --- sdmetrics/visualization.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index fe52a74d..32d8a95e 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -86,12 +86,20 @@ def _generate_heatmap_plot(all_data, columns): Returns: plotly.graph_objects._figure.Figure """ + unique_values = all_data['Data'].unique() + fig = px.density_heatmap( all_data, x=columns[0], y=columns[1], facet_col='Data', histnorm='probability' ) + title = '' + for name in unique_values: + title += f'{name} vs. ' + title = title[:-4] + title += f"Data for columns '{columns[0]}' and '{columns[1]}" + fig.update_layout( - title_text=f"Real vs Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", + title_text=title, coloraxis={'colorscale': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]}, font={'size': PlotConfig.FONT_SIZE}, ) @@ -147,6 +155,7 @@ def _generate_scatter_plot(all_data, columns): Returns: plotly.graph_objects._figure.Figure """ + unique_values = all_data['Data'].unique() fig = px.scatter( all_data, x=columns[0], @@ -159,8 +168,14 @@ def _generate_scatter_plot(all_data, columns): symbol='Data', ) + title = '' + for name in unique_values: + title += f'{name} vs. ' + title = title[:-4] + title += f"Data for columns '{columns[0]}' and '{columns[1]}'" + fig.update_layout( - title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", + title=title, plot_bgcolor=PlotConfig.BACKGROUND_COLOR, font={'size': PlotConfig.FONT_SIZE}, ) From 412b10b8a908d8081816c8a1d24896760bbd2be7 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Thu, 20 Jun 2024 15:25:50 -0500 Subject: [PATCH 04/15] Fixed title --- sdmetrics/visualization.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index 357c429c..6c2ed81b 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -264,6 +264,7 @@ def _generate_column_plot( missing_data_synthetic = 0 col_dtype = None col_names = [] + title = '' if real_column is not None and hasattr(real_column, 'name'): column_name = real_column.name elif synthetic_column is not None and hasattr(synthetic_column, 'name'): @@ -276,6 +277,7 @@ def _generate_column_plot( real_data['Data'] = 'Real' col_dtype = real_column.dtype col_names.append('Real') + title += 'Real vs. ' synthetic_data = None if synthetic_column is not None: @@ -283,9 +285,13 @@ def _generate_column_plot( synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()}) synthetic_data['Data'] = 'Synthetic' col_names.append('Synthetic') + title += 'Synthetic vs. ' if col_dtype is None: col_dtype = synthetic_column.dtype + title = title[:-4] + title += f"Data for column '{column_name}'" + is_datetime_sdtype = False if is_datetime64_dtype(col_dtype): is_datetime_sdtype = True @@ -336,7 +342,7 @@ def _generate_column_plot( ) if not plot_title: - plot_title = f"Real vs. Synthetic Data for column '{column_name}'" + plot_title = title if not x_label: x_label = 'Category' @@ -427,7 +433,8 @@ def get_cardinality_plot( plotly.graph_objects._figure.Figure """ if plot_type not in ['bar', 'distplot']: - raise ValueError(f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot'].") + raise ValueError( + f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot'].") real_cardinality = _get_cardinality( real_data[parent_table_name], From 122855c622535c841107dd03c18bbd52a6a87b48 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Thu, 20 Jun 2024 16:21:47 -0500 Subject: [PATCH 05/15] Update message --- sdmetrics/visualization.py | 2 +- tests/unit/test_visualization.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index 32d8a95e..c688092e 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -482,7 +482,7 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None raise ValueError('Must provide exactly two column names.') if real_data is None and synthetic_data is None: - raise ValueError('Must provide at least one dataset to plot.') + raise ValueError('No data provided to plot. Please provide either real or synthetic data.') if real_data is not None: if not set(column_names).issubset(real_data.columns): diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 71abb985..33f59816 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -698,8 +698,8 @@ def test_get_column_pair_plot_plot_no_data(mock__generate_scatter_plot): """Test ``get_column_pair_plot`` with neither real or synthetic data""" # Setup columns = ['amount', 'price'] - error_msg = re.escape('Must provide at least one dataset to plot.') - # Run and Aassert + error_msg = re.escape('No data provided to plot. Please provide either real or synthetic data.') + # Run and Assert with pytest.raises(ValueError, match=error_msg): get_column_pair_plot(None, None, columns) From 2add684d291240c2850ae069aafdd9d359213b79 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Thu, 20 Jun 2024 16:30:17 -0500 Subject: [PATCH 06/15] Keep colors --- sdmetrics/visualization.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index 04397597..cdf0f563 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -58,16 +58,19 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}): plotly.graph_objects._figure.Figure """ all_data = pd.DataFrame() + color_sequence = [] if real_data is not None: all_data = pd.concat([all_data, real_data], axis=0, ignore_index=True) + color_sequence.append(PlotConfig.DATACEBO_DARK) if synthetic_data is not None: all_data = pd.concat([all_data, synthetic_data], axis=0, ignore_index=True) + color_sequence.append(PlotConfig.DATACEBO_GREEN) histogram_kwargs = { 'x': 'values', 'color': 'Data', 'barmode': 'group', - 'color_discrete_sequence': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN], + 'color_discrete_sequence': color_sequence, 'pattern_shape': 'Data', 'pattern_shape_sequence': ['', '/'], 'histnorm': 'probability density', @@ -203,20 +206,23 @@ def _generate_column_distplot(real_data, synthetic_data, plot_kwargs={}): Returns: plotly.graph_objects._figure.Figure """ - default_distplot_kwargs = { - 'show_hist': False, - 'show_rug': False, - 'colors': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN], - } - hist_data = [] col_names = [] + colors = [] if real_data is not None: hist_data.append(real_data['values']) col_names.append('Real') + colors.append(PlotConfig.DATACEBO_DARK) if synthetic_data is not None: hist_data.append(synthetic_data['values']) col_names.append('Synthetic') + colors.append(PlotConfig.DATACEBO_GREEN) + + default_distplot_kwargs = { + 'show_hist': False, + 'show_rug': False, + 'colors': colors, + } fig = ff.create_distplot( hist_data, @@ -433,8 +439,7 @@ def get_cardinality_plot( plotly.graph_objects._figure.Figure """ if plot_type not in ['bar', 'distplot']: - raise ValueError( - f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot'].") + raise ValueError(f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot'].") real_cardinality = _get_cardinality( real_data[parent_table_name], From 9fec05b42a52b3516016d0785df7e11c18b2381c Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Thu, 20 Jun 2024 16:39:27 -0500 Subject: [PATCH 07/15] Update error message --- sdmetrics/visualization.py | 4 ++-- tests/unit/test_visualization.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index cdf0f563..0679c2e3 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -258,7 +258,7 @@ def _generate_column_plot( """ if real_column is None and synthetic_column is None: - raise ValueError('Must provide at least one dataset to plot.') + raise ValueError('No data provided to plot. Please provide either real or synthetic data.') if plot_type not in ['bar', 'distplot']: raise ValueError( @@ -487,7 +487,7 @@ def get_column_plot(real_data, synthetic_data, column_name, plot_type=None): """ if real_data is None and synthetic_data is None: - raise ValueError('Must provide at least one dataset to plot.') + raise ValueError('No data provided to plot. Please provide either real or synthetic data.') if plot_type not in ['bar', 'distplot', None]: raise ValueError( diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 00daf583..8b66bc37 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -271,7 +271,7 @@ def test_get_column_plot_bad_plot_type(): def test_get_column_plot_no_data(): """Test the ``get_column_plot`` method with no data passed in.""" # Run and assert - error_msg = re.escape('Must provide at least one dataset to plot.') + error_msg = re.escape('No data provided to plot. Please provide either real or synthetic data.') with pytest.raises(ValueError, match=error_msg): get_column_plot(None, None, 'values') From b80af9f36ccefffb11651ba1ec4c6f0bfffc373a Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Mon, 24 Jun 2024 12:25:52 -0500 Subject: [PATCH 08/15] Add tests for previous visualization methods --- sdmetrics/visualization.py | 2 +- tests/unit/test_visualization.py | 189 +++++++++++++++++++++++++++++++ 2 files changed, 190 insertions(+), 1 deletion(-) diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index 0679c2e3..f1360b47 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -262,7 +262,7 @@ def _generate_column_plot( if plot_type not in ['bar', 'distplot']: raise ValueError( - "Unrecognized plot_type '{plot_type}'. Pleas use one of 'bar' or 'distplot'" + f"Unrecognized plot_type '{plot_type}'. Please use one of 'bar' or 'distplot'" ) column_name = '' diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 8b66bc37..6bb82d23 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -4,9 +4,13 @@ import pandas as pd import pytest +from sdmetrics.reports.utils import PlotConfig from sdmetrics.visualization import ( _generate_box_plot, _generate_cardinality_plot, + _generate_column_bar_plot, + _generate_column_distplot, + _generate_column_plot, _generate_heatmap_plot, _generate_line_plot, _generate_scatter_plot, @@ -276,6 +280,191 @@ def test_get_column_plot_no_data(): get_column_plot(None, None, 'values') +@patch('sdmetrics.visualization.px.histogram') +def test__generate_column_bar_plot(mock_histogram): + """Test ``_generate_column_bar_plot`` functionality""" + # Setup + real_data = pd.Series([1, 2, 2, 3, 5]) + synthetic_data = pd.Series([2, 2, 3, 4, 5]) + + # Run + _generate_column_bar_plot(real_data, synthetic_data) + + # Assert + expected_data = pd.DataFrame( + pd.concat([real_data, synthetic_data], axis=0, ignore_index=True).astype('float64') + ) + expected_parameters = { + 'x': 'values', + 'color': 'Data', + 'barmode': 'group', + 'color_discrete_sequence': ['#000036', '#01E0C9'], + 'pattern_shape': 'Data', + 'pattern_shape_sequence': ['', '/'], + 'histnorm': 'probability density', + } + pd.testing.assert_frame_equal(expected_data, mock_histogram.call_args[0][0]) + assert expected_parameters == mock_histogram.call_args[1] + mock_histogram.assert_called_once() + + +@patch('sdmetrics.visualization.ff.create_distplot') +def test__generate_column_distplot(mock_distplot): + """Test ``_generate_column_distplot`` functionality""" + # Setup + real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]}) + synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]}) + + # Run + _generate_column_distplot(real_data, synthetic_data) + + # Assert + expected_data = [] + expected_data.append(real_data['values']) + expected_data.append(synthetic_data['values']) + expected_data == mock_distplot.call_args[0][0] + + ['Real', 'Synthetic'] == mock_distplot.call_args[0][1] + + expected_colors = [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN] + expected_parameters = { + 'show_hist': False, + 'show_rug': False, + 'colors': expected_colors, + } + assert expected_parameters == mock_distplot.call_args[1] + mock_distplot.assert_called_once() + + +@patch('sdmetrics.visualization._generate_column_distplot') +def test___generate_column_plot_type_distplot(mock_dist_plot): + """Test ``_generate_column_plot`` with a dist_plot""" + # Setup + real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]}) + synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]}) + mock_fig = Mock() + mock_object = Mock() + mock_object.x = [1, 2, 2, 3, 5] + mock_fig.data = [mock_object, mock_object] + mock_dist_plot.return_value = mock_fig + + # Run + _generate_column_plot(real_data['values'], synthetic_data['values'], 'distplot') + + # Assert + expected_real_data = pd.DataFrame({ + 'values': [1, 2, 2, 3, 5], + 'Data': ['Real', 'Real', 'Real', 'Real', 'Real'], + }) + expected_synth_data = pd.DataFrame({ + 'values': [2, 2, 3, 4, 5], + 'Data': ['Synthetic', 'Synthetic', 'Synthetic', 'Synthetic', 'Synthetic'], + }) + pd.testing.assert_frame_equal(mock_dist_plot.call_args[0][0], expected_real_data) + pd.testing.assert_frame_equal(mock_dist_plot.call_args[0][1], expected_synth_data) + assert mock_dist_plot.call_args[0][2] == {} + mock_dist_plot.assert_called_once() + + mock_fig.update_layout.assert_called_once_with( + title="Real vs. Synthetic Data for column 'values'", + xaxis_title='Value', + yaxis_title='Frequency', + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + annotations=[], + font={'size': PlotConfig.FONT_SIZE}, + ) + + +@patch('sdmetrics.visualization._generate_column_bar_plot') +def test___generate_column_plot_type_bar(mock_bar_plot): + """Test ``_generate_column_plot`` with a bar plot""" + # Setup + real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]}) + synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]}) + mock_fig = Mock() + mock_object = Mock() + mock_object.x = [1, 2, 2, 3, 5] + mock_fig.data = [mock_object, mock_object] + mock_bar_plot.return_value = mock_fig + + # Run + _generate_column_plot(real_data['values'], synthetic_data['values'], 'bar') + + # Assert + expected_real_data = pd.DataFrame({ + 'values': [1, 2, 2, 3, 5], + 'Data': ['Real', 'Real', 'Real', 'Real', 'Real'], + }) + expected_synth_data = pd.DataFrame({ + 'values': [2, 2, 3, 4, 5], + 'Data': ['Synthetic', 'Synthetic', 'Synthetic', 'Synthetic', 'Synthetic'], + }) + pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][0], expected_real_data) + pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][1], expected_synth_data) + assert mock_bar_plot.call_args[0][2] == {} + mock_bar_plot.assert_called_once() + mock_fig.update_layout.assert_called_once_with( + title="Real vs. Synthetic Data for column 'values'", + xaxis_title='Category', + yaxis_title='Frequency', + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + annotations=[], + font={'size': PlotConfig.FONT_SIZE}, + ) + + +@patch('sdmetrics.visualization._generate_column_bar_plot') +def test___generate_column_plot_with_datetimes(mock_bar_plot): + """Test ``_generate_column_plot`` using datetimes""" + # Setup + real_data = pd.DataFrame({'values': pd.to_datetime(['2021-01-20', '2022-01-21'])}) + synthetic_data = pd.DataFrame({'values': pd.to_datetime(['2021-01-20', '2022-01-21'])}) + mock_fig = Mock() + mock_object = Mock() + mock_object.x = [1, 2, 2, 3, 5] + mock_fig.data = [mock_object, mock_object] + mock_bar_plot.return_value = mock_fig + + # Run + _generate_column_plot(real_data['values'], synthetic_data['values'], 'bar') + + # Assert + print(mock_bar_plot.call_args[0][1]) + expected_real_data = pd.DataFrame({ + 'values': [1611100800000000000, 1642723200000000000], + 'Data': ['Real', 'Real'], + }) + expected_synth_data = pd.DataFrame({ + 'values': [1611100800000000000, 1642723200000000000], + 'Data': ['Synthetic', 'Synthetic'], + }) + pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][0], expected_real_data) + pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][1], expected_synth_data) + assert mock_bar_plot.call_args[0][2] == {} + mock_bar_plot.assert_called_once() + + +def test___generate_column_plot_no_data(): + """Test ``_generate_column_plot`` when no data is passed in.""" + # Run and Assert + error_msg = re.escape('No data provided to plot. Please provide either real or synthetic data.') + with pytest.raises(ValueError, match=error_msg): + _generate_column_plot(None, None, 'bar') + + +def test___generate_column_plot_with_bad_plot(): + """Test ``_generate_column_plot`` when an incorrect plot is set.""" + # Setup + real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]}) + synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]}) + # Run and Assert + error_msg = re.escape( + "Unrecognized plot_type 'bad_plot'. Please use one of 'bar' or 'distplot'" + ) + with pytest.raises(ValueError, match=error_msg): + _generate_column_plot(real_data, synthetic_data, 'bad_plot') + + @patch('sdmetrics.visualization._generate_column_plot') def test_get_column_plot_plot_one_data_set(mock__generate_column_plot): """Test ``get_column_plot`` for real data and synthetic data individually.""" From a917652673690b65ae8a5acdde79f9bf2e16d7cd Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Mon, 24 Jun 2024 13:36:12 -0500 Subject: [PATCH 09/15] Use assert_called_once_with with ANY --- tests/unit/test_visualization.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 6bb82d23..7a4832c9 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -1,5 +1,5 @@ import re -from unittest.mock import Mock, call, patch +from unittest.mock import ANY, Mock, call, patch import pandas as pd import pytest @@ -304,8 +304,7 @@ def test__generate_column_bar_plot(mock_histogram): 'histnorm': 'probability density', } pd.testing.assert_frame_equal(expected_data, mock_histogram.call_args[0][0]) - assert expected_parameters == mock_histogram.call_args[1] - mock_histogram.assert_called_once() + mock_histogram.assert_called_once_with(ANY, **expected_parameters) @patch('sdmetrics.visualization.ff.create_distplot') @@ -323,9 +322,7 @@ def test__generate_column_distplot(mock_distplot): expected_data.append(real_data['values']) expected_data.append(synthetic_data['values']) expected_data == mock_distplot.call_args[0][0] - - ['Real', 'Synthetic'] == mock_distplot.call_args[0][1] - + expected_col = ['Real', 'Synthetic'] expected_colors = [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN] expected_parameters = { 'show_hist': False, @@ -333,7 +330,7 @@ def test__generate_column_distplot(mock_distplot): 'colors': expected_colors, } assert expected_parameters == mock_distplot.call_args[1] - mock_distplot.assert_called_once() + mock_distplot.assert_called_once_with(expected_data, expected_col, **expected_parameters) @patch('sdmetrics.visualization._generate_column_distplot') @@ -362,8 +359,7 @@ def test___generate_column_plot_type_distplot(mock_dist_plot): }) pd.testing.assert_frame_equal(mock_dist_plot.call_args[0][0], expected_real_data) pd.testing.assert_frame_equal(mock_dist_plot.call_args[0][1], expected_synth_data) - assert mock_dist_plot.call_args[0][2] == {} - mock_dist_plot.assert_called_once() + mock_dist_plot.assert_called_once_with(ANY, ANY, {}) mock_fig.update_layout.assert_called_once_with( title="Real vs. Synthetic Data for column 'values'", @@ -401,8 +397,7 @@ def test___generate_column_plot_type_bar(mock_bar_plot): }) pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][0], expected_real_data) pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][1], expected_synth_data) - assert mock_bar_plot.call_args[0][2] == {} - mock_bar_plot.assert_called_once() + mock_bar_plot.assert_called_once_with(ANY, ANY, {}) mock_fig.update_layout.assert_called_once_with( title="Real vs. Synthetic Data for column 'values'", xaxis_title='Category', @@ -440,8 +435,7 @@ def test___generate_column_plot_with_datetimes(mock_bar_plot): }) pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][0], expected_real_data) pd.testing.assert_frame_equal(mock_bar_plot.call_args[0][1], expected_synth_data) - assert mock_bar_plot.call_args[0][2] == {} - mock_bar_plot.assert_called_once() + mock_bar_plot.assert_called_once_with(ANY, ANY, {}) def test___generate_column_plot_no_data(): From 87178a16dc32fdfe526d5560bac1834534319f0d Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Mon, 24 Jun 2024 14:54:15 -0500 Subject: [PATCH 10/15] Fix failing minimum test --- tests/unit/test_visualization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 7a4832c9..3edf8cdc 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -284,15 +284,15 @@ def test_get_column_plot_no_data(): def test__generate_column_bar_plot(mock_histogram): """Test ``_generate_column_bar_plot`` functionality""" # Setup - real_data = pd.Series([1, 2, 2, 3, 5]) - synthetic_data = pd.Series([2, 2, 3, 4, 5]) + real_data = pd.DataFrame([1, 2, 2, 3, 5]) + synthetic_data = pd.DataFrame([2, 2, 3, 4, 5]) # Run _generate_column_bar_plot(real_data, synthetic_data) # Assert expected_data = pd.DataFrame( - pd.concat([real_data, synthetic_data], axis=0, ignore_index=True).astype('float64') + pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) ) expected_parameters = { 'x': 'values', From 34b66aea0cb3a1b83a393f519fc8f66316f032f0 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Mon, 24 Jun 2024 15:11:39 -0500 Subject: [PATCH 11/15] Address comments and added check for single column test for internal functions --- sdmetrics/visualization.py | 21 +++++++------ tests/unit/test_visualization.py | 52 ++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index c688092e..d616c90f 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -88,15 +88,15 @@ def _generate_heatmap_plot(all_data, columns): """ unique_values = all_data['Data'].unique() + if len(columns) != 2: + raise ValueError('Generating a heatmap plot requires only two columns for the axis.') + fig = px.density_heatmap( all_data, x=columns[0], y=columns[1], facet_col='Data', histnorm='probability' ) - title = '' - for name in unique_values: - title += f'{name} vs. ' - title = title[:-4] - title += f"Data for columns '{columns[0]}' and '{columns[1]}" + title = ' vs. '.join(unique_values) + title += f" Data for columns '{columns[0]}' and '{columns[1]}" fig.update_layout( title_text=title, @@ -155,6 +155,10 @@ def _generate_scatter_plot(all_data, columns): Returns: plotly.graph_objects._figure.Figure """ + + if len(columns) != 2: + raise ValueError('Generating a scatter plot requires only two columns for the axis.') + unique_values = all_data['Data'].unique() fig = px.scatter( all_data, @@ -168,11 +172,8 @@ def _generate_scatter_plot(all_data, columns): symbol='Data', ) - title = '' - for name in unique_values: - title += f'{name} vs. ' - title = title[:-4] - title += f"Data for columns '{columns[0]}' and '{columns[1]}'" + title = ' vs. '.join(unique_values) + title += f" Data for columns '{columns[0]}' and '{columns[1]}'" fig.update_layout( title=title, diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 33f59816..9e547a00 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -421,6 +421,32 @@ def test__generate_scatter_plot(px_mock): assert fig == mock_figure +@patch('sdmetrics.visualization.px') +def test__generate_scatter_plot_one_column_failure(px_mock): + """Test the ``_generate_scatter_plot`` method.""" + # Setup + real_column = pd.DataFrame({ + 'col1': [1, 2, 3, 4], + 'col2': [1.1, 1.2, 1.3, 1.4], + 'Data': ['Real'] * 4, + }) + synthetic_column = pd.DataFrame({ + 'col1': [1, 2, 4, 5], + 'col2': [1.1, 1.2, 1.3, 1.4], + 'Data': ['Synthetic'] * 4, + }) + + all_data = pd.concat([real_column, synthetic_column], axis=0, ignore_index=True) + columns = ['col1'] + mock_figure = Mock() + px_mock.scatter.return_value = mock_figure + + # Run and assert + error_msg = re.escape('Generating a scatter plot requires only two columns for the axis.') + with pytest.raises(ValueError, match=error_msg): + _generate_scatter_plot(all_data, columns) + + @patch('sdmetrics.visualization.px') def test__generate_heatmap_plot(px_mock): """Test the ``_generate_heatmap_plot`` method.""" @@ -472,6 +498,32 @@ def test__generate_heatmap_plot(px_mock): assert fig == mock_figure +@patch('sdmetrics.visualization.px') +def test__generate_heatmap_plot_one_column(px_mock): + """Test the ``_generate_heatmap_plot`` method.""" + # Setup + real_column = pd.DataFrame({ + 'col1': [1, 2, 3, 4], + 'col2': ['a', 'b', 'c', 'd'], + 'Data': ['Real'] * 4, + }) + synthetic_column = pd.DataFrame({ + 'col1': [1, 2, 4, 5], + 'col2': ['a', 'b', 'c', 'd'], + 'Data': ['Synthetic'] * 4, + }) + columns = ['col1'] + all_data = pd.concat([real_column, synthetic_column], axis=0, ignore_index=True) + + mock_figure = Mock() + px_mock.density_heatmap.return_value = mock_figure + + # Run and assert + error_msg = re.escape('Generating a heatmap plot requires only two columns for the axis.') + with pytest.raises(ValueError, match=error_msg): + _generate_heatmap_plot(all_data, columns) + + @patch('sdmetrics.visualization.px') def test__generate_line_plot(px_mock): """Test the ``_generate_line_plot`` method.""" From d694819a51faf8cad84512c48225e5982e84a163 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Mon, 24 Jun 2024 15:14:26 -0500 Subject: [PATCH 12/15] Fix --- tests/unit/test_visualization.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 3edf8cdc..7b3aaaa7 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -291,9 +291,7 @@ def test__generate_column_bar_plot(mock_histogram): _generate_column_bar_plot(real_data, synthetic_data) # Assert - expected_data = pd.DataFrame( - pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) - ) + expected_data = pd.DataFrame(pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)) expected_parameters = { 'x': 'values', 'color': 'Data', From 673c1d9edc48b9e00d624179197033ec2f51274b Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Mon, 24 Jun 2024 15:55:25 -0500 Subject: [PATCH 13/15] Remove unneeded tests --- sdmetrics/visualization.py | 4 +-- tests/unit/test_visualization.py | 52 -------------------------------- 2 files changed, 2 insertions(+), 54 deletions(-) diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index d616c90f..4af7fe35 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -89,7 +89,7 @@ def _generate_heatmap_plot(all_data, columns): unique_values = all_data['Data'].unique() if len(columns) != 2: - raise ValueError('Generating a heatmap plot requires only two columns for the axis.') + raise ValueError('Generating a heatmap plot requires exactly two columns for the axis.') fig = px.density_heatmap( all_data, x=columns[0], y=columns[1], facet_col='Data', histnorm='probability' @@ -157,7 +157,7 @@ def _generate_scatter_plot(all_data, columns): """ if len(columns) != 2: - raise ValueError('Generating a scatter plot requires only two columns for the axis.') + raise ValueError('Generating a scatter plot requires exactly two columns for the axis.') unique_values = all_data['Data'].unique() fig = px.scatter( diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 9e547a00..33f59816 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -421,32 +421,6 @@ def test__generate_scatter_plot(px_mock): assert fig == mock_figure -@patch('sdmetrics.visualization.px') -def test__generate_scatter_plot_one_column_failure(px_mock): - """Test the ``_generate_scatter_plot`` method.""" - # Setup - real_column = pd.DataFrame({ - 'col1': [1, 2, 3, 4], - 'col2': [1.1, 1.2, 1.3, 1.4], - 'Data': ['Real'] * 4, - }) - synthetic_column = pd.DataFrame({ - 'col1': [1, 2, 4, 5], - 'col2': [1.1, 1.2, 1.3, 1.4], - 'Data': ['Synthetic'] * 4, - }) - - all_data = pd.concat([real_column, synthetic_column], axis=0, ignore_index=True) - columns = ['col1'] - mock_figure = Mock() - px_mock.scatter.return_value = mock_figure - - # Run and assert - error_msg = re.escape('Generating a scatter plot requires only two columns for the axis.') - with pytest.raises(ValueError, match=error_msg): - _generate_scatter_plot(all_data, columns) - - @patch('sdmetrics.visualization.px') def test__generate_heatmap_plot(px_mock): """Test the ``_generate_heatmap_plot`` method.""" @@ -498,32 +472,6 @@ def test__generate_heatmap_plot(px_mock): assert fig == mock_figure -@patch('sdmetrics.visualization.px') -def test__generate_heatmap_plot_one_column(px_mock): - """Test the ``_generate_heatmap_plot`` method.""" - # Setup - real_column = pd.DataFrame({ - 'col1': [1, 2, 3, 4], - 'col2': ['a', 'b', 'c', 'd'], - 'Data': ['Real'] * 4, - }) - synthetic_column = pd.DataFrame({ - 'col1': [1, 2, 4, 5], - 'col2': ['a', 'b', 'c', 'd'], - 'Data': ['Synthetic'] * 4, - }) - columns = ['col1'] - all_data = pd.concat([real_column, synthetic_column], axis=0, ignore_index=True) - - mock_figure = Mock() - px_mock.density_heatmap.return_value = mock_figure - - # Run and assert - error_msg = re.escape('Generating a heatmap plot requires only two columns for the axis.') - with pytest.raises(ValueError, match=error_msg): - _generate_heatmap_plot(all_data, columns) - - @patch('sdmetrics.visualization.px') def test__generate_line_plot(px_mock): """Test the ``_generate_line_plot`` method.""" From 42bfd87f9943622e6a778d7d11a45f626b373d6f Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Tue, 25 Jun 2024 11:16:18 -0500 Subject: [PATCH 14/15] Cap scipy from new release to avoid test failures --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f95571d9..306b7e6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "scikit-learn>=1.3.1;python_version>='3.12'", "scipy>=1.7.3;python_version<'3.10'", "scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'", - "scipy>=1.12.0;python_version>='3.12'", + "scipy>=1.12.0,<1.14.0;python_version>='3.12'", 'copulas>=0.11.0', 'tqdm>=4.29', 'plotly>=5.19.0', From 8e94caa31ac1ed0c8002b99ba08a5ffe518f08ab Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Tue, 25 Jun 2024 11:23:40 -0500 Subject: [PATCH 15/15] Cap all scipy versions --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 306b7e6e..f0e87e4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,8 +30,8 @@ dependencies = [ "scikit-learn>=1.1.0;python_version>='3.10' and python_version<'3.11'", "scikit-learn>=1.1.3;python_version>='3.11' and python_version<'3.12'", "scikit-learn>=1.3.1;python_version>='3.12'", - "scipy>=1.7.3;python_version<'3.10'", - "scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'", + "scipy>=1.7.3,<1.14.0;python_version<'3.10'", + "scipy>=1.9.2,<1.14.0;python_version>='3.10' and python_version<'3.12'", "scipy>=1.12.0,<1.14.0;python_version>='3.12'", 'copulas>=0.11.0', 'tqdm>=4.29',