diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 17ba8c08..bee928db 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -737,7 +737,11 @@ def test__generate_scatter_plot(px_mock): color_discrete_map={'Real': '#000036', 'Synthetic': '#01E0C9'}, symbol='Data', ) - mock_figure.update_layout.assert_called_once() + mock_figure.update_layout.assert_called_once_with( + title="Real vs. Synthetic Data for columns 'col1' and 'col2'", + plot_bgcolor='#F5F5F8', + font={'size': 18}, + ) assert fig == mock_figure @@ -787,7 +791,11 @@ def test__generate_heatmap_plot(px_mock): facet_col='Data', histnorm='probability', ) - mock_figure.update_layout.assert_called_once() + mock_figure.update_layout.assert_called_once_with( + title_text="Real vs. Synthetic Data for columns 'col1' and 'col2'", + coloraxis={'colorscale': ['#000036', '#01E0C9']}, + font={'size': 18}, + ) mock_figure.for_each_annotation.assert_called_once() assert fig == mock_figure @@ -909,10 +917,37 @@ def test__generate_box_plot(px_mock): color='Data', color_discrete_map={'Real': '#000036', 'Synthetic': '#01E0C9'}, ) - mock_figure.update_layout.assert_called_once() + mock_figure.update_layout.assert_called_once_with( + title="Real vs. Synthetic Data for columns 'col1' and 'col2'", + plot_bgcolor='#F5F5F8', + font={'size': 18}, + ) assert fig == mock_figure +@patch('sdmetrics.visualization.px') +def test__generate_box_plot_title_one_dataset_only(px_mock): + """Test the ``_generate_box_plot`` title when only one dataset is passed.""" + # Setup + real_data = pd.DataFrame({ + 'col1': [1, 2, 3, 4], + 'col2': ['a', 'b', 'c', 'd'], + 'Data': ['Real'] * 4, + }) + columns = ['col1', 'col2'] + mock_figure = Mock() + px_mock.box.side_effect = [mock_figure, mock_figure] + + # Run + fig_real = _generate_box_plot(real_data, columns) + + # Assert + mock_figure.update_layout.assert_called_once_with( + title="Real Data for columns 'col1' and 'col2'", plot_bgcolor='#F5F5F8', font={'size': 18} + ) + assert fig_real == mock_figure + + def test_get_column_pair_plot_invalid_column_names(): """Test ``get_column_pair_plot`` method with invalid ``column_names``.""" # Setup @@ -1013,8 +1048,7 @@ def test_get_column_pair_plot_plot_single_data(mock__generate_scatter_plot): assert synth_fig == 'mock_return_2' -@patch('sdmetrics.visualization._generate_scatter_plot') -def test_get_column_pair_plot_plot_no_data(mock__generate_scatter_plot): +def test_get_column_pair_plot_plot_no_data(): """Test ``get_column_pair_plot`` with neither real or synthetic data""" # Setup columns = ['amount', 'price']