diff --git a/sdmetrics/reports/base_report.py b/sdmetrics/reports/base_report.py index 1f67d009..3536b3e0 100644 --- a/sdmetrics/reports/base_report.py +++ b/sdmetrics/reports/base_report.py @@ -55,9 +55,6 @@ def validate(self, real_data, synthetic_data, metadata): metadata (dict): The metadata of the table. """ - if not isinstance(metadata, dict): - metadata = metadata.to_dict() - self._validate_metadata_matches_data(real_data, synthetic_data, metadata) def _handle_results(self, verbose): @@ -101,6 +98,9 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True): verbose (bool): Whether or not to print report summary and progress. """ + if not isinstance(metadata, dict): + raise TypeError('The provided metadata is not a dictionary.') + self.validate(real_data, synthetic_data, metadata) self.convert_datetimes(real_data, synthetic_data, metadata) diff --git a/tests/unit/reports/test_base_report.py b/tests/unit/reports/test_base_report.py index 90275f07..156b096a 100644 --- a/tests/unit/reports/test_base_report.py +++ b/tests/unit/reports/test_base_report.py @@ -152,6 +152,27 @@ def test_convert_datetimes(self): pd.testing.assert_frame_equal(real_data, expected_real_data) pd.testing.assert_frame_equal(synthetic_data, expected_synthetic_data) + def test_generate_metadata_not_dict(self): + """Test the ``generate`` method with metadata not being a dict.""" + # Setup + base_report = BaseReport() + real_data = pd.DataFrame({ + 'column1': [1, 2, 3], + 'column2': ['a', 'b', 'c'] + }) + synthetic_data = pd.DataFrame({ + 'column1': [1, 2, 3], + 'column2': ['a', 'b', 'c'] + }) + metadata = 'metadata' + + # Run and Assert + expected_message = ( + 'The provided metadata is not a dictionary.' + ) + with pytest.raises(TypeError, match=expected_message): + base_report.generate(real_data, synthetic_data, metadata, verbose=False) + def test_generate(self): """Test the ``generate`` method.