Skip to content

Commit

Permalink
Add get_column_plot (#455)
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer authored Sep 27, 2023
1 parent efeb278 commit f80dc79
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 267 deletions.
3 changes: 1 addition & 2 deletions sdmetrics/reports/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Reports for sdmetrics."""
from sdmetrics.reports.utils import get_column_pair_plot, get_column_plot
from sdmetrics.reports.utils import get_column_pair_plot

__all__ = [
'get_column_pair_plot',
'get_column_plot',
]
47 changes: 0 additions & 47 deletions sdmetrics/reports/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,53 +224,6 @@ def make_continuous_column_plot(real_column, synthetic_column, sdtype):
return fig


def get_column_plot(real_data, synthetic_data, column_name, metadata):
"""Return a plot of the real and synthetic data for a given column.
Args:
real_data (pandas.DataFrame):
The real table data.
synthetic_data (pandas.DataFrame):
The synthetic table data.
column_name (str):
The name of the column.
metadata (dict):
The table metadata.
Returns:
plotly.graph_objects._figure.Figure
"""
columns = get_columns_from_metadata(metadata)
if column_name not in columns:
raise ValueError(f"Column '{column_name}' not found in metadata.")
elif 'sdtype' not in columns[column_name]:
raise ValueError(f"Metadata for column '{column_name}' missing 'type' information.")
if column_name not in real_data.columns:
raise ValueError(f"Column '{column_name}' not found in real table data.")
if column_name not in synthetic_data.columns:
raise ValueError(f"Column '{column_name}' not found in synthetic table data.")

column_meta = columns[column_name]
sdtype = get_type_from_column_meta(columns[column_name])
if sdtype == 'datetime':
real_column, synthetic_column = convert_datetime_columns(
real_data[column_name],
synthetic_data[column_name],
column_meta
)
else:
real_column = real_data[column_name]
synthetic_column = synthetic_data[column_name]
if sdtype in CONTINUOUS_SDTYPES:
fig = make_continuous_column_plot(real_column, synthetic_column, sdtype)
elif sdtype in DISCRETE_SDTYPES:
fig = make_discrete_column_plot(real_column, synthetic_column, sdtype)
else:
raise ValueError(f"sdtype of type '{sdtype}' not recognized.")

return fig


def make_continuous_column_pair_plot(real_data, synthetic_data):
"""Make a column pair plot for continuous data.
Expand Down
70 changes: 60 additions & 10 deletions sdmetrics/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pandas.api.types import is_datetime64_dtype

from sdmetrics.reports.utils import PlotConfig
from sdmetrics.utils import get_missing_percentage
from sdmetrics.utils import get_missing_percentage, is_datetime


def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
Expand All @@ -25,7 +25,7 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
plotly.graph_objects._figure.Figure
"""
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
default_histogram_kwargs = {
histogram_kwargs = {
'x': 'values',
'color': 'Data',
'barmode': 'group',
Expand All @@ -34,9 +34,10 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
'pattern_shape_sequence': ['', '/'],
'histnorm': 'probability density',
}
histogram_kwargs.update(plot_kwargs)
fig = px.histogram(
all_data,
**{**default_histogram_kwargs, **plot_kwargs}
**histogram_kwargs
)

return fig
Expand Down Expand Up @@ -105,25 +106,28 @@ def _generate_column_plot(real_column,

column_name = real_column.name if hasattr(real_column, 'name') else ''

real_data = pd.DataFrame({'values': real_column.copy()})
missing_data_real = get_missing_percentage(real_column)
missing_data_synthetic = get_missing_percentage(synthetic_column)

real_data = pd.DataFrame({'values': real_column.copy().dropna()})
real_data['Data'] = 'Real'
synthetic_data = pd.DataFrame({'values': synthetic_column.copy()})
synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()})
synthetic_data['Data'] = 'Synthetic'

is_datetime_sdtype = False
if is_datetime64_dtype(real_column.dtype):
is_datetime_sdtype = True
real_data = real_data.astype('int64')
synthetic_data = synthetic_data.astype('int64')

missing_data_real = get_missing_percentage(real_column)
missing_data_synthetic = get_missing_percentage(synthetic_column)
real_data['values'] = real_data['values'].astype('int64')
synthetic_data['values'] = synthetic_data['values'].astype('int64')

trace_args = {}

if plot_type == 'bar':
fig = _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs)
elif plot_type == 'distplot':
if x_label is None:
x_label = 'Value'

fig = _generate_column_distplot(real_data, synthetic_data, plot_kwargs)
trace_args = {'fill': 'tozeroy'}

Expand Down Expand Up @@ -259,3 +263,49 @@ def get_cardinality_plot(real_data, synthetic_data, child_table_name, parent_tab
)

return fig


def get_column_plot(real_data, synthetic_data, column_name, plot_type=None):
"""Return a plot of the real and synthetic data for a given column.
Args:
real_data (pandas.DataFrame):
The real table data.
synthetic_data (pandas.DataFrame):
The synthetic table data.
column_name (str):
The name of the column.
plot_type (str or None):
The plot to be used. Can choose between ``distplot``, ``bar`` or ``None``. If ``None`
select between ``distplot`` or ``bar`` depending on the data that the column contains:
``distplot`` for datetime and numerical values and ``bar`` for categorical.
Defaults to ``None``.
Returns:
plotly.graph_objects._figure.Figure
"""
if plot_type not in ['bar', 'distplot', None]:
raise ValueError(
f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot', None]."
)

if column_name not in real_data.columns:
raise ValueError(f"Column '{column_name}' not found in real table data.")
if column_name not in synthetic_data.columns:
raise ValueError(f"Column '{column_name}' not found in synthetic table data.")

real_column = real_data[column_name]
if plot_type is None:
column_is_datetime = is_datetime(real_data[column_name])
dtype = real_column.dropna().infer_objects().dtype.kind
if column_is_datetime or dtype in ('i', 'f'):
plot_type = 'distplot'
else:
plot_type = 'bar'

real_column = real_data[column_name]
synthetic_column = synthetic_data[column_name]

fig = _generate_column_plot(real_column, synthetic_column, plot_type)

return fig
Loading

0 comments on commit f80dc79

Please sign in to comment.