Skip to content

Commit

Permalink
unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Oct 30, 2023
1 parent 05a800e commit fea8ead
Show file tree
Hide file tree
Showing 2 changed files with 332 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tests/unit/reports/multi_table/_properties/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,36 @@ def test__get_num_iterations(self):
base_property._num_iteration_case = 'inter_table_column_pair'
assert base_property._get_num_iterations(metadata) == 11

def test__extract_tuple(self):
"""Test the ``_extract_tuple`` method."""
# Setup
base_property = BaseMultiTableProperty()
real_user_df = pd.DataFrame({
'user_id': ['user1', 'user2'],
'columnA': ['A', 'B'],
'columnB': [np.nan, 1.0]
})
real_session_df = pd.DataFrame({
'session_id': ['session1', 'session2', 'session3'],
'user_id': ['user1', 'user1', 'user2'],
'columnC': ['X', 'Y', 'Z'],
'columnD': [4.0, 6.0, 7.0]
})

real_data = {'users': real_user_df, 'sessions': real_session_df}
relation = {
'parent_table_name': 'users',
'child_table_name': 'sessions',
'parent_primary_key': 'user_id',
'child_foreign_key': 'user_id'
}

# Run
real_columns = base_property._extract_tuple(real_data, relation)

# Assert
assert real_columns == (real_data['users']['user_id'], real_data['sessions']['user_id'])

def test__generate_details_property(self):
"""Test the ``_generate_details`` method."""
# Setup
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
"""Test multi-table RelationshipValidity properties."""

from unittest.mock import Mock, patch

import numpy as np
import pandas as pd
import pytest
from plotly.graph_objects import Figure

from sdmetrics.reports.multi_table._properties.relationship_validity import RelationshipValidity


@pytest.fixture()
def real_data_fixture():
real_user_df = pd.DataFrame({
'user_id': ['user1', 'user2'],
'columnA': ['A', 'B'],
'columnB': [np.nan, 1.0]
})
real_session_df = pd.DataFrame({
'session_id': ['session1', 'session2', 'session3'],
'user_id': ['user1', 'user1', 'user2'],
'columnC': ['X', 'Y', 'Z'],
'columnD': [4.0, 6.0, 7.0]
})
return {'users': real_user_df, 'sessions': real_session_df}


@pytest.fixture()
def synthetic_data_fixture():
synthetic_user_df = pd.DataFrame({
'user_id': ['user1', 'user2'],
'columnA': ['A', 'A'],
'columnB': [0.5, np.nan]
})
synthetic_session_df = pd.DataFrame({
'session_id': ['session1', 'session2', 'session3'],
'user_id': ['user1', 'user1', 'user2'],
'columnC': ['X', 'Z', 'Y'],
'columnD': [3.6, 5.0, 6.0]
})
return {'users': synthetic_user_df, 'sessions': synthetic_session_df}


@pytest.fixture()
def metadata_fixture():
return {
'tables': {
'users': {
'primary_key': 'user_id',
'columns': {
'user_id': {'sdtype': 'id'},
'columnA': {'sdtype': 'categorical'},
'columnB': {'sdtype': 'numerical'}
},
},
'sessions': {
'primary_key': 'session_id',
'columns': {
'session_id': {'sdtype': 'id'},
'user_id': {'sdtype': 'id'},
'columnC': {'sdtype': 'categorical'},
'columnD': {'sdtype': 'numerical'}
}
}
},
'relationships': [
{
'parent_table_name': 'users',
'child_table_name': 'sessions',
'parent_primary_key': 'user_id',
'child_foreign_key': 'user_id'
}
]
}


class TestRelationshipValidity:

def test__extract_tuple(self, real_data_fixture):
"""Test the ``_extract_tuple`` method."""
# Setup
relationship_validity = RelationshipValidity()
real_data = real_data_fixture
relation = {
'parent_table_name': 'users',
'child_table_name': 'sessions',
'parent_primary_key': 'user_id',
'child_foreign_key': 'user_id'
}

# Run
real_columns = relationship_validity._extract_tuple(real_data, relation)

# Assert
assert real_columns == (real_data['users']['user_id'], real_data['sessions']['user_id'])

def test__get_num_iteration(self):
"""Test the ``_get_num_iterations`` method."""
# Setup
metadata = {
'relationships': [
{
'parent_table_name': 'table1',
'parent_primary_key': 'col1',
'child_table_name': 'table2',
'child_foreign_key': 'col6'
},
{
'parent_table_name': 'table1',
'parent_primary_key': 'col1',
'child_table_name': 'table3',
'child_foreign_key': 'col7'
},
{
'parent_table_name': 'table2',
'parent_primary_key': 'col6',
'child_table_name': 'table4',
'child_foreign_key': 'col8'
},
]
}
relationship_validity = RelationshipValidity()

# Run
num_iterations = relationship_validity._get_num_iterations(metadata)

# Assert
assert num_iterations == 3

@patch('sdmetrics.reports.multi_table._properties.relationship_validity.'
'CardinalityBoundaryAdherence')
@patch('sdmetrics.reports.multi_table._properties.relationship_validity.ReferentialIntegrity')
def test_get_score(
self, mock_referentialintegrity, mock_cardinalityboundaryadherence,
real_data_fixture, synthetic_data_fixture, metadata_fixture
):
"""Test the ``get_score`` function.
Test that when given a ``progress_bar`` and relationships, this calls
``CardinalityBoundaryAdherence`` and ``ReferentialIntegrity`` compute
method for each relationship.
"""
# Setup
mock_referentialintegrity.compute.return_value = 0.7
mock_referentialintegrity.__name__ = 'ReferentialIntegrity'
mock_cardinalityboundaryadherence.compute.return_value = 0.3
mock_cardinalityboundaryadherence.__name__ = 'CardinalityBoundaryAdherence'
mock_compute_average = Mock(return_value=0.5)
relationship_validity = RelationshipValidity()
relationship_validity._compute_average = mock_compute_average
progress_bar = Mock()

