From 02000f34761231915c37b8d99357e9ee2ef2724d Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 30 Oct 2023 18:00:10 -0600 Subject: [PATCH] cardinality clean up --- .../multi_table/_properties/cardinality.py | 20 ++--------- .../_properties/test_cardinality.py | 33 ++++--------------- 2 files changed, 10 insertions(+), 43 deletions(-) diff --git a/sdmetrics/reports/multi_table/_properties/cardinality.py b/sdmetrics/reports/multi_table/_properties/cardinality.py index 4c5e5f51..ca683844 100644 --- a/sdmetrics/reports/multi_table/_properties/cardinality.py +++ b/sdmetrics/reports/multi_table/_properties/cardinality.py @@ -67,22 +67,8 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No 'Error': error_messages, }) - def get_details(self, table_name=None): - """Return the details for the property. - - Args: - table_name (str): - Table name to get the details for. - Defaults to ``None``. - - Returns: - pandas.DataFrame: - The details for the property. - """ - if table_name is None: - return self.details.copy() - - return self._get_details_for_table_name_with_relationships(table_name) + if self.details['Error'].isna().all(): + self.details = self.details.drop('Error', axis=1) def _get_table_relationships_plot(self, table_name): """Get the table relationships plot from the parent child relationship scores for a table. @@ -94,7 +80,7 @@ def _get_table_relationships_plot(self, table_name): Returns: plotly.graph_objects._figure.Figure """ - plot_data = self._get_details_for_table_name_with_relationships(table_name).copy() + plot_data = self.get_details(table_name).copy() column_name = 'Child → Parent Relationship' plot_data[column_name] = plot_data['Child Table'] + ' → ' + plot_data['Parent Table'] plot_data = plot_data.drop(['Child Table', 'Parent Table'], axis=1) diff --git a/tests/unit/reports/multi_table/_properties/test_cardinality.py b/tests/unit/reports/multi_table/_properties/test_cardinality.py index 56e55e2c..f0d47698 100644 --- a/tests/unit/reports/multi_table/_properties/test_cardinality.py +++ b/tests/unit/reports/multi_table/_properties/test_cardinality.py @@ -105,8 +105,8 @@ def test_get_score_raises_errors(self, mock_cardinalityshapesimilarity): progress_bar.update.assert_called() assert progress_bar.update.call_count == 2 - def test_get_details_for_table_name(self): - """Test the ``_get_details_for_table_name`` method. + def test_get_details_with_table_name(self): + """Test the ``get_details`` method. Test that the method returns the correct details for the given table name, either from the child or parent table. @@ -122,8 +122,8 @@ def test_get_details_for_table_name(self): }) # Run - details_users_child = cardinality._get_details_for_table_name('users_child') - details_sessions_parent = cardinality._get_details_for_table_name('sessions_parent') + details_users_child = cardinality.get_details('users_child') + details_sessions_parent = cardinality.get_details('sessions_parent') # Assert for child table assert details_users_child.equals(pd.DataFrame({ @@ -143,26 +143,6 @@ def test_get_details_for_table_name(self): 'Error': ['Some error'] }, index=[1])) - def test_get_details(self): - """Test the ``get_details`` method. - - Test that the method returns the correct details for the given property and table name. - """ - # Setup - mock__get_details_for_table_name = Mock(return_value='Details for table name') - cardinality = Cardinality() - cardinality.details = pd.DataFrame({'a': ['b']}) - cardinality._get_details_for_table_name = mock__get_details_for_table_name - - # Run - details = cardinality.get_details('table_name') - entire_details = cardinality.get_details() - - # Assert - assert details == 'Details for table name' - pd.testing.assert_frame_equal(entire_details, pd.DataFrame({'a': ['b']})) - mock__get_details_for_table_name.assert_called_once_with('table_name') - def test_get_table_relationships_plot(self): """Test the ``_get_table_relationships_plot`` method. @@ -195,7 +175,7 @@ def test_get_table_relationships_plot(self): def test_get_visualization(self): """Test the ``get_visualization`` method.""" # Setup - mock__get_table_relationships_plot = Mock(return_value='Table relationships plot') + mock__get_table_relationships_plot = Mock(side_effect=[Figure()]) cardinality = Cardinality() cardinality._get_table_relationships_plot = mock__get_table_relationships_plot @@ -203,4 +183,5 @@ def test_get_visualization(self): fig = cardinality.get_visualization('table_name') # Assert - assert fig == 'Table relationships plot' + assert isinstance(fig, Figure) + mock__get_table_relationships_plot.assert_called_once_with('table_name')