diff --git a/sdmetrics/reports/base_report.py b/sdmetrics/reports/base_report.py index 89f5b8a5..1f67d009 100644 --- a/sdmetrics/reports/base_report.py +++ b/sdmetrics/reports/base_report.py @@ -102,6 +102,7 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True): Whether or not to print report summary and progress. """ self.validate(real_data, synthetic_data, metadata) + self.convert_datetimes(real_data, synthetic_data, metadata) scores = [] progress_bar = None diff --git a/tests/integration/reports/multi_table/test_diagnostic_report.py b/tests/integration/reports/multi_table/test_diagnostic_report.py index 8332b63d..dd12a484 100644 --- a/tests/integration/reports/multi_table/test_diagnostic_report.py +++ b/tests/integration/reports/multi_table/test_diagnostic_report.py @@ -32,6 +32,43 @@ def test_end_to_end(self): } assert results == expected_results + def test_end_to_end_with_object_datetimes(self): + """Test the ``DiagnosticReport`` report with object datetimes.""" + real_data, synthetic_data, metadata = load_demo(modality='multi_table') + for table, table_meta in metadata['tables'].items(): + for column, column_meta in table_meta['columns'].items(): + if column_meta['sdtype'] == 'datetime': + dt_format = column_meta['datetime_format'] + real_data[table][column] = real_data[table][column].dt.strftime(dt_format) + + report = DiagnosticReport() + + # Run + report.generate(real_data, synthetic_data, metadata, verbose=False) + results = report.get_results() + properties = report.get_properties() + + # Assert + expected_dataframe = pd.DataFrame({ + 'Property': ['Coverage', 'Boundary', 'Synthesis'], + 'Score': [0.9573447196980541, 0.8666666666666667, 0.6333333333333333] + }) + expected_results = { + 'SUCCESS': [ + 'The synthetic data covers over 90% of the categories present in the real data', + 'The synthetic data covers over 90% of the numerical ranges present' + ' in the real data' + ], + 'WARNING': [ + 'More than 10% the synthetic data does not follow the min/max boundaries' + ' set by the real data', + 'More than 10% of the synthetic rows are copies of the real data' + ], + 'DANGER': [] + } + assert results == expected_results + pd.testing.assert_frame_equal(properties, expected_dataframe) + def test_end_to_end_with_metrics_failing(self): """Test the ``DiagnosticReport`` report when some metrics crash. diff --git a/tests/integration/reports/multi_table/test_quality_report.py b/tests/integration/reports/multi_table/test_quality_report.py index da7a752d..3b0cdf49 100644 --- a/tests/integration/reports/multi_table/test_quality_report.py +++ b/tests/integration/reports/multi_table/test_quality_report.py @@ -186,6 +186,32 @@ def test_quality_report_end_to_end(): pd.testing.assert_frame_equal(properties, expected_properties) +def test_quality_report_with_object_datetimes(): + """Test the multi table QualityReport with object datetimes.""" + # Setup + real_data, synthetic_data, metadata = load_demo(modality='multi_table') + for table, table_meta in metadata['tables'].items(): + for column, column_meta in table_meta['columns'].items(): + if column_meta['sdtype'] == 'datetime': + dt_format = column_meta['datetime_format'] + real_data[table][column] = real_data[table][column].dt.strftime(dt_format) + + report = QualityReport() + + # Run + report.generate(real_data, synthetic_data, metadata) + score = report.get_score() + properties = report.get_properties() + + # Assert + expected_properties = pd.DataFrame({ + 'Property': ['Column Shapes', 'Column Pair Trends', 'Cardinality'], + 'Score': [0.7922619047619048, 0.4249665433225429, 0.8], + }) + assert score == 0.672409482694816 + pd.testing.assert_frame_equal(properties, expected_properties) + + def test_quality_report_with_errors(): """Test the multi table QualityReport with errors when computing metrics.""" # Setup diff --git a/tests/integration/reports/single_table/test_diagnostic_report.py b/tests/integration/reports/single_table/test_diagnostic_report.py index de51bb96..939a75cc 100644 --- a/tests/integration/reports/single_table/test_diagnostic_report.py +++ b/tests/integration/reports/single_table/test_diagnostic_report.py @@ -143,6 +143,76 @@ def test_end_to_end(self): expected_details_boundary ) + def test_generate_with_object_datetimes(self): + """Test the diagnostic report with object datetimes.""" + # Setup + real_data, synthetic_data, metadata = load_demo(modality='single_table') + for column, column_meta in metadata['columns'].items(): + if column_meta['sdtype'] == 'datetime': + dt_format = column_meta['datetime_format'] + real_data[column] = real_data[column].dt.strftime(dt_format) + + report = DiagnosticReport() + + # Run + report.generate(real_data, synthetic_data, metadata) + + # Assert + expected_details_synthetis = pd.DataFrame( + { + 'Metric': 'NewRowSynthesis', + 'Score': 1.0, + 'Num Matched Rows': 0, + 'Num New Rows': 215 + }, index=[0] + ) + + expected_details_coverage = pd.DataFrame({ + 'Column': [ + 'start_date', 'end_date', 'salary', 'duration', 'high_perc', 'high_spec', + 'mba_spec', 'second_perc', 'gender', 'degree_perc', 'placed', 'experience_years', + 'employability_perc', 'mba_perc', 'work_experience', 'degree_type' + ], + 'Metric': [ + 'RangeCoverage', 'RangeCoverage', 'RangeCoverage', 'RangeCoverage', + 'RangeCoverage', 'CategoryCoverage', 'CategoryCoverage', 'RangeCoverage', + 'CategoryCoverage', 'RangeCoverage', 'CategoryCoverage', 'RangeCoverage', + 'RangeCoverage', 'RangeCoverage', 'CategoryCoverage', 'CategoryCoverage' + ], + 'Score': [ + 1.0, 1.0, 0.42333783783783785, 1.0, 0.9807348482826732, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 0.6666666666666667, 1.0, 1.0, 1.0, 1.0 + ] + }) + + expected_details_boundary = pd.DataFrame({ + 'Column': [ + 'start_date', 'end_date', 'salary', 'duration', 'high_perc', 'second_perc', + 'degree_perc', 'experience_years', 'employability_perc', 'mba_perc' + ], + 'Metric': ['BoundaryAdherence'] * 10, + 'Score': [ + 0.8503937007874016, 0.8615384615384616, 0.9444444444444444, 1.0, + 0.8651162790697674, 0.9255813953488372, 0.9441860465116279, 1.0, + 0.8883720930232558, 0.8930232558139535 + ] + }) + + pd.testing.assert_frame_equal( + report.get_details('Synthesis'), + expected_details_synthetis + ) + + pd.testing.assert_frame_equal( + report.get_details('Coverage'), + expected_details_coverage + ) + + pd.testing.assert_frame_equal( + report.get_details('Boundary'), + expected_details_boundary + ) + def test_generate_multiple_times(self): """The results should be the same both times.""" # Setup diff --git a/tests/integration/reports/single_table/test_quality_report.py b/tests/integration/reports/single_table/test_quality_report.py index 19f1a26a..b8e87915 100644 --- a/tests/integration/reports/single_table/test_quality_report.py +++ b/tests/integration/reports/single_table/test_quality_report.py @@ -126,6 +126,73 @@ def test_report_end_to_end(self): ) assert report.get_score() == 0.7804181608907237 + def test_quality_report_with_object_datetimes(self): + """Test the quality report with object datetimes. + + The report must compute each property and the overall quality score. + """ + # Setup + column_names = [ + 'student_id', 'degree_type', 'start_date', 'second_perc', 'work_experience' + ] + real_data, synthetic_data, metadata = load_demo(modality='single_table') + for column, column_meta in metadata['columns'].items(): + if column_meta['sdtype'] == 'datetime': + dt_format = column_meta['datetime_format'] + real_data[column] = real_data[column].dt.strftime(dt_format) + + metadata['columns'] = { + key: val for key, val in metadata['columns'].items() if key in column_names + } + report = QualityReport() + + # Run + report.generate(real_data[column_names], synthetic_data[column_names], metadata) + + # Assert + expected_details_column_shapes_dict = { + 'Column': ['start_date', 'second_perc', 'work_experience', 'degree_type'], + 'Metric': ['KSComplement', 'KSComplement', 'TVComplement', 'TVComplement'], + 'Score': [ + 0.7011066184294531, 0.627906976744186, 0.9720930232558139, 0.9255813953488372 + ], + } + + expected_details_cpt__dict = { + 'Column 1': [ + 'start_date', 'start_date', 'start_date', 'second_perc', + 'second_perc', 'work_experience' + ], + 'Column 2': [ + 'second_perc', 'work_experience', 'degree_type', 'work_experience', + 'degree_type', 'degree_type' + ], + 'Metric': [ + 'CorrelationSimilarity', 'ContingencySimilarity', 'ContingencySimilarity', + 'ContingencySimilarity', 'ContingencySimilarity', 'ContingencySimilarity' + ], + 'Score': [ + 0.9854510263003199, 0.586046511627907, 0.6232558139534884, 0.7348837209302326, + 0.6976744186046512, 0.8976744186046511 + ], + 'Real Correlation': [ + 0.04735340044317632, np.nan, np.nan, np.nan, np.nan, np.nan + ], + 'Synthetic Correlation': [ + 0.07645134784253645, np.nan, np.nan, np.nan, np.nan, np.nan + ] + } + expected_details_column_shapes = pd.DataFrame(expected_details_column_shapes_dict) + expected_details_cpt = pd.DataFrame(expected_details_cpt__dict) + + pd.testing.assert_frame_equal( + report.get_details('Column Shapes'), expected_details_column_shapes + ) + pd.testing.assert_frame_equal( + report.get_details('Column Pair Trends'), expected_details_cpt + ) + assert report.get_score() == 0.7804181608907237 + def test_report_end_to_end_with_errors(self): """Test the quality report end to end with errors in the properties computation.""" # Setup diff --git a/tests/unit/reports/test_base_report.py b/tests/unit/reports/test_base_report.py index 8d226c7f..90275f07 100644 --- a/tests/unit/reports/test_base_report.py +++ b/tests/unit/reports/test_base_report.py @@ -179,8 +179,8 @@ def test_generate(self): }) metadata = { 'columns': { - 'column1': {'sdtypes': 'numerical'}, - 'column2': {'sdtypes': 'categorical'} + 'column1': {'sdtype': 'numerical'}, + 'column2': {'sdtype': 'categorical'} } } @@ -238,10 +238,10 @@ def test_generate_verbose(self, mock_tqdm): }) metadata = { 'columns': { - 'column1': {'sdtypes': 'numerical'}, - 'column2': {'sdtypes': 'categorical'}, - 'column3': {'sdtypes': 'numerical'}, - 'column4': {'sdtypes': 'numerical'}, + 'column1': {'sdtype': 'numerical'}, + 'column2': {'sdtype': 'categorical'}, + 'column3': {'sdtype': 'numerical'}, + 'column4': {'sdtype': 'numerical'}, } }