real_data = real_data_fixture
synthetic_data = synthetic_data_fixture
metadata = metadata_fixture

# Run
score = relationship_validity.get_score(
real_data=real_data, synthetic_data=synthetic_data, metadata=metadata,
progress_bar=progress_bar
)

# Assert
expected_details_property = pd.DataFrame({
'Parent Table': ['users', 'users'],
'Child Table': ['sessions', 'sessions'],
'Primary Key': ['user_id', 'user_id'],
'Foreign Key': ['user_id', 'user_id'],
'Metric': ['ReferentialIntegrity', 'CardinalityBoundaryAdherence'],
'Score': [0.7, 0.3],
})

assert score == 0.5
progress_bar.update.assert_called()
assert progress_bar.update.call_count == 2
mock_compute_average.assert_called_once_with()
pd.testing.assert_frame_equal(relationship_validity.details, expected_details_property)

@patch('sdmetrics.reports.multi_table._properties.relationship_validity.'
'CardinalityBoundaryAdherence')
@patch('sdmetrics.reports.multi_table._properties.relationship_validity.ReferentialIntegrity')
def test_get_score_raises_errors(
self, mock_referentialintegrity, mock_cardinalityboundaryadherence,
real_data_fixture, synthetic_data_fixture, metadata_fixture
):
"""Test the ``get_score`` when ``ReferentialIntegrity`` or
``CardinalityBoundaryAdherence`` crashes"""
# Setup
mock_referentialintegrity.compute.side_effect = [ValueError('error 1')]
mock_referentialintegrity.__name__ = 'ReferentialIntegrity'
mock_cardinalityboundaryadherence.compute.side_effect = [ValueError('error 2')]
mock_cardinalityboundaryadherence.__name__ = 'CardinalityBoundaryAdherence'
relationship_validity = RelationshipValidity()
progress_bar = Mock()

real_data = real_data_fixture
synthetic_data = synthetic_data_fixture
metadata = metadata_fixture

# Run
score = relationship_validity.get_score(
real_data=real_data, synthetic_data=synthetic_data, metadata=metadata,
progress_bar=progress_bar
)

# Assert
expected_details_property = pd.DataFrame({
'Parent Table': ['users', 'users'],
'Child Table': ['sessions', 'sessions'],
'Primary Key': ['user_id', 'user_id'],
'Foreign Key': ['user_id', 'user_id'],
'Metric': ['ReferentialIntegrity', 'CardinalityBoundaryAdherence'],
'Score': [np.nan, np.nan],
'Error': ['ValueError: error 1', 'ValueError: error 2']
})

assert pd.isna(score)
pd.testing.assert_frame_equal(relationship_validity.details, expected_details_property)
progress_bar.update.assert_called()
assert progress_bar.update.call_count == 2

def test_get_details_with_table_name(self):
"""Test the ``get_details`` method.
Test that the method returns the correct details for the given table name,
either from the child or parent table.
"""
# Setup
relationship_validity = RelationshipValidity()
relationship_validity.details = pd.DataFrame({
'Child Table': ['users_child', 'sessions_child'],
'Parent Table': ['users_parent', 'sessions_parent'],
'Metric': ['RelationshipValidityShapeSimilarity', 'SomeOtherMetric'],
'Score': [1.0, 0.5],
'Error': [None, 'Some error']
})

# Run
details_users_child = relationship_validity.get_details('users_child')
details_sessions_parent = relationship_validity.get_details('sessions_parent')

# Assert for child table
assert details_users_child.equals(pd.DataFrame({
'Child Table': ['users_child'],
'Parent Table': ['users_parent'],
'Metric': ['RelationshipValidityShapeSimilarity'],
'Score': [1.0],
'Error': [None]
}, index=[0]))

# Assert for parent table
assert details_sessions_parent.equals(pd.DataFrame({
'Child Table': ['sessions_child'],
'Parent Table': ['sessions_parent'],
'Metric': ['SomeOtherMetric'],
'Score': [0.5],
'Error': ['Some error']
}, index=[1]))

def test_get_table_relationships_plot(self):
"""Test the ``_get_table_relationships_plot`` method.
Test that the method returns the correct plotly figure for the given table name.
"""
# Setup
instance = RelationshipValidity()
instance.details = pd.DataFrame({
'Child Table': ['users_child', 'sessions_child'],
'Parent Table': ['users_parent', 'sessions_parent'],
'Metric': ['RelationshipValidityShapeSimilarity', 'SomeOtherMetric'],
'Score': [1.0, 0.5],
'Error': [None, 'Some error']
})

# Run
fig = instance._get_table_relationships_plot('users_child')

# Assert
assert isinstance(fig, Figure)

expected_x = ['users_child → users_parent']
expected_y = [1.0]
expected_title = 'Table Relationships (Average Score=1.0)'

assert fig.data[0].x.tolist() == expected_x
assert fig.data[0].y.tolist() == expected_y
assert fig.layout.title.text == expected_title

def test_get_visualization(self):
"""Test the ``get_visualization`` method."""
# Setup
mock__get_table_relationships_plot = Mock(side_effect=[Figure()])
relationship_validity = RelationshipValidity()
relationship_validity._get_table_relationships_plot = mock__get_table_relationships_plot

# Run
fig = relationship_validity.get_visualization('table_name')

# Assert
assert isinstance(fig, Figure)
mock__get_table_relationships_plot.assert_called_once_with('table_name')

0 comments on commit fea8ead

Please sign in to comment.