Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate that the metadata is always a dict #466

Merged
merged 5 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions sdmetrics/reports/base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,25 @@ def validate(self, real_data, synthetic_data, metadata):
metadata (dict):
The metadata of the table.
"""
if not isinstance(metadata, dict):
metadata = metadata.to_dict()

self._validate_metadata_matches_data(real_data, synthetic_data, metadata)

def _handle_results(self, verbose):
raise NotImplementedError

@staticmethod
def _convert_metadata(metadata):
"""If the metadta is not a dict, try to convert it."""
if not isinstance(metadata, dict):
try:
metadata = metadata.to_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we actually want to support? Other objects might implement a to_dict method which might lead to weird issues down the road (i.e. I'm pretty sure you can call to_dict on a pandas dataframe). Maybe @npatki has thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frances-h agreed. I think we should stick to the issue's description, which states that we should only accept dictionaries. We do not intend to convert to a dictionary on the user's behalf.

except Exception:
raise TypeError(
'The provided metadata is not a dictionary and does not have a to_dict method.'
'Please convert the metadata to a dictionary.'
)

return metadata

@staticmethod
def convert_datetimes(real_data, synthetic_data, metadata):
"""Try to convert all datetime columns to datetime dtype.
Expand Down Expand Up @@ -101,6 +112,7 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True):
verbose (bool):
Whether or not to print report summary and progress.
"""
metadata = self._convert_metadata(metadata)
self.validate(real_data, synthetic_data, metadata)
self.convert_datetimes(real_data, synthetic_data, metadata)

Expand Down
53 changes: 52 additions & 1 deletion tests/unit/reports/test_base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,56 @@ def test_convert_datetimes(self):
pd.testing.assert_frame_equal(real_data, expected_real_data)
pd.testing.assert_frame_equal(synthetic_data, expected_synthetic_data)

def test_generate(self):
def test__convert_metadata_with_to_dict_method(self):
"""Test ``_convert_metadata`` when the metadata object has a ``to_dict`` method."""
# Setup
metadata_example = {
'column1': {'sdtype': 'numerical'},
'column2': {'sdtype': 'categorical'},
}

class Metadata:
def __init__(self):
self.columns = metadata_example

def to_dict(self):
return self.columns

metadata = Metadata()

# Run
converted_metadata = BaseReport._convert_metadata(metadata)

# Assert
assert converted_metadata == metadata_example

def test__convert_metadata_without_to_dict_method(self):
"""Test ``_convert_metadata`` when the metadata object has no ``to_dict`` method."""
# Setup
metadata_example = {
'column1': {'sdtype': 'numerical'},
'column2': {'sdtype': 'categorical'},
}

class Metadata:
def __init__(self):
self.columns = metadata_example

metadata = Metadata()

# Run and Assert
expected_message = re.escape(
'The provided metadata is not a dictionary and does not have a to_dict method.'
'Please convert the metadata to a dictionary.'
)
with pytest.raises(TypeError, match=expected_message):
BaseReport._convert_metadata(metadata)

result = BaseReport._convert_metadata(metadata_example)
assert result == metadata_example

@patch('sdmetrics.reports.base_report.BaseReport._convert_metadata')
def test_generate(self, mock__convert_metadata):
"""Test the ``generate`` method.

This test checks that the method calls the ``validate`` method and the ``get_score``
Expand Down Expand Up @@ -183,11 +232,13 @@ def test_generate(self):
'column2': {'sdtype': 'categorical'}
}
}
mock__convert_metadata.return_value = metadata

# Run
base_report.generate(real_data, synthetic_data, metadata, verbose=False)

# Assert
mock__convert_metadata.assert_called_once_with(metadata)
mock_validate.assert_called_once_with(real_data, synthetic_data, metadata)
mock_handle_results.assert_called_once_with(False)
base_report._properties['Property 1'].get_score.assert_called_with(
Expand Down
Loading