Skip to content

Commit

Permalink
cardinality clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Oct 31, 2023
1 parent 40f1d34 commit 02000f3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 43 deletions.
20 changes: 3 additions & 17 deletions sdmetrics/reports/multi_table/_properties/cardinality.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,8 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No
'Error': error_messages,
})

def get_details(self, table_name=None):
"""Return the details for the property.
Args:
table_name (str):
Table name to get the details for.
Defaults to ``None``.
Returns:
pandas.DataFrame:
The details for the property.
"""
if table_name is None:
return self.details.copy()

return self._get_details_for_table_name_with_relationships(table_name)
if self.details['Error'].isna().all():
self.details = self.details.drop('Error', axis=1)

def _get_table_relationships_plot(self, table_name):
"""Get the table relationships plot from the parent child relationship scores for a table.
Expand All @@ -94,7 +80,7 @@ def _get_table_relationships_plot(self, table_name):
Returns:
plotly.graph_objects._figure.Figure
"""
plot_data = self._get_details_for_table_name_with_relationships(table_name).copy()
plot_data = self.get_details(table_name).copy()
column_name = 'Child → Parent Relationship'
plot_data[column_name] = plot_data['Child Table'] + ' → ' + plot_data['Parent Table']
plot_data = plot_data.drop(['Child Table', 'Parent Table'], axis=1)
Expand Down
33 changes: 7 additions & 26 deletions tests/unit/reports/multi_table/_properties/test_cardinality.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def test_get_score_raises_errors(self, mock_cardinalityshapesimilarity):
progress_bar.update.assert_called()
assert progress_bar.update.call_count == 2

def test_get_details_for_table_name(self):
"""Test the ``_get_details_for_table_name`` method.
def test_get_details_with_table_name(self):
"""Test the ``get_details`` method.
Test that the method returns the correct details for the given table name,
either from the child or parent table.
Expand All @@ -122,8 +122,8 @@ def test_get_details_for_table_name(self):
})

# Run
details_users_child = cardinality._get_details_for_table_name('users_child')
details_sessions_parent = cardinality._get_details_for_table_name('sessions_parent')
details_users_child = cardinality.get_details('users_child')
details_sessions_parent = cardinality.get_details('sessions_parent')

# Assert for child table
assert details_users_child.equals(pd.DataFrame({
Expand All @@ -143,26 +143,6 @@ def test_get_details_for_table_name(self):
'Error': ['Some error']
}, index=[1]))

def test_get_details(self):
"""Test the ``get_details`` method.
Test that the method returns the correct details for the given property and table name.
"""
# Setup
mock__get_details_for_table_name = Mock(return_value='Details for table name')
cardinality = Cardinality()
cardinality.details = pd.DataFrame({'a': ['b']})
cardinality._get_details_for_table_name = mock__get_details_for_table_name

# Run
details = cardinality.get_details('table_name')
entire_details = cardinality.get_details()

# Assert
assert details == 'Details for table name'
pd.testing.assert_frame_equal(entire_details, pd.DataFrame({'a': ['b']}))
mock__get_details_for_table_name.assert_called_once_with('table_name')

def test_get_table_relationships_plot(self):
"""Test the ``_get_table_relationships_plot`` method.
Expand Down Expand Up @@ -195,12 +175,13 @@ def test_get_table_relationships_plot(self):
def test_get_visualization(self):
"""Test the ``get_visualization`` method."""
# Setup
mock__get_table_relationships_plot = Mock(return_value='Table relationships plot')
mock__get_table_relationships_plot = Mock(side_effect=[Figure()])
cardinality = Cardinality()
cardinality._get_table_relationships_plot = mock__get_table_relationships_plot

# Run
fig = cardinality.get_visualization('table_name')

# Assert
assert fig == 'Table relationships plot'
assert isinstance(fig, Figure)
mock__get_table_relationships_plot.assert_called_once_with('table_name')

0 comments on commit 02000f3

Please sign in to comment.