From b3c63de6c91a89a14aeb4a4bf61cf05873e721fc Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Mon, 25 Sep 2023 11:27:02 +0200 Subject: [PATCH] Add generate_column_plot --- sdmetrics/visualization.py | 70 ++++++++++++--- tests/unit/test_visualization.py | 144 ++++++++++++++++++++++++++++++- 2 files changed, 203 insertions(+), 11 deletions(-) diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index f8606332..57c3570b 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -6,7 +6,7 @@ from pandas.api.types import is_datetime64_dtype from sdmetrics.reports.utils import PlotConfig -from sdmetrics.utils import get_missing_percentage +from sdmetrics.utils import get_missing_percentage, is_datetime def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}): @@ -25,7 +25,7 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}): plotly.graph_objects._figure.Figure """ all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) - default_histogram_kwargs = { + histogram_kwargs = { 'x': 'values', 'color': 'Data', 'barmode': 'group', @@ -34,9 +34,10 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}): 'pattern_shape_sequence': ['', '/'], 'histnorm': 'probability density', } + histogram_kwargs.update(plot_kwargs) fig = px.histogram( all_data, - **{**default_histogram_kwargs, **plot_kwargs} + **histogram_kwargs ) return fig @@ -105,25 +106,28 @@ def _generate_column_plot(real_column, column_name = real_column.name if hasattr(real_column, 'name') else '' - real_data = pd.DataFrame({'values': real_column.copy()}) + missing_data_real = get_missing_percentage(real_column) + missing_data_synthetic = get_missing_percentage(synthetic_column) + + real_data = pd.DataFrame({'values': real_column.copy().dropna()}) real_data['Data'] = 'Real' - synthetic_data = pd.DataFrame({'values': synthetic_column.copy()}) + synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()}) synthetic_data['Data'] = 'Synthetic' is_datetime_sdtype = False if is_datetime64_dtype(real_column.dtype): is_datetime_sdtype = True - real_data = real_data.astype('int64') - synthetic_data = synthetic_data.astype('int64') - - missing_data_real = get_missing_percentage(real_column) - missing_data_synthetic = get_missing_percentage(synthetic_column) + real_data['values'] = real_data['values'].astype('int64') + synthetic_data['values'] = synthetic_data['values'].astype('int64') trace_args = {} if plot_type == 'bar': fig = _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs) elif plot_type == 'distplot': + if x_label is None: + x_label = 'Value' + fig = _generate_column_distplot(real_data, synthetic_data, plot_kwargs) trace_args = {'fill': 'tozeroy'} @@ -259,3 +263,49 @@ def get_cardinality_plot(real_data, synthetic_data, child_table_name, parent_tab ) return fig + + +def get_column_plot(real_data, synthetic_data, column_name, plot_type=None): + """Return a plot of the real and synthetic data for a given column. + + Args: + real_data (pandas.DataFrame): + The real table data. + synthetic_data (pandas.DataFrame): + The synthetic table data. + column_name (str): + The name of the column. + plot_type (str or None): + The plot to be used. Can choose between ``distplot``, ``bar`` or ``None``. If ``None` + select between ``distplot`` or ``bar`` depending on the data that the column contains, + ``distplot`` for datetime and numerical values and ``bar`` for categorical. + Defaults to ``None``. + + Returns: + plotly.graph_objects._figure.Figure + """ + if plot_type not in ['bar', 'distplot', None]: + raise ValueError( + f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot', None]." + ) + + if column_name not in real_data.columns: + raise ValueError(f"Column '{column_name}' not found in real table data.") + if column_name not in synthetic_data.columns: + raise ValueError(f"Column '{column_name}' not found in synthetic table data.") + + column_is_datetime = is_datetime(real_data[column_name]) + real_column = real_data[column_name] + dtype = real_column.dropna().infer_objects().dtype.kind + if plot_type is None: + if column_is_datetime or dtype in ('i', 'f'): + plot_type = 'distplot' + else: + plot_type = 'bar' + + real_column = real_data[column_name] + synthetic_column = synthetic_data[column_name] + + fig = _generate_column_plot(real_column, synthetic_column, plot_type) + + return fig diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 902b55b1..e6ce953a 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -5,7 +5,7 @@ import pytest from sdmetrics.visualization import ( - _generate_cardinality_plot, _get_cardinality, get_cardinality_plot) + _generate_cardinality_plot, _get_cardinality, get_cardinality_plot, get_column_plot) from tests.utils import DataFrameMatcher, SeriesMatcher @@ -221,3 +221,145 @@ def test_get_cardinality_plot_bad_plot_type(): real_data, synthetic_data, child_table_name, parent_table_name, child_foreign_key, parent_primary_key, plot_type='bad_type' ) + + +def test_get_column_plot_column_not_found(): + """Test the ``get_column_plot`` method when column is not present.""" + # Setup + real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]}) + synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]}) + + # Run and assert + match = re.escape("Column 'start_date' not found in real table data.") + with pytest.raises(ValueError, match=match): + get_column_plot(real_data, synthetic_data, 'start_date') + + match = re.escape("Column 'start_date' not found in synthetic table data.") + with pytest.raises(ValueError, match=match): + get_column_plot(pd.DataFrame({'start_date': []}), synthetic_data, 'start_date') + + +def test_get_column_plot_bad_plot_type(): + """Test the ``get_column_plot`` method.""" + # Setup + real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]}) + synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]}) + + # Run and assert + match = re.escape("Invalid plot_type 'bad_type'. Please use one of ['bar', 'distplot', None].") + with pytest.raises(ValueError, match=match): + get_column_plot(real_data, synthetic_data, 'valeus', plot_type='bad_type') + + +@patch('sdmetrics.visualization._generate_column_plot') +def test_get_column_plot_plot_type_none_data_int(mock__generate_column_plot): + """Test ``get_column_plot`` when ``plot_type`` is ``None`` and data is ``int``.""" + # Setup + real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]}) + synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]}) + + # Run + figure = get_column_plot(real_data, synthetic_data, 'values') + + # Assert + mock__generate_column_plot.assert_called_once_with( + real_data['values'], + synthetic_data['values'], + 'distplot' + ) + assert figure == mock__generate_column_plot.return_value + + +@patch('sdmetrics.visualization._generate_column_plot') +def test_get_column_plot_plot_type_none_data_float(mock__generate_column_plot): + """Test ``get_column_plot`` when ``plot_type`` is ``None`` and data is ``float``.""" + # Setup + real_data = pd.DataFrame({'values': [1., 2., 2., 3., 5.]}) + synthetic_data = pd.DataFrame({'values': [2., 2., 3., 4., 5.]}) + + # Run + figure = get_column_plot(real_data, synthetic_data, 'values') + + # Assert + mock__generate_column_plot.assert_called_once_with( + real_data['values'], + synthetic_data['values'], + 'distplot' + ) + assert figure == mock__generate_column_plot.return_value + + +@patch('sdmetrics.visualization._generate_column_plot') +def test_get_column_plot_plot_type_none_data_datetime(mock__generate_column_plot): + """Test ``get_column_plot`` when ``plot_type`` is ``None`` and data is ``datetime``.""" + # Setup + real_data = pd.DataFrame({'values': pd.to_datetime(['2021-01-20', '2022-01-21'])}) + synthetic_data = pd.DataFrame({'values': pd.to_datetime(['2021-01-20', '2022-01-21'])}) + + # Run + figure = get_column_plot(real_data, synthetic_data, 'values') + + # Assert + mock__generate_column_plot.assert_called_once_with( + real_data['values'], + synthetic_data['values'], + 'distplot' + ) + assert figure == mock__generate_column_plot.return_value + + +@patch('sdmetrics.visualization._generate_column_plot') +def test_get_column_plot_plot_type_none_data_category(mock__generate_column_plot): + """Test ``get_column_plot`` when ``plot_type`` is ``None`` and data is ``category``.""" + # Setup + real_data = pd.DataFrame({'values': ['John', 'Doe']}) + synthetic_data = pd.DataFrame({'values': ['Johanna', 'Doe']}) + + # Run + figure = get_column_plot(real_data, synthetic_data, 'values') + + # Assert + mock__generate_column_plot.assert_called_once_with( + real_data['values'], + synthetic_data['values'], + 'bar' + ) + assert figure == mock__generate_column_plot.return_value + + +@patch('sdmetrics.visualization._generate_column_plot') +def test_get_column_plot_plot_type_bar(mock__generate_column_plot): + """Test ``get_column_plot`` when ``plot_type`` is ``bar``.""" + # Setup + real_data = pd.DataFrame({'values': [1., 2., 2., 3., 5.]}) + synthetic_data = pd.DataFrame({'values': [2., 2., 3., 4., 5.]}) + + # Run + figure = get_column_plot(real_data, synthetic_data, 'values', plot_type='bar') + + # Assert + mock__generate_column_plot.assert_called_once_with( + real_data['values'], + synthetic_data['values'], + 'bar' + ) + assert figure == mock__generate_column_plot.return_value + + +@patch('sdmetrics.visualization._generate_column_plot') +def test_get_column_plot_plot_type_distplot(mock__generate_column_plot): + """Test ``get_column_plot`` when ``plot_type`` is ``distplot``.""" + # Setup + real_data = pd.DataFrame({'values': ['John', 'Doe']}) + synthetic_data = pd.DataFrame({'values': ['Johanna', 'Doe']}) + + # Run + figure = get_column_plot(real_data, synthetic_data, 'values', plot_type='distplot') + + # Assert + mock__generate_column_plot.assert_called_once_with( + real_data['values'], + synthetic_data['values'], + 'distplot' + ) + assert figure == mock__generate_column_plot.return_value