diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index 64630271..734f5aac 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -365,6 +365,26 @@ def _generate_column_plot( return fig +def _get_max_between_datasets(real_data, synthetic_data): + if synthetic_data is None and real_data is None: + raise ValueError('Cannot get max between two None values.') + if real_data is None: + return max(synthetic_data) + elif synthetic_data is None: + return max(real_data) + return max(max(real_data), max(synthetic_data)) + + +def _get_min_between_datasets(real_data, synthetic_data): + if synthetic_data is None and real_data is None: + raise ValueError('Cannot get min between two None values.') + if real_data is None: + return min(synthetic_data) + elif synthetic_data is None: + return min(real_data) + return min(min(real_data), min(synthetic_data)) + + def _generate_cardinality_plot( real_data, synthetic_data, parent_primary_key, child_foreign_key, plot_type='bar' ): @@ -376,8 +396,8 @@ def _generate_cardinality_plot( plot_kwargs = {} if plot_type == 'bar': - max_cardinality = max(max(real_data), max(synthetic_data)) - min_cardinality = min(min(real_data), min(synthetic_data)) + max_cardinality = _get_max_between_datasets(real_data, synthetic_data) + min_cardinality = _get_min_between_datasets(real_data, synthetic_data) plot_kwargs = {'nbins': max_cardinality - min_cardinality + 1} return _generate_column_plot( @@ -420,10 +440,10 @@ def get_cardinality_plot( """Return a plot of the cardinality of the parent-child relationship. Args: - real_data (dict): - The real data. - synthetic_data (dict): - The synthetic data. + real_data (dict or None): + The real data. If None this data will not be graphed. + synthetic_data (dict or None): + The synthetic data. If None this data will not be graphed. child_table_name (string): The name of the child table. parent_table_name (string): @@ -442,18 +462,27 @@ def get_cardinality_plot( if plot_type not in ['bar', 'distplot']: raise ValueError(f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot'].") - real_cardinality = _get_cardinality( - real_data[parent_table_name], - real_data[child_table_name], - parent_primary_key, - child_foreign_key, - ) - synth_cardinality = _get_cardinality( - synthetic_data[parent_table_name], - synthetic_data[child_table_name], - parent_primary_key, - child_foreign_key, - ) + if real_data is None and synthetic_data is None: + raise ValueError('No data provided to plot. Please provide either real or synthetic data.') + + real_cardinality = None + synth_cardinality = None + + if real_data is not None: + real_cardinality = _get_cardinality( + real_data[parent_table_name], + real_data[child_table_name], + parent_primary_key, + child_foreign_key, + ) + + if synthetic_data is not None: + synth_cardinality = _get_cardinality( + synthetic_data[parent_table_name], + synthetic_data[child_table_name], + parent_primary_key, + child_foreign_key, + ) fig = _generate_cardinality_plot( real_cardinality, diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 7b3aaaa7..632a0082 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -15,6 +15,8 @@ _generate_line_plot, _generate_scatter_plot, _get_cardinality, + _get_max_between_datasets, + _get_min_between_datasets, get_cardinality_plot, get_column_line_plot, get_column_pair_plot, @@ -48,6 +50,54 @@ def test_get_cardinality(): pd.testing.assert_series_equal(result, expected_result) +def test__get_max_between_datasets(): + """Test the ``_get_max_between_datasets`` method""" + # Setup + mock_real_data = pd.Series([1, 1, 2, 2, 2]) + mock_synthetic_data = pd.Series([3, 3, 4]) + + # Run + real_only_val = _get_max_between_datasets(mock_real_data, None) + synth_only_val = _get_max_between_datasets(None, mock_synthetic_data) + all_val = _get_max_between_datasets(mock_real_data, mock_synthetic_data) + + # Assert + expected_real_only_val = 2 + expected_synth_only_val = 4 + expected_all_val = 4 + assert expected_real_only_val == real_only_val + assert expected_synth_only_val == synth_only_val + assert expected_all_val == all_val + + error_msg = re.escape('Cannot get max between two None values.') + with pytest.raises(ValueError, match=error_msg): + _get_max_between_datasets(None, None) + + +def test__get_min_between_datasets(): + """Test the ``_get_min_between_datasets`` method""" + # Setup + mock_real_data = pd.Series([1, 1, 2, 2, 2]) + mock_synthetic_data = pd.Series([3, 3, 4]) + + # Run + real_only_val = _get_min_between_datasets(mock_real_data, None) + synth_only_val = _get_min_between_datasets(None, mock_synthetic_data) + all_val = _get_min_between_datasets(mock_real_data, mock_synthetic_data) + + # Assert + expected_real_only_val = 1 + expected_synth_only_val = 3 + expected_all_val = 1 + assert expected_real_only_val == real_only_val + assert expected_synth_only_val == synth_only_val + assert expected_all_val == all_val + + error_msg = re.escape('Cannot get min between two None values.') + with pytest.raises(ValueError, match=error_msg): + _get_min_between_datasets(None, None) + + @patch('sdmetrics.visualization.px') def test_generate_cardinality_bar_plot(mock_px): """Test the ``_generate_cardinality_plot`` method.""" @@ -217,6 +267,64 @@ def test_get_cardinality_plot(mock_generate_cardinality_plot, mock_get_cardinali assert mock_generate_cardinality_plot.call_args.kwargs == {'plot_type': 'bar'} +def test_get_cardinality_plot_no_data(): + """Test the ``get_cardinality_plot`` method with no data passed in.""" + # Run and assert + error_msg = re.escape('No data provided to plot. Please provide either real or synthetic data.') + with pytest.raises(ValueError, match=error_msg): + get_cardinality_plot( + None, None, 'mock_child_table', 'mock_parent_name', 'child_fk', 'parent_fk', 'bar' + ) + + +@patch('sdmetrics.visualization._get_cardinality') +@patch('sdmetrics.visualization._generate_cardinality_plot') +def test_get_cardinality_plot_plot_single_data( + mock_generate_cardinality_plot, mock_get_cardinality +): + """Test the ``get_cardinality_plot`` method runs fine with individual datasets.""" + # Setup + real_data = {'table1': None, 'table2': None} + synthetic_data = {'table1': None, 'table2': None} + child_foreign_key = 'child_key' + parent_primary_key = 'parent_key' + parent_table_name = 'table1' + child_table_name = 'table2' + + real_cardinality = pd.Series([1, 2, 2, 3, 5]) + synthetic_cardinality = pd.Series([2, 2, 3, 4, 5]) + mock_get_cardinality.side_effect = [real_cardinality, synthetic_cardinality] + + mock_generate_cardinality_plot.side_effect = ['mock_return_1', 'mock_return_2'] + + # Run + fig_real = get_cardinality_plot( + real_data, + None, + child_table_name, + parent_table_name, + child_foreign_key, + parent_primary_key, + ) + fig_synth = get_cardinality_plot( + None, + synthetic_data, + child_table_name, + parent_table_name, + child_foreign_key, + parent_primary_key, + ) + assert fig_real == 'mock_return_1' + assert fig_synth == 'mock_return_2' + + # Assert by checking the calls + calls = [ + call(real_data['table1'], real_data['table2'], 'parent_key', 'child_key'), + call(synthetic_data['table1'], synthetic_data['table2'], 'parent_key', 'child_key'), + ] + mock_get_cardinality.assert_has_calls(calls) + + def test_get_cardinality_plot_bad_plot_type(): """Test the ``get_cardinality_plot`` method.""" # Setup