diff --git a/sdmetrics/reports/base_report.py b/sdmetrics/reports/base_report.py index aa8daffe..a179baed 100644 --- a/sdmetrics/reports/base_report.py +++ b/sdmetrics/reports/base_report.py @@ -143,7 +143,10 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True): Whether or not to print report summary and progress. """ if not isinstance(metadata, dict): - raise TypeError('The provided metadata is not a dictionary.') + raise TypeError( + f"Expected a dictionary but received a '{type(metadata).__name__}' instead." + " For SDV metadata objects, please use the 'to_dict' function to convert it to a dictionary." + ) self._validate(real_data, synthetic_data, metadata) self.convert_datetimes(real_data, synthetic_data, metadata) diff --git a/sdmetrics/reports/multi_table/base_multi_table_report.py b/sdmetrics/reports/multi_table/base_multi_table_report.py index 9a714148..1b78aa07 100644 --- a/sdmetrics/reports/multi_table/base_multi_table_report.py +++ b/sdmetrics/reports/multi_table/base_multi_table_report.py @@ -101,8 +101,10 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True): verbose (bool): Whether or not to print report summary and progress. """ + results = super().generate(real_data, synthetic_data, metadata, verbose) self.table_names = list(metadata.get('tables', {}).keys()) - return super().generate(real_data, synthetic_data, metadata, verbose) + + return results def _check_table_names(self, table_name): if table_name not in self.table_names: diff --git a/tests/unit/reports/test_base_report.py b/tests/unit/reports/test_base_report.py index e7819392..3a82d7f9 100644 --- a/tests/unit/reports/test_base_report.py +++ b/tests/unit/reports/test_base_report.py @@ -232,7 +232,10 @@ def test_generate_metadata_not_dict(self): metadata = 'metadata' # Run and Assert - expected_message = 'The provided metadata is not a dictionary.' + expected_message = ( + "Expected a dictionary but received a 'str' instead. For SDV metadata objects, " + "please use the 'to_dict' function to convert it to a dictionary." + ) with pytest.raises(TypeError, match=expected_message): base_report.generate(real_data, synthetic_data, metadata, verbose=False)