Skip to content

Commit

Permalink
Add generate_column_plot
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer committed Sep 25, 2023
1 parent efeb278 commit b3c63de
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 11 deletions.
70 changes: 60 additions & 10 deletions sdmetrics/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pandas.api.types import is_datetime64_dtype

from sdmetrics.reports.utils import PlotConfig
from sdmetrics.utils import get_missing_percentage
from sdmetrics.utils import get_missing_percentage, is_datetime


def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
Expand All @@ -25,7 +25,7 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
plotly.graph_objects._figure.Figure
"""
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
default_histogram_kwargs = {
histogram_kwargs = {
'x': 'values',
'color': 'Data',
'barmode': 'group',
Expand All @@ -34,9 +34,10 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
'pattern_shape_sequence': ['', '/'],
'histnorm': 'probability density',
}
histogram_kwargs.update(plot_kwargs)
fig = px.histogram(
all_data,
**{**default_histogram_kwargs, **plot_kwargs}
**histogram_kwargs
)

return fig
Expand Down Expand Up @@ -105,25 +106,28 @@ def _generate_column_plot(real_column,

column_name = real_column.name if hasattr(real_column, 'name') else ''

real_data = pd.DataFrame({'values': real_column.copy()})
missing_data_real = get_missing_percentage(real_column)
missing_data_synthetic = get_missing_percentage(synthetic_column)

real_data = pd.DataFrame({'values': real_column.copy().dropna()})
real_data['Data'] = 'Real'
synthetic_data = pd.DataFrame({'values': synthetic_column.copy()})
synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()})
synthetic_data['Data'] = 'Synthetic'

is_datetime_sdtype = False
if is_datetime64_dtype(real_column.dtype):
is_datetime_sdtype = True
real_data = real_data.astype('int64')
synthetic_data = synthetic_data.astype('int64')

missing_data_real = get_missing_percentage(real_column)
missing_data_synthetic = get_missing_percentage(synthetic_column)
real_data['values'] = real_data['values'].astype('int64')
synthetic_data['values'] = synthetic_data['values'].astype('int64')

trace_args = {}

if plot_type == 'bar':
fig = _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs)
elif plot_type == 'distplot':
if x_label is None:
x_label = 'Value'

fig = _generate_column_distplot(real_data, synthetic_data, plot_kwargs)
trace_args = {'fill': 'tozeroy'}

Expand Down Expand Up @@ -259,3 +263,49 @@ def get_cardinality_plot(real_data, synthetic_data, child_table_name, parent_tab
)

return fig


def get_column_plot(real_data, synthetic_data, column_name, plot_type=None):
"""Return a plot of the real and synthetic data for a given column.
Args:
real_data (pandas.DataFrame):
The real table data.
synthetic_data (pandas.DataFrame):
The synthetic table data.
column_name (str):
The name of the column.
plot_type (str or None):
The plot to be used. Can choose between ``distplot``, ``bar`` or ``None``. If ``None`
select between ``distplot`` or ``bar`` depending on the data that the column contains,
``distplot`` for datetime and numerical values and ``bar`` for categorical.
Defaults to ``None``.
Returns:
plotly.graph_objects._figure.Figure
"""
if plot_type not in ['bar', 'distplot', None]:
raise ValueError(
f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot', None]."
)

if column_name not in real_data.columns:
raise ValueError(f"Column '{column_name}' not found in real table data.")
if column_name not in synthetic_data.columns:
raise ValueError(f"Column '{column_name}' not found in synthetic table data.")

column_is_datetime = is_datetime(real_data[column_name])
real_column = real_data[column_name]
dtype = real_column.dropna().infer_objects().dtype.kind
if plot_type is None:
if column_is_datetime or dtype in ('i', 'f'):
plot_type = 'distplot'
else:
plot_type = 'bar'

real_column = real_data[column_name]
synthetic_column = synthetic_data[column_name]

fig = _generate_column_plot(real_column, synthetic_column, plot_type)

return fig
144 changes: 143 additions & 1 deletion tests/unit/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from sdmetrics.visualization import (
_generate_cardinality_plot, _get_cardinality, get_cardinality_plot)
_generate_cardinality_plot, _get_cardinality, get_cardinality_plot, get_column_plot)
from tests.utils import DataFrameMatcher, SeriesMatcher


Expand Down Expand Up @@ -221,3 +221,145 @@ def test_get_cardinality_plot_bad_plot_type():
real_data, synthetic_data, child_table_name, parent_table_name, child_foreign_key,
parent_primary_key, plot_type='bad_type'
)


def test_get_column_plot_column_not_found():
"""Test the ``get_column_plot`` method when column is not present."""
# Setup
real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]})
synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]})

# Run and assert
match = re.escape("Column 'start_date' not found in real table data.")
with pytest.raises(ValueError, match=match):
get_column_plot(real_data, synthetic_data, 'start_date')

match = re.escape("Column 'start_date' not found in synthetic table data.")
with pytest.raises(ValueError, match=match):
get_column_plot(pd.DataFrame({'start_date': []}), synthetic_data, 'start_date')


def test_get_column_plot_bad_plot_type():
"""Test the ``get_column_plot`` method."""
# Setup
real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]})
synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]})

# Run and assert
match = re.escape("Invalid plot_type 'bad_type'. Please use one of ['bar', 'distplot', None].")
with pytest.raises(ValueError, match=match):
get_column_plot(real_data, synthetic_data, 'valeus', plot_type='bad_type')


@patch('sdmetrics.visualization._generate_column_plot')
def test_get_column_plot_plot_type_none_data_int(mock__generate_column_plot):
"""Test ``get_column_plot`` when ``plot_type`` is ``None`` and data is ``int``."""
# Setup
real_data = pd.DataFrame({'values': [1, 2, 2, 3, 5]})
synthetic_data = pd.DataFrame({'values': [2, 2, 3, 4, 5]})

# Run
figure = get_column_plot(real_data, synthetic_data, 'values')

# Assert
mock__generate_column_plot.assert_called_once_with(
real_data['values'],
synthetic_data['values'],
'distplot'
)
assert figure == mock__generate_column_plot.return_value


@patch('sdmetrics.visualization._generate_column_plot')
def test_get_column_plot_plot_type_none_data_float(mock__generate_column_plot):
"""Test ``get_column_plot`` when ``plot_type`` is ``None`` and data is ``float``."""
# Setup
real_data = pd.DataFrame({'values': [1., 2., 2., 3., 5.]})
synthetic_data = pd.DataFrame({'values': [2., 2., 3., 4., 5.]})

# Run
figure = get_column_plot(real_data, synthetic_data, 'values')

# Assert
mock__generate_column_plot.assert_called_once_with(
real_data['values'],
synthetic_data['values'],
'distplot'
)
assert figure == mock__generate_column_plot.return_value


@patch('sdmetrics.visualization._generate_column_plot')
def test_get_column_plot_plot_type_none_data_datetime(mock__generate_column_plot):
"""Test ``get_column_plot`` when ``plot_type`` is ``None`` and data is ``datetime``."""
# Setup
real_data = pd.DataFrame({'values': pd.to_datetime(['2021-01-20', '2022-01-21'])})
synthetic_data = pd.DataFrame({'values': pd.to_datetime(['2021-01-20', '2022-01-21'])})

# Run
figure = get_column_plot(real_data, synthetic_data, 'values')

# Assert
mock__generate_column_plot.assert_called_once_with(
real_data['values'],
synthetic_data['values'],
'distplot'
)
assert figure == mock__generate_column_plot.return_value


@patch('sdmetrics.visualization._generate_column_plot')
def test_get_column_plot_plot_type_none_data_category(mock__generate_column_plot):
"""Test ``get_column_plot`` when ``plot_type`` is ``None`` and data is ``category``."""
# Setup
real_data = pd.DataFrame({'values': ['John', 'Doe']})
synthetic_data = pd.DataFrame({'values': ['Johanna', 'Doe']})

# Run
figure = get_column_plot(real_data, synthetic_data, 'values')

# Assert
mock__generate_column_plot.assert_called_once_with(
real_data['values'],
synthetic_data['values'],
'bar'
)
assert figure == mock__generate_column_plot.return_value


@patch('sdmetrics.visualization._generate_column_plot')
def test_get_column_plot_plot_type_bar(mock__generate_column_plot):
"""Test ``get_column_plot`` when ``plot_type`` is ``bar``."""
# Setup
real_data = pd.DataFrame({'values': [1., 2., 2., 3., 5.]})
synthetic_data = pd.DataFrame({'values': [2., 2., 3., 4., 5.]})

# Run
figure = get_column_plot(real_data, synthetic_data, 'values', plot_type='bar')

# Assert
mock__generate_column_plot.assert_called_once_with(
real_data['values'],
synthetic_data['values'],
'bar'
)
assert figure == mock__generate_column_plot.return_value


@patch('sdmetrics.visualization._generate_column_plot')
def test_get_column_plot_plot_type_distplot(mock__generate_column_plot):
"""Test ``get_column_plot`` when ``plot_type`` is ``distplot``."""
# Setup
real_data = pd.DataFrame({'values': ['John', 'Doe']})
synthetic_data = pd.DataFrame({'values': ['Johanna', 'Doe']})

# Run
figure = get_column_plot(real_data, synthetic_data, 'values', plot_type='distplot')

# Assert
mock__generate_column_plot.assert_called_once_with(
real_data['values'],
synthetic_data['values'],
'distplot'
)
assert figure == mock__generate_column_plot.return_value

0 comments on commit b3c63de

Please sign in to comment.