Skip to content

Commit

Permalink
Allow Cardinality Plot to work with individual datasets (#597)
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 authored Jun 26, 2024
1 parent 647525d commit 85f221f
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 18 deletions.
65 changes: 47 additions & 18 deletions sdmetrics/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,26 @@ def _generate_column_plot(
return fig


def _get_max_between_datasets(real_data, synthetic_data):
if synthetic_data is None and real_data is None:
raise ValueError('Cannot get max between two None values.')
if real_data is None:
return max(synthetic_data)
elif synthetic_data is None:
return max(real_data)
return max(max(real_data), max(synthetic_data))


def _get_min_between_datasets(real_data, synthetic_data):
if synthetic_data is None and real_data is None:
raise ValueError('Cannot get min between two None values.')
if real_data is None:
return min(synthetic_data)
elif synthetic_data is None:
return min(real_data)
return min(min(real_data), min(synthetic_data))


def _generate_cardinality_plot(
real_data, synthetic_data, parent_primary_key, child_foreign_key, plot_type='bar'
):
Expand All @@ -376,8 +396,8 @@ def _generate_cardinality_plot(

plot_kwargs = {}
if plot_type == 'bar':
max_cardinality = max(max(real_data), max(synthetic_data))
min_cardinality = min(min(real_data), min(synthetic_data))
max_cardinality = _get_max_between_datasets(real_data, synthetic_data)
min_cardinality = _get_min_between_datasets(real_data, synthetic_data)
plot_kwargs = {'nbins': max_cardinality - min_cardinality + 1}

return _generate_column_plot(
Expand Down Expand Up @@ -420,10 +440,10 @@ def get_cardinality_plot(
"""Return a plot of the cardinality of the parent-child relationship.
Args:
real_data (dict):
The real data.
synthetic_data (dict):
The synthetic data.
real_data (dict or None):
The real data. If None this data will not be graphed.
synthetic_data (dict or None):
The synthetic data. If None this data will not be graphed.
child_table_name (string):
The name of the child table.
parent_table_name (string):
Expand All @@ -442,18 +462,27 @@ def get_cardinality_plot(
if plot_type not in ['bar', 'distplot']:
raise ValueError(f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot'].")

real_cardinality = _get_cardinality(
real_data[parent_table_name],
real_data[child_table_name],
parent_primary_key,
child_foreign_key,
)
synth_cardinality = _get_cardinality(
synthetic_data[parent_table_name],
synthetic_data[child_table_name],
parent_primary_key,
child_foreign_key,
)
if real_data is None and synthetic_data is None:
raise ValueError('No data provided to plot. Please provide either real or synthetic data.')

real_cardinality = None
synth_cardinality = None

if real_data is not None:
real_cardinality = _get_cardinality(
real_data[parent_table_name],
real_data[child_table_name],
parent_primary_key,
child_foreign_key,
)

if synthetic_data is not None:
synth_cardinality = _get_cardinality(
synthetic_data[parent_table_name],
synthetic_data[child_table_name],
parent_primary_key,
child_foreign_key,
)

fig = _generate_cardinality_plot(
real_cardinality,
Expand Down
108 changes: 108 additions & 0 deletions tests/unit/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
_generate_line_plot,
_generate_scatter_plot,
_get_cardinality,
_get_max_between_datasets,
_get_min_between_datasets,
get_cardinality_plot,
get_column_line_plot,
get_column_pair_plot,
Expand Down Expand Up @@ -48,6 +50,54 @@ def test_get_cardinality():
pd.testing.assert_series_equal(result, expected_result)


def test__get_max_between_datasets():
"""Test the ``_get_max_between_datasets`` method"""
# Setup
mock_real_data = pd.Series([1, 1, 2, 2, 2])
mock_synthetic_data = pd.Series([3, 3, 4])

# Run
real_only_val = _get_max_between_datasets(mock_real_data, None)
synth_only_val = _get_max_between_datasets(None, mock_synthetic_data)
all_val = _get_max_between_datasets(mock_real_data, mock_synthetic_data)

# Assert
expected_real_only_val = 2
expected_synth_only_val = 4
expected_all_val = 4
assert expected_real_only_val == real_only_val
assert expected_synth_only_val == synth_only_val
assert expected_all_val == all_val

error_msg = re.escape('Cannot get max between two None values.')
with pytest.raises(ValueError, match=error_msg):
_get_max_between_datasets(None, None)


def test__get_min_between_datasets():
"""Test the ``_get_min_between_datasets`` method"""
# Setup
mock_real_data = pd.Series([1, 1, 2, 2, 2])
mock_synthetic_data = pd.Series([3, 3, 4])

# Run
real_only_val = _get_min_between_datasets(mock_real_data, None)
synth_only_val = _get_min_between_datasets(None, mock_synthetic_data)
all_val = _get_min_between_datasets(mock_real_data, mock_synthetic_data)

# Assert
expected_real_only_val = 1
expected_synth_only_val = 3
expected_all_val = 1
assert expected_real_only_val == real_only_val
assert expected_synth_only_val == synth_only_val
assert expected_all_val == all_val

error_msg = re.escape('Cannot get min between two None values.')
with pytest.raises(ValueError, match=error_msg):
_get_min_between_datasets(None, None)


@patch('sdmetrics.visualization.px')
def test_generate_cardinality_bar_plot(mock_px):
"""Test the ``_generate_cardinality_plot`` method."""
Expand Down Expand Up @@ -217,6 +267,64 @@ def test_get_cardinality_plot(mock_generate_cardinality_plot, mock_get_cardinali
assert mock_generate_cardinality_plot.call_args.kwargs == {'plot_type': 'bar'}


def test_get_cardinality_plot_no_data():
"""Test the ``get_cardinality_plot`` method with no data passed in."""
# Run and assert
error_msg = re.escape('No data provided to plot. Please provide either real or synthetic data.')
with pytest.raises(ValueError, match=error_msg):
get_cardinality_plot(
None, None, 'mock_child_table', 'mock_parent_name', 'child_fk', 'parent_fk', 'bar'
)


@patch('sdmetrics.visualization._get_cardinality')
@patch('sdmetrics.visualization._generate_cardinality_plot')
def test_get_cardinality_plot_plot_single_data(
mock_generate_cardinality_plot, mock_get_cardinality
):
"""Test the ``get_cardinality_plot`` method runs fine with individual datasets."""
# Setup
real_data = {'table1': None, 'table2': None}
synthetic_data = {'table1': None, 'table2': None}
child_foreign_key = 'child_key'
parent_primary_key = 'parent_key'
parent_table_name = 'table1'
child_table_name = 'table2'

real_cardinality = pd.Series([1, 2, 2, 3, 5])
synthetic_cardinality = pd.Series([2, 2, 3, 4, 5])
mock_get_cardinality.side_effect = [real_cardinality, synthetic_cardinality]

mock_generate_cardinality_plot.side_effect = ['mock_return_1', 'mock_return_2']

# Run
fig_real = get_cardinality_plot(
real_data,
None,
child_table_name,
parent_table_name,
child_foreign_key,
parent_primary_key,
)
fig_synth = get_cardinality_plot(
None,
synthetic_data,
child_table_name,
parent_table_name,
child_foreign_key,
parent_primary_key,
)
assert fig_real == 'mock_return_1'
assert fig_synth == 'mock_return_2'

# Assert by checking the calls
calls = [
call(real_data['table1'], real_data['table2'], 'parent_key', 'child_key'),
call(synthetic_data['table1'], synthetic_data['table2'], 'parent_key', 'child_key'),
]
mock_get_cardinality.assert_has_calls(calls)


def test_get_cardinality_plot_bad_plot_type():
"""Test the ``get_cardinality_plot`` method."""
# Setup
Expand Down

0 comments on commit 85f221f

Please sign in to comment.